from google.colab import drive
drive.mount('/content/drive')

!unzip -qq /content/drive/MyDrive/RUPESH_RESEARCH_IMPLEMENTATIONS/DATASETS/NPY_DATA_patient_wise_split_128_x_128.zip

In [None]:
# import shutil
# shutil.rmtree("/content/NPY_DATA_patient_wise_split", ignore_errors = True)

# Install the required libraries

In [None]:
!pip install wandb --quiet
!pip install torchsummary --quiet
!pip install torchsampler --quiet
!pip install torchmetrics --quiet
!pip install grad-cam --quiet
!pip install torchfunc==0.1.1 --quiet
#!pip install timm --quiet
#!pip install vit-pytorch --quiet
!pip install -q timm pytorch-metric-learning

In [None]:
# import os
# from glob import glob

# os.makedirs("/content/PPMI_SPECT", exist_ok= True)

# zip_list = glob(f"/content/drive/MyDrive/RUPESH_RESEARCH_IMPLEMENTATIONS/DATASETS/PPMI_SPECT_40_42/*")
# for fzip in zip_list:
#     !unzip $fzip -d /content/PPMI_SPECT

# Importing necessary libraries

In [None]:
import os
import torch
import torchvision
import torchfunc
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as tt
from torchvision.ops import sigmoid_focal_loss
from torch.utils.data import random_split
from torchvision.utils import make_grid
from torchsummary import summary
from pytorch_metric_learning import losses
from torch.cuda import amp

import random

import timm

import copy

import tarfile

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.metrics import f1_score

from tqdm import tqdm
import wandb

from glob import glob

import cv2 
import time


matplotlib.rcParams['figure.facecolor'] = '#ffffff'

In [None]:
timm.list_models("vit*")

In [None]:
# model = timm.create_model("deit_tiny_patch16_224", num_classes = 2, img_size = 128, pretrained = True)
# #print(model)
# device = get_default_device()
# model =to_device(model, device)
# x     = to_device(torch.randn(2, 3, 128, 128), device)
# output = model(x)
# print(output.shape)
# summary(model, input_size = (3,128, 128))

In [None]:
print(torch.cuda.is_available())

In [None]:
torch.cuda.get_device_name()

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    



In [None]:
import wandb
!wandb login 4763b6e998a039526b08249362dd9c58a2348e34

# Run from here

In [None]:


######################  CONFIGURATION  #########################
config_dict= {"num_epochs": 300,
              
              # for LEARING RATE
              "lr": 1e-5,
              "min_lr": 1e-6,
              "opt_func": "AdamW",
              "scheduler": 'CosineAnnealingLR',
              "T_max": 10,
              "weight_decay": 1e-6,
              
              "grad_clip": 0.1,
              
              # for MODEL HYPERPARAMETERS
              "batch_size": 16,
              "model": "resnet",
              "num_classes":2, 
              "pretrained": False,
              "random_seed": 42,
              "target_names": ["HC", "SWEDD"],
              "metric_task": "binary",       
              
              # LOSS FUNCTION
              "Loss_fn": "focal",                   # CE/ focal/ SCL
                  # for FOCAL loss
                  "fl_alpha": 0.25,                 # Must be in range [0, 1]
                  "fl_gamma": 2,
                  # for SCL
                  "temperature": 0.1
                }
###############################################################
project_name = 'May_30_FL_PPMI_SPECT_128x128_2_class_HC_SWEDD_imbalance'
project_run_name = config_dict["model"] + "_pretrained_" + str(config_dict["pretrained"])
###############################################################

set_seed(seed = config_dict["random_seed"])
# torch.manual_seed(config_dict["random_seed"])
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(config_dict["random_seed"])
# torch.backends.cudnn.deterministic = True
# np.random.seed(config_dict["random_seed"])

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(config_dict['random_seed'])
    #random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(config_dict["random_seed"]) 


In [None]:
str(config_dict["pretrained"])

In [None]:
# wandb.finish()

In [None]:
wandb.init(config = config_dict,
           name = project_run_name, 
           project=project_name,
           notes=config_dict["model"], 
           tags=['Adam', 'pretrained'])

# Loss Function

In [None]:
#WEIGHTED FOCAL LOSS
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=config_dict["fl_alpha"], gamma=config_dict["fl_gamma"]):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma

    def forward(self, inputs, targets):
#         # targets = F.one_hot(targets, num_classes = 2)
#         # print(inputs)
#         # print(targets)
        
#         # print(type(inputs))
#         # print(type(targets))

#         # targets = targets.type(torch.long)
#         # inputs = inputs.type(torch.long)
#         #BCE_loss = nn.BCEWithLogitsLoss(inputs, targets, reduce = None)
#         BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        
#         at = self.alpha.gather(0, targets.data.view(-1))
#         pt = torch.exp(-BCE_loss)
#         F_loss = at*(1-pt)**self.gamma * BCE_loss
#         return F_loss.mean()
        targets_one_hot = F.one_hot(targets, num_classes = config_dict['num_classes']).float()
        loss = sigmoid_focal_loss(inputs, targets_one_hot, alpha = config_dict['fl_alpha'], gamma = config_dict['fl_gamma'], reduction = "mean")
        return loss
    
# SUPERVISED CONTRASTIVE LOSS    
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors, p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )
        return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))

In [None]:
# input = torch.randn(3, requires_grad=True)
# target = torch.empty(3).random_(2)
# floss = WeightedFocalLoss()
# loss = floss(input, target)
# print(input, target)
# loss.backward()
# print(input)
# print(input.shape, target.shape)


# DATASET reading from the disk

In [None]:
class CustomDataset(Dataset):
    def __init__(self, imgs_path, tfms):
        self.imgs_path = imgs_path
        class_list = os.listdir(imgs_path)
        print(class_list)
        class_list.sort()
        self.data = []
        for class_path in class_list:
            class_name = class_path
            # print(glob(os.path.join(imgs_path, class_path,"*.npz")))
            for img_path in glob(os.path.join(imgs_path, class_path,"*.npy")):
                self.data.append([img_path, class_name])
        #print(self.data)
        class_map = {}
        for i in range(len(class_list)):
            class_map[class_list[i]] = i
        self.class_map = class_map
        print(class_map)
        for key, val in self.class_map.items():
            print(f"{key}: {val}")
        self.img_dim = (128, 128)  
        self.tfms = tfms  

    def __len__(self):
        return len(self.data)  

    def __class_to_idx__(self):
        for key, val in self.class_map.items():
            print(f"{key}: {val}")
    
    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = np.load(img_path)
        img = (img - img.min())/(img.max() - img.min())
        class_id = self.class_map[class_name]
        # print(img.max(), img.min(), img.dtype)
        img_tensor = torch.from_numpy(img).type(torch.float32)
        # print(img_tensor.max() , img_tensor.min(), img_tensor.dtype)
        img_tensor = self.tfms(img_tensor)
        #print(img_tensor.max() , img_tensor.min(), img_tensor.dtype)
        class_id = torch.tensor(class_id)
        return img_tensor, class_id

class CustomDataset_patient_wise(Dataset):
    def __init__(self, imgs_path, tfms):
        self.imgs_path = imgs_path
        class_list = os.listdir(imgs_path)
        class_list.sort()
        print(class_list)
        self.data = []
        for class_path in class_list:
            class_name = class_path
            print((os.path.join(imgs_path, class_path,"**/*.npy")))
            for img_path in glob(os.path.join(imgs_path, class_path,"**/*.npy")):
                self.data.append([img_path, class_name])
        #print(self.data)
        self.OVERALL_MAX, self.OVERALL_MIN = -500000.0, 500000.0
        for fpath in tqdm(self.data, desc = "Processing the data for finding local MIN and MAX ....."):
            #print(fpath)
            x = np.load(fpath[0])
            self.OVERALL_MAX = max(x.max(), self.OVERALL_MAX)
            self.OVERALL_MIN = min(x.min(), self.OVERALL_MIN)
        print(f"DATASET_MAX: {self.OVERALL_MAX}, DATASET_MIN: {self.OVERALL_MIN}")
        class_map = {}
        for i in range(len(class_list)):
            class_map[class_list[i]] = i
        self.class_map = class_map
        print(class_map)
        for key, val in self.class_map.items():
            print(f"{key}: {val}")
        self.img_dim = x.shape
        print(self.img_dim)
        self.tfms = tfms  

    def __len__(self):
        return len(self.data)  

    def __class_to_idx__(self):
        for key, val in self.class_map.items():
            print(f"{key}: {val}")
    
    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = np.load(img_path)
        img = (img - self.OVERALL_MIN)/(self.OVERALL_MAX - self.OVERALL_MIN)
        class_id = self.class_map[class_name]
        # print(img.max(), img.min(), img.dtype)
        img_tensor = torch.from_numpy(img).type(torch.float32)
        # print(img_tensor.max() , img_tensor.min(), img_tensor.dtype)
        if self.tfms is not None:
            img_tensor = self.tfms(img_tensor)
        #print(img_tensor.max() , img_tensor.min(), img_tensor.dtype)
        class_id = torch.tensor(class_id)
        return img_tensor, class_id

In [None]:
# train_ds.__class_to_idx__()

# Transformations and Augmentations

In [None]:
stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tfms_org = tt.Compose([#tt.ToTensor(),
                   # tt.Resize((224, 224)),
                   tt.Lambda(lambda x: x.repeat(int(3), 1, 1)),
                   tt.Normalize(*stats,inplace=True)])

tfms_1 = tt.Compose([#tt.ToTensor(),
                   # tt.Resize((224, 224)),
                   tt.Lambda(lambda x: x.repeat(int(3), 1, 1)),
                   tt.RandomRotation(10),
                   
                   tt.Normalize(*stats,inplace=True)])

tfms_2 = tt.Compose([#tt.ToTensor(),
                   # tt.Resize((224, 224)),
                   tt.Lambda(lambda x: x.repeat(int(3), 1, 1)),
                   tt.RandomHorizontalFlip(0.5),
                   
                   tt.Normalize(*stats,inplace=True)])

tfms_3 = tt.Compose([#tt.ToTensor(),
                   # tt.Resize((224, 224)),
                   tt.Lambda(lambda x: x.repeat(int(3), 1, 1)),
                   tt.GaussianBlur(3, sigma = (0.1, 2.0)),
                   
                   tt.Normalize(*stats,inplace=True)])

## Train, Valid and Test Dataset read

In [None]:
# PyTorch datasets

"""NOTE in this version the tfms are ignored"""
data_dir = "/kaggle/input/npy-train-test-split-may-30/npy_train_test_split"

train_ds_org = CustomDataset_patient_wise(data_dir+'/train', tfms_org)
# train_ds_1 = CustomDataset_patient_wise(data_dir+'/train', tfms_1)
# train_ds_2 = CustomDataset_patient_wise(data_dir+'/train', tfms_2)
# train_ds_3 = CustomDataset_patient_wise(data_dir+'/train', tfms_3)

train_ds = train_ds_org #+ train_ds_1 + train_ds_2 + train_ds_3           #Augmentation is ONLY for the Training dataset, test dataset must have original DATA
valid_ds = CustomDataset_patient_wise(data_dir+'/valid', tfms = tfms_org)
test_ds  = CustomDataset_patient_wise(data_dir+'/test', tfms = tfms_org)

In [None]:
train_ds.__len__()

# DataLoaders

In [None]:
batch_size = config_dict["batch_size"]
# PyTorch data loaders
train_dl = DataLoader(train_ds, batch_size, num_workers=2, pin_memory=True, shuffle = True, worker_init_fn=seed_worker,
    generator=g)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=2, pin_memory=True, shuffle = True, worker_init_fn=seed_worker,
    generator=g)
test_dl = DataLoader(test_ds, batch_size, num_workers =2, pin_memory = True, shuffle = True, worker_init_fn=seed_worker,
    generator=g)

In [None]:
# temp = train_dl.__iter__()
# print(temp.size)

In [None]:
train_dl.__len__()

# Denormalize images and view a batch of images

In [None]:
def denormalize(images, means, stds):
    means = torch.tensor(means).reshape(1, 3, 1, 1)
    stds = torch.tensor(stds).reshape(1, 3, 1, 1)
    return images * stds + means

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        denorm_images = denormalize(images, *stats).numpy()
        print((denorm_images.max()))
        for imgNum in range(len(denorm_images)):
            denorm_images[imgNum, :, :, :] = cv2.normalize(denorm_images[imgNum, :, :, :], None, 0, 255, cv2.NORM_MINMAX)
        denorm_images = torch.from_numpy(denorm_images) 
        ax.imshow(make_grid(denorm_images[:16], nrow=4).permute(1, 2, 0).clamp(0,1))
        break

In [None]:
show_batch(train_dl)

# Utility functions
 - choosing from available devices
 - Loading the data/ model into the DEFAULT device

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
device = get_default_device()
device

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)
test_dl = DeviceDataLoader(test_dl, device)

# Training Codes

In [None]:
from sklearn.metrics import f1_score
import torchmetrics
from torchmetrics import AUROC

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch, criterion):
    images, labels = batch 
    out = model(images)                  # Generate predictions
    loss = criterion(out, labels)
#     if(config_dict["Loss_fn"] == "CE"):
#         loss = F.cross_entropy(out, labels) # Calculate loss
#     elif(config_dict["Loss_fn"] == "focal"):
#         focal_loss = WeightedFocalLoss()
#         #print(out)
#         loss = focal_loss(out, labels)
    return loss

@torch.no_grad()
def validation_step(model, batch, criterion, fold):
    images, labels = batch 
    out = model(images)                    # Generate predictions
    pred = out
    #print(pred)
    loss = criterion(out, labels)
#     if(config_dict["Loss_fn"] == "CE"):
#         loss = F.cross_entropy(out, labels) # Calculate loss
#     elif(config_dict["Loss_fn"] == "focal"):
#         focal_loss = WeightedFocalLoss()
#         loss = focal_loss(out, labels)
#     elif(config_dict["Loss_fn"] == "SCL"):
#         focal_loss = 
    acc = accuracy(out, labels)           # Calculate accuracy
    return {f'val_loss_{fold}': loss.detach(), f'val_acc_{fold}': acc, f"preds_{fold}": pred, f"labels_{fold}": labels}

@torch.no_grad()
def validation_epoch_end(outputs, fold):
    batch_losses = [x[f'val_loss_{fold}'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
    batch_accs = [x[f'val_acc_{fold}'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
    batch_y_true = [x[f"labels_{fold}"].detach() for x in outputs]
    batch_y_pred = [x[f"preds_{fold}"].detach() for x in outputs]
    y_true = torch.cat(batch_y_true)
    y_pred = torch.cat(batch_y_pred)
    f1 = torchmetrics.functional.f1_score(y_pred, y_true, num_classes=config_dict["num_classes"],task = 'multiclass', average = "micro")
    auroc = AUROC(num_classes=config_dict["num_classes"], average="macro", task = 'multiclass')
    auroc = auroc(y_pred, y_true)
    return {f'val_loss_{fold}': epoch_loss.item(), f'val_acc_{fold}': epoch_acc.item(), f"val_f1_{fold}": f1, f"val_auroc_{fold}": auroc}

def epoch_end(epoch, result, fold):
    print("Epoch [{}], \ntrain_acc: {:.4f}, train_loss: {:.4f}, train_f1: {:.4f}, train_auroc: {:4f}, \nval_acc: {:.4f}, val_loss: {:.4f}, val_f1: {:.4f}, val_auroc: {:.4f}".format(
        epoch, result[f'train_acc_{fold}'], result[f'train_loss_{fold}'], result[f'train_f1_{fold}'], result[f"train_auroc_{fold}"], result[f'val_acc_{fold}'], result[f'val_loss_{fold}'], result[f'val_f1_{fold}'], result[f"val_auroc_{fold}"]))

    


# Model Definitions

In [None]:
import torch.nn as nn
from torchvision import models


"""AlexNet model"""
if(config_dict['model'] == "alexnet"):
    print("alexnet")
    if (config_dict["pretrained"]==True):
        model = models.alexnet(pretrained=config_dict["pretrained"])
        model.classifier._modules["6"] = nn.Linear(model.classifier._modules["6"].in_features, config_dict["num_classes"])
        #model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    else:
        model = models.alexnet(pretrained=config_dict["pretrained"], num_classes=config_dict["num_classes"])
        #model.features[0] = nn.Conv2d(1, 64, kernel_size=(nn.BCEWithLogitsLoss11, 11), stride=(4, 4), padding=(2, 2))
    print(model)
#model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))


"""VGG16"""
if(config_dict['model'] == "vgg16"):
    if (config_dict["pretrained"]==True):
        model = models.vgg16(pretrained=config_dict["pretrained"])
        model.classifier._modules["6"] = nn.Linear(model.classifier._modules["6"].in_features, config_dict["num_classes"])
    else:
        model = models.vgg16(pretrained=config_dict["pretrained"], num_classes=config_dict["num_classes"])
    print(model)
# model.features[0] = nn.Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

"""VGG19"""
if(config_dict['model'] == "vgg19"):
    if (config_dict["pretrained"]==True):
        model = models.vgg19(pretrained=config_dict["pretrained"])
        model.classifier._modules["6"] = nn.Linear(model.classifier._modules["6"].in_features, config_dict["num_classes"])
    else:
        model = models.vgg19(pretrained=config_dict["pretrained"], num_classes=config_dict["num_classes"])
    print(model)
#model = models.vgg19(pretrained=False, num_classes=2)

"""resnet"""
if(config_dict['model'] == "resnet"):
    if (config_dict["pretrained"]==True):
        model = timm.create_model("resnet18", pretrained=config_dict["pretrained"], num_classes = config_dict["num_classes"])
        #model = models.resnet18(pretrained=config_dict["pretrained"])
        #model.fc = nn.Linear(model.fc.in_features, config_dict["num_classes"])
    else:
        #model = models.resnet18(pretrained=config_dict["pretrained"], num_classes = config_dict["num_classes"])
        model = timm.create_model("resnet18", pretrained=config_dict["pretrained"], num_classes = config_dict["num_classes"])
    print(model)
    
"""EfficientNET - V2"""
if(config_dict['model'] == "effnetv2"):
    if (config_dict["pretrained"]==True):
        model = timm.create_model("efficientnetv2_rw_t", pretrained=config_dict["pretrained"], num_classes = config_dict["num_classes"])
        #model.fc = nn.Linear(model.fc.in_features, config_dict["num_classes"])
    else:
        model = timm.create_model("efficientnetv2_rw_t", pretrained=config_dict["pretrained"], num_classes = config_dict["num_classes"])
    print(model) 

"""DeiT"""
if(config_dict['model'] == "DeiT"):
    if (config_dict["pretrained"]==True):
        model = timm.create_model("deit_tiny_patch16_224", num_classes = config_dict["num_classes"], img_size = 128, pretrained = config_dict["pretrained"])
        #model.fc = nn.Linear(model.fc.in_features, config_dict["num_classes"])
    else:
        model = timm.create_model("deit_tiny_patch16_224", num_classes = config_dict["num_classes"], img_size = 128, pretrained = config_dict["pretrained"])
    print(model) 

"""ViT"""
if(config_dict['model'] == "ViT"):
    if (config_dict["pretrained"]==True):
        model = timm.create_model("vit_tiny_patch16_224", num_classes = config_dict["num_classes"], img_size = 128, pretrained = config_dict["pretrained"])
        #model.fc = nn.Linear(model.fc.in_features, config_dict["num_classes"])
    else:
        model = timm.create_model("vit_tiny_patch16_224", num_classes = config_dict["num_classes"], img_size = 128, pretrained = config_dict["pretrained"])
    print(model) 

# """Vision Transformer"""
# from torchvision.models import resnet50

# from vit_pytorch.distill import DistillableViT, DistillWrapper
# from vit_pytorch import ViT, SimpleViT
# from vit_pytorch.cct import CCT

# if(config_dict['model'] == "ViT"):
#     # if(config_dict["pretrained"]==True):
#     #     model = 
        
#     # else:
#     model = ViT(
#                 image_size = 128,
#                 patch_size = 8,
#                 num_classes = 2,
#                 dim = 1024,
#                 depth = 6,
#                 heads = 8,
#                 mlp_dim = 2048,
#                 dropout = 0.1,
#                 emb_dropout = 0.1
#                 )
#     print(model)     
# print(model.fc)
# for param in model.parameters():
#     param.requires_grad = False
    # Replace the last fully-connected layer
    # Parameters of newly constructed modules have requires_grad=True by default
#model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

#print(model)



# model.training_step = training_step
# model.validation_step = validation_step
# model.validation_epoch_end = validation_epoch_end
# model.epoch_end = epoch_end

#print(model)


    
model = to_device(model, device)

wandb.watch(model)


summary(model, input_size = (3, 128, 128))


In [None]:
# model.EfficientNet.conv_head

In [None]:
@torch.no_grad()
def evaluate(model, val_loader, fold):
    model.eval()
    outputs = [validation_step(model, batch, criterion, fold) for batch in val_loader]
    return validation_epoch_end(outputs, fold)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_fold(epochs, lr, model, train_loader, val_loader, test_dl, grad_clip = config_dict["grad_clip"],
                 weight_decay = config_dict["weight_decay"], opt_func=torch.optim.SGD, criterion = F.cross_entropy,
                fold = 0):
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = np.inf
    best_acc = -np.inf
    history = []
    optimizer = opt_func(model.parameters(), lr = lr, weight_decay = config_dict["weight_decay"])
    sched = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config_dict["T_max"], eta_min=config_dict["min_lr"])
    lrs = []
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        for batch in tqdm(train_loader):
            loss = training_step(model, batch, criterion)
            train_losses.append(loss)
            loss.backward()
            
            ## GRADIENT CLIPPING
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
                 
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        with torch.no_grad():
            # Validation phase
            result = evaluate(model, val_loader, fold)
            if(result[f'val_loss_{fold}']<=best_loss):
                    best_loss = result[f'val_loss_{fold}']
                    best_model_wts = copy.deepcopy(model.state_dict())
                    PATH = f"Fold{fold}_{best_loss:.2f}_epoch_{epoch}.bin"
#                     torch.save(model.state_dict(), PATH)

            result_2 = evaluate(model, train_loader, fold)
            result[f'train_acc_{fold}'] = result_2[f'val_acc_{fold}']
            result[f'train_loss_{fold}'] = torch.stack(train_losses).mean().item()
            result[f'train_f1_{fold}'] = result_2[f"val_f1_{fold}"]
            result[f"train_auroc_{fold}"] = result_2[f"val_auroc_{fold}"]
            result[f"epoch_{fold}"] = epoch
            wandb.log({'Epoch': epoch,
                       "fold": fold,
                      "train_acc": result[f'train_acc_{fold}'],
                      "train_loss": result[f'train_loss_{fold}'],
                      "train_f1": result[f'train_f1_{fold}'],
                      "val_acc": result[f'val_acc_{fold}'],
                      "val_loss": result[f'val_loss_{fold}'],
                      "val_f1": result[f'val_f1_{fold}']})

            epoch_end(epoch, result, fold)
            history.append(result)
    # Testing phase
    result_temp = evaluate(model, test_dl, fold)
    test_result = {"test_acc": result_temp[f'val_acc_{fold}'],
                  "test_loss": result_temp[f'val_loss_{fold}'],
                  "test_f1": result_temp[f'val_f1_{fold}']}
    wandb.log(test_result)
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Loss ",best_loss)
    print("Best acc ", best_acc)

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

In [None]:
epochs = config_dict["num_epochs"]
lr = config_dict["lr"]

# for fixing OPTIMIZER
if config_dict["opt_func"] == "AdamW":
    opt_func = torch.optim.AdamW 
elif config_dict["opt_func"] == "Adam":
    opt_func = torch.optim.Adam
elif config_dict["opt_func"] == "RMSprop":
    opt_func = torch.optim.RMSprop

def cross_entropy(inputs, targets):
    pass

# for fixing LOSS
if config_dict["Loss_fn"] == "CE":
    criterion = F.cross_entropy().to(device)
elif config_dict["Loss_fn"] == 'focal':
    criterion = WeightedFocalLoss().to(device)
elif config_dict["Loss_fn"] == 'SCL':
    criterion = SupervisedContrastiveLoss(temperature=config_dict["temperature"]).to(device)

print(opt_func)

In [None]:
%%time
from torch.optim import lr_scheduler

history = []
model, history_temp = fit_one_fold(epochs, lr, model, train_dl, valid_dl, test_dl, opt_func=opt_func, criterion = criterion, fold = 0)
history += history_temp

In [None]:
x = F.one_hot(torch.tensor([1,0,1,0,1,1,1,1]), num_classes = 2) 
print(x.shape)

In [None]:

test_result = evaluate(model, test_dl, fold = 0)
print(test_result)

In [None]:
print((model))

# Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

# iterate over test data
for inputs, labels in test_dl:
        output = model(inputs) # Feed Network

        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output) # Save Prediction
        
        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth

# constant for classes
classes = tuple(config_dict["target_names"])
target_names = config_dict["target_names"]

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes],
                     columns = [i for i in classes])
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 22}

matplotlib.rc('font', **font)
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True, fmt = "g")
"""For Google-Colab"""
#BASE_PATH = os.getcwd()
"""For Kaggle"""
BASE_PATH = "/kaggle/working"
model_root_folder = os.path.join(BASE_PATH, project_name, config_dict["model"])
os.makedirs(model_root_folder, exist_ok = True)
cm_figName = os.path.join(model_root_folder, config_dict["model"] + "_Confusion_Matrix.png")
plt.savefig(cm_figName)

import sklearn
print(sklearn.metrics.classification_report(y_true, y_pred,target_names = target_names))

In [None]:
# wandb.save("/content/"+cm_figName, base_path = model_root_folder)

# classification report

In [None]:
report = sklearn.metrics.classification_report(y_true, y_pred,target_names = target_names, output_dict = True)
print(report)

In [None]:
df = pd.DataFrame(report).transpose()

In [None]:
df

In [None]:
summary_fname = os.path.join(model_root_folder, config_dict["model"] + "_summary.txt")
df.to_csv(summary_fname)

In [None]:
# wandb.save("/content/"+summary_fname, base_path = "/content/")

# Save model weights

In [None]:
model_path = os.path.join(model_root_folder, config_dict["model"] + "_model_weights.pth")
torch.save(model, model_path)

# GradCAM

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
from scipy import misc
from PIL import Image


def reshape_transform(tensor, height=8, width=8):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result


def show_grad_CAM(valid_ds, model_path, img_number):
    model = torch.load(model_path,map_location=torch.device('cpu'))
    if(config_dict["model"] == "alexnet"):
        target_layers = [model.features[10]]
    if(config_dict["model"] == "vgg16"):
        print("yo")
        target_layers = [model.features[-1]]
    if(config_dict["model"] == "vgg19"):
        print("yo")
        target_layers = [model.features[-1]]
    if(config_dict["model"] == "resnet"):
        target_layers = [model.layer4[-1]]
    if(config_dict["model"] == "effnetv2"):
        target_layers = [model.conv_head]
    if(config_dict["model"] == "ViT"):
        target_layers = [model.blocks[-1].norm1]
    if(config_dict["model"] == "DeiT"):
        target_layers = [model.blocks[-1].norm1]
    img, label = valid_ds[img_number]
    img_original = img.clone()
    print(img.shape)
    print(target_layers)
    img = img.unsqueeze(0)
    input_tensor = img # Create an input tensor image for your model..
    # Note: input_tensor can be a batch tensor with several images!

    # Construct the CAM object once, and then re-use it on many images:
    cam = GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=False)

    # You can also use it within a with statement, to make sure it is freed,
    # In case you need to re-create it inside an outer loop:
    # with GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda) as cam:
    #   ...

    # We have to specify the target we want to generate
    # the Class Activation Maps for.
    # If targets is None, the highest scoring category
    # will be used for every image in the batch.
    # Here we use ClassifierOutputTarget, but you can define your own custom targets
    # That are, for example, combinations of categories, or specific outputs in a non standard model.
    print(label)
    targets = [ClassifierOutputTarget(label)]

    # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth = True, eigen_smooth =True)

    # In this example grayscale_cam has only one image in the batch:
    grayscale_cam = grayscale_cam[0, :]
    plt.figure(0)
    print(valid_ds[img_number][0].shape)
    original_image = valid_ds[img_number][0].permute(1,2,0)
    plt.imshow(original_image)
    
    img_path = os.path.join(model_root_folder, f'_Orginal_img_{img_number}_{label}.jpg')
    #misc.imshow(original_image)
    #temp_img = Image.fromarray(original_image.numpy(), "RGB")
    #temp_img.save("my.png")
    #temp_img.show()
    plt.savefig(img_path, bbox_inches='tight')

    plt.figure(1)
    plt.imshow(grayscale_cam)
    gradcam_path = os.path.join(model_root_folder, f'_GradCAM_active_{img_number}_{label}.jpg')
    plt.savefig(gradcam_path, bbox_inches='tight')
    rgb_img = valid_ds[img_number][0].permute(1,2,0)
    original_image = (original_image - original_image.min())/(original_image.max() - original_image.min()) 
    print(f"grayscale_cam_max: {grayscale_cam.max()} , grayscale_cam_min: {grayscale_cam.min()}")
    print(f"original_img_max: {original_image.max()} , original_img_min: {original_image.min()}")
    visualization = show_cam_on_image(original_image.numpy().astype(np.float32), grayscale_cam, use_rgb=True)
    
    plt.figure(2)
    plt.imshow(visualization)
    gradcam_overlay_path= os.path.join(model_root_folder, f'_GradCAM_overlap_{img_number}_{label}.jpg')
    plt.savefig(gradcam_overlay_path, bbox_inches='tight')


    return img_path, label, gradcam_path, gradcam_overlay_path

In [None]:
show_grad_CAM(valid_ds, model_path, img_number = 5)

## Feature maps visualization

In [None]:
def intermediate_feature_maps(valid_ds, model_path, img_number):
    model = torch.load(model_path,map_location=torch.device('cpu'))
    # we will save the conv layer weights in this list
    model_weights =[]
    #we will save the 49 conv layers in this list
    conv_layers = []# get all the model children as list
    #model_children = list(model.children())#counter to keep count of the conv layers
    model_children = [module for module in model.modules() if not isinstance(module, nn.Sequential)]
    #print(model_children)
    counter = 0#append all the conv layers and their respective wights to the list
    for i in range(len(model_children)):
        #print(type(model_children[i]))
        if type(model_children[i]) == nn.Conv2d:
            counter+=1
            model_weights.append(model_children[i].weight)
            conv_layers.append(model_children[i])
        elif type(model_children[i]) == nn.Sequential:
            for j in range(len(model_children[i])):
                for child in model_children[i][j].children():
                    if type(child) == nn.Conv2d:
                        counter+=1
                        model_weights.append(child.weight)
                        conv_layers.append(child)
    print(f"Total convolution layers: {counter}")
    print("conv_layers")
    outputs = []
    names = []
    image, label = valid_ds[img_number]
    image = image.unsqueeze(0)
    for layer in conv_layers[0:]:
        print(layer)
        image = layer(image)
        print(image.shape)
        outputs.append(image)
        names.append(str(layer))
    print(len(outputs))#print feature_maps
    for feature_map in outputs:
        print(feature_map.shape)

    processed = []
    for feature_map in outputs:
        feature_map = feature_map.squeeze(0)
        gray_scale = torch.sum(feature_map,0)
        gray_scale = gray_scale / feature_map.shape[0]
        processed.append(gray_scale.data.cpu().numpy()) 

    for fm in processed:
        print(fm.shape)

    fig = plt.figure(figsize=(30, 50))
    for i in range(min(len(processed), 40)):
        a = fig.add_subplot(10, 4, i+1)
        imgplot = plt.imshow(processed[i], cmap = "gray")
        a.axis("off")
        a.set_title(names[i].split('(')[0], fontsize=30)
    fmap_save_path = os.path.join(model_root_folder, f'feature_maps_{img_number}_{label}.jpg')
    plt.savefig(fmap_save_path, bbox_inches='tight')

    return fmap_save_path

In [None]:
def get_attention_map(img_number, get_mask=False):
    #x = tfms_org(img)
    #x.size()
    model = torch.load(model_path,map_location=torch.device('cpu'))
    img, label = valid_ds[img_number]
    logits, att_mat = model(img.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        result = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        result = (mask * img).astype("uint8")
    
    return result

def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)

In [None]:
#result1 = get_attention_map(6)

In [None]:
if(config_dict['model'] != 'ViT' or config_dict['model'] != 'DeiT'):
    intermediate_feature_maps(valid_ds, model_path, img_number = 5)

In [None]:
wandb_file_save_list = glob(os.path.join(model_root_folder, "*"))
for fname in wandb_file_save_list:
    wandb.save(fname, base_path = model_root_folder)


# WanDB Table

In [None]:
labels_array = [int(label) for img, label in valid_ds]

In [None]:
unique_label = list(set(labels_array))
index_log = {}
for i in unique_label:
    index_log[i] = []

for i in range(len(valid_ds)):
    if labels_array[i]==0:
        index_log[0] +=[i]
    if labels_array[i]==1:
        index_log[1] +=[i]
    if labels_array[i]==2:
        index_log[2] +=[i]

In [None]:
interest_index_list = []
for key, val in index_log.items():
    print(key)
    interest_index_list += val[:10]

In [None]:
for i in interest_index_list:
    print(valid_ds[i][1])

In [None]:
if(config_dict["model"] == 'ViT' or config_dict["model"] == 'DeiT'):
    table_columns = ["Image", "Class label", "gradCAM", "gradCAM_overlay", ]
    img_table = wandb.Table(columns = table_columns)

    for imgNumber in interest_index_list:
        img, true_label, gradcam_img, gradcam_overlay = show_grad_CAM(valid_ds, model_path, img_number = imgNumber)
        #feature_maps = intermediate_feature_maps(valid_ds, model_path, img_number = imgNumber)
        img_table.add_data(wandb.Image(img),\
                        config_dict["target_names"][true_label], wandb.Image(gradcam_img),\
                        wandb.Image(gradcam_overlay))


    log_batch_data = wandb.Artifact(name = project_run_name, 
                                    type="predictions",
                                    description = project_name,
                                    metadata = config_dict)
    log_batch_data.add(img_table, "validation_batch_record")
    wandb.log_artifact(log_batch_data)
else: 
    table_columns = ["Image", "Class label", "gradCAM", "gradCAM_overlay", "feature_maps"]
    img_table = wandb.Table(columns = table_columns)

    for imgNumber in interest_index_list:
        img, true_label, gradcam_img, gradcam_overlay = show_grad_CAM(valid_ds, model_path, img_number = imgNumber)
        feature_maps = intermediate_feature_maps(valid_ds, model_path, img_number = imgNumber)



        img_table.add_data(wandb.Image(img),\
                        config_dict["target_names"][true_label], wandb.Image(gradcam_img),\
                        wandb.Image(gradcam_overlay),\
                        wandb.Image(feature_maps))


    log_batch_data = wandb.Artifact(name = project_run_name, 
                                    type="predictions",
                                    description = project_name,
                                    metadata = config_dict)
    log_batch_data.add(img_table, "validation_batch_record")
    wandb.log_artifact(log_batch_data)

# Finish WandB process

In [None]:
time.sleep(600)

In [None]:
wandb.finish()