In [3]:
import os
from copy import deepcopy

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
# %matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

## tqdm for loading bars
# from tqdm.notebook import tqdm
from tqdm import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import STL10
from torchvision import transforms

# PyTorch Lightning
# try:
#     import pytorch_lightning as pl
# except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
#     !pip install --quiet pytorch-lightning>=1.4
#     import pytorch_lightning as pl
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from PIL import Image
import glob

pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

from torchmetrics.classification import MultilabelAUROC
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelPrecision
from torchmetrics.classification import MultilabelRecall
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.classification import MultilabelSpecificity
from torchmetrics.classification import MultilabelConfusionMatrix
from torchmetrics.classification import MulticlassAUROC
from torchmetrics.classification import MulticlassF1Score
from torchmetrics.classification import MulticlassPrecision
from torchmetrics.classification import MulticlassRecall
from torchmetrics.classification import MulticlassAccuracy

  set_matplotlib_formats('svg', 'pdf') # For export
Global seed set to 42


Device: cuda:0


<Figure size 640x480 with 0 Axes>

In [4]:
class SimCLR(pl.LightningModule):

    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=1000):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # Base model f(.)
        self.convnet = torchvision.models.resnet34(num_classes=4*hidden_dim)  # Output of last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50)
        return [optimizer], [lr_scheduler]

    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0)

        # Encode all images
        feats = self.convnet(imgs)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # Logging loss
        self.log(mode+'_loss', nll)
        # Get ranking position of positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)],
                             dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean(), prog_bar=True)
        self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean(), prog_bar=True)
        self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean(), prog_bar=True)

        return nll

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

In [5]:
class LogisticRegression(pl.LightningModule):

    def __init__(self, feature_dim, num_classes, lr, weight_decay, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        # Mapping from representation h to classes
        # self.network = network
        ####################################################
        simclr = simclr_model
        self.network = deepcopy(simclr.convnet)
        self.network.fc = nn.Identity()  # Removing projection head g(.)
        self.network.eval()
        self.network.to(device)
        self.model = nn.Linear(512, 6)
        #######################################################
        # self.network = torchvision.models.resnet34(num_classes=6, pretrained=False)
        #####################barlow#############################
        # simclr = encoder
        # # self.network = deepcopy(simclr.convnet)
        # self.network = deepcopy(simclr)
        # # self.network.fc = nn.Identity()  # Removing projection head g(.)
        # self.network.eval()
        # self.network.to(device)
        # self.model = nn.Linear(512, 6)


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=[int(self.hparams.max_epochs*0.6),
                                                                  int(self.hparams.max_epochs*0.8)],
                                                      gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode='train'):
        # feats, labels = batch
        # preds = self.model(feats)
        # loss = F.cross_entropy(preds, labels)
        # acc = (preds.argmax(dim=-1) == labels).float().mean()
        #
        # self.log(mode + '_loss', loss)
        # self.log(mode + '_acc', acc)
        imgs, labels = batch

        # preds = self.network(imgs)

        features = self.network(imgs)
        preds = self.model(features)
        # preds = torch.sigmoid(preds)

        # preds = preds.squeeze()
        # criterion = nn.BCEWithLogitsLoss()
        criterion = nn.CrossEntropyLoss()
        labels = labels.squeeze()
        # loss = F.cross_entropy(preds, labels)
        loss = criterion(preds, labels)

        # acc = (preds.argmax(dim=-1) == labels).float().mean()
        acc = MultilabelAccuracy(num_labels=6).to(device)
        auc = MultilabelAUROC(num_labels=6, average="macro", thresholds=None).to(device)
        f1 = MultilabelF1Score(num_labels=6, average="macro").to(device)
        precision = MultilabelPrecision(num_labels=6, average="macro").to(device)
        recall = MultilabelRecall(num_labels=6, average="macro", threshold=0.3).to(device)

        # acc = MulticlassAccuracy(num_classes=5).to(device)
        # auc = MulticlassAUROC(num_classes=5, average="macro", thresholds=None).to(device)
        # f1 = MulticlassF1Score(num_classes=5, average="macro").to(device)
        # precision = MulticlassPrecision(num_classes=5, average="macro").to(device)
        # recall = MulticlassRecall(num_classes=5, average="macro", threshold=0.3).to(device)
        # specificity = MultilabelSpecificity(num_labels=7).to(device)
        # confusion_matrix = MultiClassConfusionMatrix(num_labels=6).to(device)

        accuracy = acc(preds, labels)
        auc_score = auc(preds, labels)
        f1_score = f1(preds, labels)
        precision_score = precision(preds, labels)
        recall_score = recall(preds, labels)
        # specificity_score = specificity(preds, labels)
        # confusion_matrix_score = confusion_matrix(preds, labels)
        # tn = confusion_matrix_score[:, 0, 0]
        # tp = confusion_matrix_score[:, 1, 1]
        # fn = confusion_matrix_score[:, 1, 0]
        # fp = confusion_matrix_score[:, 0, 1]
        # specificity_score_matrix = tn / (tn + fp)
        # sensitivity_score_matrix = tp / (tp + fn)

        self.log(mode + '_acc', accuracy, prog_bar=True)
        self.log(mode + '_auc', auc_score, prog_bar=True)
        self.log(mode + '_f1', f1_score, prog_bar=True)
        self.log(mode + '_precision', precision_score, prog_bar=True)
        self.log(mode + '_recall', recall_score, prog_bar=True)
        # self.log(mode + '_specificity', specificity_score, prog_bar=True)
        # self.log(mode + '_specificity_matrix', specificity_score_matrix.mean(), prog_bar=True)
        # self.log(mode + '_sensitivity_matrix', sensitivity_score_matrix.mean(), prog_bar=True)

        self.log(mode + '_loss', loss, prog_bar=True)


        # if mode == 'val':
        #     print("accuracy ", accuracy)
        #     print("auc score ", auc_score)
        #     print("f1", f1_score)
        #     print("precision", precision_score)
            # print("recall", recall_score)
            # print("specificity_matrix", specificity_score_matrix.mean())
            # print("sensitivity_matrix", sensitivity_score_matrix.mean())

        # if mode=='test':
        #     # auc_class = MultilabelAUROC(num_labels=6, average=None, thresholds=None).to(device)
        #     # auc_score_class = auc_class(preds, labels)
        #     # print("AUC score for each class: ", auc_score_class)
        #
        #     acc_class =  MultilabelAccuracy(num_labels=6, average=None).to(device)
        #     acc_score_class = acc_class(preds, labels)
        #     print("Accuracy for each class: ", acc_score_class)
        #
        #
        #     f1_class = MultilabelF1Score(num_labels=6, average=None).to(device)
        #     f1_score_class = f1_class(preds, labels)
        #     print("F1 score for each class: ", f1_score_class)
        #
        #     precision_class = MultilabelPrecision(num_labels=6, average=None).to(device)
        #     pre_score_class = precision_class(preds, labels)
        #     print("Precison for each class: ", pre_score_class)
        #
        #     recall_class = MultilabelRecall(num_labels=6, average=None, threshold=0.2).to(device)
        #     recall_score_class = recall_class(preds, labels)
        #     print("AUC score for each class: ", recall_score_class)



        return loss

    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode='val')

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode='test')

In [6]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.label_arr = np.asarray(self.img_labels.iloc[:, 1:])

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path+'.jpeg')
        # label = self.img_labels.iloc[idx, 1]
        label = self.label_arr[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        # print(img_path,label)
        return image, label


In [9]:
from torchvision.datasets.imagenet import ImageFolder
train_transforms = transforms.Compose([transforms.ToPILImage(),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.RandomVerticalFlip(),
                                       transforms.RandomRotation(180),
                                       transforms.Resize((512,512)),
                                       transforms.RandomGrayscale(p=0.2),
                                       transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 0.5)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225]),
                                       # ToFeature(simclr_model)
                                       # transforms.Normalize((0.5,), (0.5,))
                                       ])
dataset = CustomImageDataset(annotations_file="C:/Users/User/Fundus Dataset/UFI_multidisease/labels_nohead_no_normal.csv",
                             img_dir="C:/Users/User/Fundus Dataset/UFI_multidisease/all_no_normal" ,transform=train_transforms)

train_img_aug_data, test_img_aug_data = torch.utils.data.random_split(dataset, [2885, 400])

In [10]:
simclr_model = SimCLR.load_from_checkpoint("models/epoch=1002-step=116348.ckpt")
model = LogisticRegression.load_from_checkpoint("checkpoints/dt_simclr_m4/lightning_logs/version_0/checkpoints/epoch=89-step=4140-v2.ckpt")

Lightning automatically upgraded your loaded checkpoint from v1.8.4.post0 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file c:\Users\User\PycharmProjects\ssl\models\epoch=1002-step=116348.ckpt`
