In [None]:
import os
import gc
import glob
import random
import pickle
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import utils

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, early_stopping
seed_everything(7, workers = True)

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import roc_auc_score, roc_curve, auc

import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

device='cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def set_seed(seed):
    #Sets the seed for Reprocudibility
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        
set_seed(7)

In [None]:
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok = True)

In [None]:
config = dict(
    batch_size = 128,
    epochs = 10,
    learning_rate = 3e-4,
    model_name = "Model2_Classifier"
    )

In [None]:
class_map = {}
for index, category in enumerate(glob.glob(r'../input/gsocml4scimodel2/Model_II/*')):
    class_name = category.split(os.sep)[-1]
    class_map[class_name] = index
    
class_map                          

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_paths_and_labels_list, class_map, transform = None):
        root_list = img_paths_and_labels_list
 
        self.class_distribution = {}
        self.transform = transform
        self.class_map = class_map

        for img_path in root_list:
            class_name = img_path.split(os.sep)[-2]
            if class_name not in self.class_distribution:
                self.class_distribution[class_name] = 1
            else:
                self.class_distribution[class_name] +=1

        print("Dataset Distribution:\n")
        print(self.class_distribution)


        self.data = []
        for img_path in tqdm(root_list):
            class_name = img_path.split(os.sep)[-2]
            self.data.append([img_path, class_name])
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = np.load(img_path, allow_pickle = True)
        if class_name == 'axion':
            img = img[0]
        
#         img = np.expand_dims(img, axis = 0)
        
        if self.transform:
            aug = self.transform(image = img)
            img = aug['image']
        
        img = img.to(torch.float)
        class_id = self.class_map[class_name]
        class_id = torch.tensor(class_id)
        
        return img, class_id

In [None]:
class CustomDataModule(pl.LightningDataModule):
    
    def __init__(self, data_root_dir, test_data_root_dir, val_split_ratio = 0.1):
        super().__init__()
        self.dataset_img_paths_list = glob.glob(data_root_dir)
        self.test_paths_list = glob.glob(test_data_root_dir)

        self.test_data_root_dir = test_data_root_dir
        self.val_split_ratio = val_split_ratio
        

    def setup(self, stage):
        val_split = int(self.val_split_ratio * len(self.dataset_img_paths_list))
        random.shuffle(self.dataset_img_paths_list)

        val_paths_list = self.dataset_img_paths_list[:val_split]
        train_paths_list = self.dataset_img_paths_list[val_split:]

        assert len(self.dataset_img_paths_list) == (len(train_paths_list) + len(val_paths_list))

        train_transforms = A.Compose(
                    [
                        A.CenterCrop(height = 50, width = 50, p=1.0),
                        A.HorizontalFlip(p = 0.5),
                        A.VerticalFlip(p = 0.5),
                        A.Rotate(limit = 360, p = 0.4),
                        ToTensorV2()
                    ]
                )

        test_transforms = A.Compose(
                    [
                        A.CenterCrop(height = 50, width = 50, p=1.0),
                        ToTensorV2()
                    ]
                )
        

        self.train_dataset = CustomDataset(train_paths_list, class_map = class_map, transform = train_transforms)
        self.val_dataset = CustomDataset(val_paths_list, class_map = class_map, transform = test_transforms)
        self.test_dataset = CustomDataset(self.test_paths_list, class_map = class_map, transform = test_transforms)


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = config["batch_size"], shuffle = True, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size = config["batch_size"], shuffle = False, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = config["batch_size"], shuffle = False, pin_memory=True)


In [None]:
data_module = CustomDataModule(data_root_dir = r'../input/gsocml4scimodel2/Model_II/*/*', 
                               test_data_root_dir = r'../input/gsocml4scimodel2test/Model_II_test/*/*', 
                               val_split_ratio=0.1)
data_module.setup(stage='fit')
train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

In [None]:
single_batch = next(iter(train_loader))
single_batch[0].shape

single_batch_grid = utils.make_grid(single_batch[0][:16], nrow=8)
plt.figure(figsize = (20,70))
plt.imshow(single_batch_grid.permute(1, 2, 0))

In [None]:
def calculate_accuracy(y_pred, y_truth):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_labels = torch.max(y_pred_softmax, dim = 1)
    
    correct_preds = (y_pred_labels == y_truth).float()
    acc = correct_preds.sum() / len(correct_preds)
    acc = torch.round(acc*100)
    
    return acc  

In [None]:
class PreTrainedModel(pl.LightningModule):
    
    def __init__(self, pretrained = True):
        super().__init__()
        self.model = timm.create_model('efficientnet_b1',pretrained = pretrained, in_chans = 1)
        for param in self.model.parameters():
            param.requires_grad = True         
    
        self.fc = nn.Sequential(
                                nn.Linear(1280 * 2 * 2, 1024),
                                nn.PReLU(),
                                nn.BatchNorm1d(1024),
                                nn.Dropout(p = 0.5),
                                
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.PReLU(),
                                nn.Dropout(p = 0.5),
        
                                nn.Linear(512, 128),
                                nn.PReLU(),
                                nn.BatchNorm1d(128),
                                nn.Dropout(p = 0.3),
                                
                                nn.Linear(128, 3)
                                )
        
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr = config['learning_rate'])
        self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr = config['learning_rate'], 
                                                  epochs = config['epochs'], 
                                                  steps_per_epoch = len(train_loader), 
                                                  verbose = False)
#         self.checkpoint_callback = ModelCheckpoint(
#             monitor = 'val accuracy',  # Monitor validation accuracy
#             mode = 'max',              # Maximize the monitored metric
#             save_top_k = 1,            # Save the top 1 model
#             dirpath = 'checkpoints/',  # Directory to save checkpoints
#             filename = f'{config["model_name"]}.ckpt', 
#         )

        
    def forward(self, x):
        x = self.model.forward_features(x)
        x = x.view(-1, 1280 * 2 * 2)
        x = self.fc(x)
        return x
    
    
    def cross_entropy_loss(self, logits, labels):
        return nn.CrossEntropyLoss()(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = calculate_accuracy(logits, y)
        self.log('train_loss', loss)
        self.log('train_accuracy', accuracy)
        self.train_losses.append(loss.item())
        self.train_accuracies.append(accuracy.item())
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = calculate_accuracy(logits, y)
        self.log('val_loss', loss)
        self.log('val_accuracy', accuracy)
        self.val_losses.append(loss.item())
        self.val_accuracies.append(accuracy.item())
        return loss
        
    def configure_optimizers(self):
        optimizer = self.optimizer
        scheduler = self.scheduler
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
    
#     def configure_callbacks(self):
#         return [self.checkpoint_callback]

    def on_batch_end(self):
        self.scheduler.step()

In [None]:
model = PreTrainedModel()

x = torch.randn(single_batch[0].shape)
print(x.shape)
print(model(x).shape)

del model
gc.collect()

In [None]:
def train_model():
    model = PreTrainedModel()
    
    early_stop = early_stopping.EarlyStopping(monitor = "val_accuracy", patience = 5, mode = "max")
    checkpoint_callback = ModelCheckpoint(save_top_k = 1, monitor = "val_accuracy", mode = "max")
    
    trainer = pl.Trainer(
        accelerator = "gpu",
        devices = 1,
        max_epochs = config['epochs'],
        callbacks = [early_stop],
        enable_checkpointing = True
    )

    trainer.fit(model, data_module)
#     model = PreTrainedModel.load_from_checkpoint(checkpoint_callback.best_model_path)
    
    # Plot training and validation loss trends
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(model.train_losses, label='Training Loss')
    plt.plot(model.val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot training and validation accuracy trends
    plt.subplot(1, 2, 2)
    plt.plot(model.train_accuracies, label='Training Accuracy')
    plt.plot(model.val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()
    
    return model

In [None]:
# model = train_model()

In [None]:
model = PreTrainedModel()
    
early_stop = early_stopping.EarlyStopping(monitor = "val_accuracy", patience = 5, mode = "max")
checkpoint_callback = ModelCheckpoint(save_top_k = 1, monitor = "val_accuracy", mode = "max")

trainer = pl.Trainer(
    accelerator = "gpu",
    devices = 1,
    max_epochs = config['epochs'],
    callbacks = [early_stop],
    enable_checkpointing = True
)

trainer.fit(model, data_module)

In [None]:
# Plot training and validation loss trends
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(model.train_losses, label='Training Loss')
plt.plot(model.val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Plot training and validation accuracy trends
plt.subplot(1, 2, 2)
plt.plot(model.train_accuracies, label='Training Accuracy')
plt.plot(model.val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
def test_epoch(model, dataloader, criterion):

    model.eval()
    test_loss = []
    test_accuracy = []
    
    y_pred_list = []
    y_pred_prob_list = []
    y_truth_list = []

    with torch.no_grad():

        loop=tqdm(enumerate(dataloader),total=len(dataloader))
        
        for batch_idx, (img_batch,labels) in loop:
            X = img_batch.to(device)
            y_truth = labels.to(device)
            y_truth_list.append(y_truth.detach().cpu().numpy())

            #forward prop
            y_pred = model(X)
            y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
            y_pred_prob_list.append(torch.softmax(y_pred, dim = 1).detach().cpu().numpy())
            _, y_pred_labels = torch.max(y_pred_softmax, dim = 1)
            y_pred_list.append(y_pred_labels.detach().cpu().numpy())

            #loss and accuracy calculation
            loss = criterion(y_pred, y_truth)
            accuracy = calculate_accuracy(y_pred, y_truth)


            #batch loss and accuracy
            # print(f'Partial train loss: {loss.data}')
            test_loss.append(loss.detach().cpu().numpy())
            test_accuracy.append(accuracy.detach().cpu().numpy())
            
    return y_pred_prob_list, y_pred_list, y_truth_list, np.mean(test_loss), np.mean(test_accuracy)

In [None]:
criterion = nn.CrossEntropyLoss()
y_pred_prob_list, y_pred_list, y_truth_list, test_loss, test_accuracy = test_epoch(model.to(device), test_loader, criterion)

print(test_loss, test_accuracy)

In [None]:
def flatten_list(x):
    return [j for i in x for j in i]


y_pred_list_flattened = flatten_list(y_pred_list)
y_truth_list_flattened = flatten_list(y_truth_list)
y_pred_prob_list_flattened = flatten_list(y_pred_prob_list)

In [None]:
idx2class = {v: k for k, v in class_map.items()}
class_names = [i for i in class_map.keys()]
idx2class

In [None]:
print(classification_report(y_truth_list_flattened, y_pred_list_flattened,target_names = class_names))

In [None]:
print(confusion_matrix(y_pred_list_flattened, y_truth_list_flattened))

In [None]:
confusion_matrix_df = pd.DataFrame(confusion_matrix(y_truth_list_flattened, y_pred_list_flattened)).rename(columns=idx2class, index=idx2class)
fig, ax = plt.subplots(figsize=(19,12))         
sns.heatmap(confusion_matrix_df, fmt = ".0f", annot=True, ax=ax)

In [None]:
print(roc_auc_score(y_truth_list_flattened, y_pred_prob_list_flattened, average='macro', multi_class="ovr"))
print(roc_auc_score(y_truth_list_flattened, y_pred_prob_list_flattened, average='macro', multi_class="ovo"))
print()
print(roc_auc_score(y_truth_list_flattened, y_pred_prob_list_flattened, average='weighted', multi_class="ovr"))
print(roc_auc_score(y_truth_list_flattened, y_pred_prob_list_flattened, average='weighted', multi_class="ovo"))

In [None]:
fpr = {}
tpr = {}
roc_auc = {}
thresh ={}

n_class = 3

for i in range(n_class):    
    fpr[i], tpr[i], thresh[i] = roc_curve(np.array(y_truth_list_flattened), np.array(y_pred_prob_list_flattened)[:,i], pos_label=i)
    roc_auc[i] = auc(fpr[i], tpr[i])

In [None]:
plt.figure(figsize=(17,12))
plt.rcParams['font.size'] = '30'
for i in range(0,n_class):
    plt.plot(fpr[i], tpr[i], linestyle='--', label=f'{idx2class[i]} AUC = {roc_auc[i]:.3f}')

    plt.title('Multiclass ROC curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive rate')
    plt.legend(loc='best')
    
plt.savefig('ROC_curves.png',dpi=352); 