In [None]:
!pip install pytorch_lightning

In [None]:
import torch
import torchvision
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn
import os
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sn
import io
from PIL import Image

def calculate_confusion_matrix(y_true, y_pred, num_classes):
    cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
    return cm


import numpy as np
import math

def inv_normalize(tensor):
    return (np.transpose(tensor.numpy(), axes=[1, 2, 0]) * data_module.imnet_std) + data_module.imnet_mean

def list_to_image_grid(images, titles, cols=3):
    """Display a list of images as a grid with titles."""
    rows = math.ceil(len(images) / cols)
    if titles is None:
        titles = ['']*len(images)
    plt.figure(figsize=(16, 14))
    for n, (image, title) in enumerate(zip(images, titles)):
        plt.subplot(rows, cols, n + 1)
        plt.imshow(inv_normalize(image))
        plt.title(title)
        plt.axis('off')
    plt.show()

In [None]:
import torch 
torch.set_float32_matmul_precision('high')
torch.cuda.get_device_name(0)

In [None]:
class ResNet(pl.LightningModule):
    def __init__(self, num_angle_classes=5, num_level_classes=6):
        super().__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, 512)
        self.angle_head = torch.nn.Linear(512, num_angle_classes)
        self.level_head = torch.nn.Linear(512, num_level_classes)

    def forward(self, x):
        x = self.resnet(x)
        angle_logits = self.angle_head(x)
        level_logits = self.level_head(x)
        return angle_logits, level_logits
    
    def forward_angle(self, x):
        x = self.resnet(x)
        angle_logits = self.angle_head(x)
        return angle_logits

    def forward_level(self, x):
        x = self.resnet(x)
        level_logits = self.level_head(x)
        return level_logits
    
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x, target = batch

        if dataloader_idx == 0:
            output = self.forward_angle(x)
        else:
            output = self.forward_level(x)
        
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(output, target)

        preds = torch.argmax(output, dim=1)
        acc = torch.sum(preds == target).item() / len(target)
        
        if dataloader_idx == 0:            
            self.log("test_loss_angle", loss, prog_bar=True)
            self.log("test_angle_acc", acc, prog_bar=True)
            
            # Calculate confusion matrix
            cm = calculate_confusion_matrix(target.cpu().numpy(), preds.cpu().numpy(), num_classes=5)
            
            tb = self.logger.experiment
            df_cm = pd.DataFrame(
                cm,
                index=data_module.angle_data.class_to_idx.values(),
                columns=data_module.angle_data.class_to_idx.values(),
            )

            fig, ax = plt.subplots(figsize=(10, 5))
            fig.subplots_adjust(left=0.05, right=.65)
            sn.set(font_scale=1.2)
            sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
            ax.legend(
                data_module.angle_data.class_to_idx.values(),
                data_module.angle_data.class_to_idx.keys(),
                handler_map={int: IntHandler()},
                loc='upper left',
                bbox_to_anchor=(1.2, 1)
            )
            buf = io.BytesIO()

            plt.savefig(buf, format='jpeg', bbox_inches='tight')
            buf.seek(0)
            im = Image.open(buf)
            im = torchvision.transforms.ToTensor()(im)
            tb.add_image("test_confusion_matrix_angle", im, global_step=self.current_epoch)
        else:             
            self.log("test_loss_level", loss, prog_bar=True)
            self.log("test_level_acc", acc, prog_bar=True)
            
            # Calculate confusion matrix
            cm = calculate_confusion_matrix(target.cpu().numpy(), preds.cpu().numpy(), num_classes=6)
            
            tb = self.logger.experiment
            df_cm = pd.DataFrame(
                cm,
                index=data_module.level_data.class_to_idx.values(),
                columns=data_module.level_data.class_to_idx.values(),
            )

            fig, ax = plt.subplots(figsize=(10, 5))
            fig.subplots_adjust(left=0.05, right=.65)
            sn.set(font_scale=1.2)
            sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
            ax.legend(
                data_module.level_data.class_to_idx.values(),
                data_module.level_data.class_to_idx.keys(),
                handler_map={int: IntHandler()},
                loc='upper left',
                bbox_to_anchor=(1.2, 1)
            )
            buf = io.BytesIO()

            plt.savefig(buf, format='jpeg', bbox_inches='tight')
            buf.seek(0)
            im = Image.open(buf)
            im = torchvision.transforms.ToTensor()(im)
            tb.add_image("test_confusion_matrix_level", im, global_step=self.current_epoch)
        
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, target = batch
        
        output_angle = self.forward_angle(x)
        output_level = self.forward_level(x)

        preds_angle = torch.argmax(output_angle, dim=1)
        preds_level = torch.argmax(output_level, dim=1)
        
        return {'img': x, 'angle': preds_angle, 'level': preds_level, 'target': target, 'dl': dataloader_idx}
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

In [None]:
# load from a checkpoint
model = ResNet(num_angle_classes=5, num_level_classes=6)
model = model.load_from_checkpoint('model.ckpt') #95 94

In [None]:
from torchvision import transforms
from torch.utils.data import random_split

class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, val_size=0.1, train_size=0.7):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.val_size = val_size
        self.train_size = train_size
        self.imnet_mean = [0.485, 0.456, 0.406]
        self.imnet_std = [0.229, 0.224, 0.225]

    def setup(self, stage=None):
        angle_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imnet_mean, std=self.imnet_std)
        ])
        self.angle_data = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'angle'), transform=angle_transforms)
        train_size = int(len(self.angle_data) * self.train_size)
        val_size = int(len(self.angle_data) * self.val_size)
        test_size = len(self.angle_data) - train_size - val_size
        self.angle_data_train, self.angle_data_val, self.angle_data_test = random_split(self.angle_data, [train_size, val_size, test_size])

        
        level_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imnet_mean, std=self.imnet_std)
        ])
        self.level_data = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'level'), transform=level_transforms)
        train_size = int(len(self.level_data) * self.train_size)
        val_size = int(len(self.level_data) * self.val_size)
        test_size = len(self.level_data) - train_size - val_size
        self.level_data_train, self.level_data_val, self.level_data_test = random_split(self.level_data, [train_size, val_size, test_size])
    
    def test_dataloader(self):
        angle_loader = torch.utils.data.DataLoader(self.angle_data_test, batch_size=self.batch_size*10)
        level_loader = torch.utils.data.DataLoader(self.level_data_test, batch_size=self.batch_size*10)
        return {'angle': angle_loader, 'level': level_loader}
    
    def predict_dataloader(self):
        angle_loader = torch.utils.data.DataLoader(self.angle_data_test, batch_size=self.batch_size*10)
        level_loader = torch.utils.data.DataLoader(self.level_data_test, batch_size=self.batch_size*10)
        return {'angle': angle_loader, 'level': level_loader}

data_module = DataModule(data_dir='./dataset_level_angle/', batch_size=64)
data_module.setup()

In [None]:
trainer = pl.Trainer(accelerator='gpu', max_epochs=1)
trainer.test(model, data_module) 

In [None]:
from PIL import Image

imnet_mean = [0.485, 0.456, 0.406]
imnet_std = [0.229, 0.224, 0.225]
t = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=imnet_mean, std=imnet_std)
        ])


img = Image.open('./dataset_level_angle/level/eye/f51419ee4bcf2bd63b70dd26b779e5db6b46819d9997c41921faea0cfa131fb2.jpg')

timg = t(img)
plt.imshow(timg.numpy()[2, ...])

In [None]:
timg