In [65]:
#/r

# Load Data

In [41]:
OUTPUT_FOLDER = "/scratch/aakash_ks.iitr/dr-scnn/"
DATA_FOLDER = "/scratch/aakash_ks.iitr/data/diabetic-retinopathy/"
# TRAIN_DATA_FOLDER = DATA_FOLDER + 'resized_train/'
TRAIN_DATA_FOLDER = DATA_FOLDER + 'resized_train_c/'

TEST_DATA_FOLDER = DATA_FOLDER + 'test/'

# Imports

In [42]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from PIL import Image

plt.rcParams['figure.dpi'] = 100

In [43]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.transforms import v2

import timm

In [44]:

NUM_CLASSES = 5

class CFG:
    seed = 42
    N_folds = 6
    train_folds = [0, ] # [0,1,2,3,4]

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    apex=True # use half precision
    workers = 16

    model_name = "resnet50.a1_in1k"
    epochs = 20
    cropped = True
    # weights =  torch.tensor([0.206119, 0.793881],dtype=torch.float32)

    clip_val = 1000.
    batch_size = 64
    # gradient_accumulation_steps = 1

    lr = 5e-3
    weight_decay=1e-2
    
    resolution = 224
    samples_per_class = 1000
    frozen_layers = 0

In [45]:
import wandb
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# wandb.login(key=user_secrets.get_secret("wandb_api"))

run = wandb.init(
    project="hello-world", 
    dir=OUTPUT_FOLDER,
    config={
    k:v for k, v in CFG.__dict__.items() if not k.startswith('__')}
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maakashks_[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [46]:
from sklearn.manifold import TSNE
import matplotlib.colors as mcolors

class LinearClassifier(nn.Module):
    def __init__(self, in_features=2048, num_classes=NUM_CLASSES):
        super().__init__()
        self.model = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)
    

class SupConModel(nn.Module):
    def __init__(self, encoder, input_dim=2048, output_dim=128):        # assuming either resnet50 or resnet101 is used
        super().__init__()
        self.encoder = encoder
        self.head = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, output_dim)
        )
    
    def forward(self, x):
        ft = self.encoder(x)
        return F.normalize(self.head(ft), dim=1)


class ImageTrainDataset(Dataset):
    def __init__(
            self,
            folder,
            data,
            transforms,
    ):
        self.folder = folder
        self.data = data
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        d = self.data.loc[index]
        image = Image.open(f"{self.folder}{d.image}.jpeg")
        image = self.transforms(image)
        label = d.level

        return image, torch.tensor(label, dtype=torch.long)


def plot_tsne(embeddings, labels):
    # Apply t-SNE to the embeddings
    tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
    tsne_results = tsne.fit_transform(embeddings.numpy())

    # Define the number of unique labels/classes
    num_classes = len(np.unique(labels.numpy()))
    # Create a custom color map with specific color transitions
    colors = ['blue', 'green', 'yellow', 'orange', 'red']
    cmap = mcolors.LinearSegmentedColormap.from_list("Custom", colors, N=num_classes)

    # Create a boundary norm with boundaries and colors
    norm = mcolors.BoundaryNorm(np.arange(-0.5, num_classes + 0.5, 1), cmap.N)

    fig = plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap=cmap, norm=norm, alpha=0.5)
    colorbar = plt.colorbar(scatter, ticks=np.arange(num_classes))
    colorbar.set_label('Severity Level')
    colorbar.set_ticklabels(np.arange(num_classes))  # Set discrete labels if needed
    plt.title('t-SNE of Image Embeddings with Discrete Severity Levels')
    plt.xlabel('t-SNE Axis 1')
    plt.ylabel('t-SNE Axis 2')
    fg = wandb.Image(fig)
    wandb.log({"t-SNE": fg})
    plt.savefig(os.path.join(wandb.run.dir, f"tsne.png"), dpi=300, bbox_inches='tight')



class style:
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    END = '\033[0m'
    BOLD = '\033[1m'


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [47]:
device = torch.device(CFG.device)

# Dataset

In [48]:
from torchvision.transforms import functional as func

class CustomTransform:
    def __init__(self, output_size=(CFG.resolution, CFG.resolution), radius_factor=0.9):
        self.output_size = output_size
        self.radius_factor = radius_factor

    def __call__(self, img):
        # Assuming img is a PIL Image
        # Normalize and preprocess as previously defined
        img = func.resize(img, int(min(img.size) / self.radius_factor))
        img_tensor = func.to_tensor(img)
        mean, std = img_tensor.mean([1, 2]), img_tensor.std([1, 2])
        img_normalized = func.normalize(img_tensor, mean.tolist(), std.tolist())
        kernel_size = 15
        padding = kernel_size // 2
        avg_pool = torch.nn.AvgPool2d(kernel_size, stride=1, padding=padding)
        local_avg = avg_pool(img_normalized.unsqueeze(0)).squeeze(0)
        img_subtracted = img_normalized - local_avg
        center_crop_size = int(min(img_subtracted.shape[1:]) * self.radius_factor)
        img_cropped = func.center_crop(img_subtracted, [center_crop_size, center_crop_size])

        # Apply augmentations
        img_resized = func.resize(img_cropped, self.output_size)

        return img_resized

In [49]:
# train_transforms = CustomTransform()

train_transforms = v2.Compose([
    v2.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2)),  # Gaussian blur with random kernel size and sigma
    v2.RandomRotation(degrees=(0, 90)),  # Random rotation between 0 and 360 degrees
    CustomTransform(),
    # v2.RandomResizedCrop(CFG.resolution, scale=(0.8, 1.0)),  # Krizhevsky style random cropping
    v2.RandomHorizontalFlip(),  # Random horizontal flip
    v2.RandomVerticalFlip(),  # Random vertical flip
    v2.ToDtype(torch.float32, scale=False),
])

val_transforms = v2.Compose([
    CustomTransform(),
    v2.ToDtype(torch.float32, scale=False),
])

In [50]:
os.listdir(TEST_DATA_FOLDER )

['DR2', 'Normal', 'DR3', 'DR1']

In [51]:
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder


In [66]:
from torch.utils.data import random_split

# visualize the transformations
data = ImageFolder(TEST_DATA_FOLDER, transform=val_transforms)

# Define the ratio for the split
train_ratio = 0.75  # 75% training, 25% validation
train_size = int(train_ratio * len(data))
val_size = len(data) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(data, [train_size, val_size])

In [68]:
len(val_dataset)

29

In [60]:
# split the data into train and validation



{'DR1': 0, 'DR2': 1, 'DR3': 2, 'Normal': 3}

# Metric

In [61]:
from sklearn.metrics import f1_score as sklearn_f1
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, precision_score

# Train and evaluate functions

In [35]:
def evaluate_model(cfg, feature_extractor, model, data_loader, epoch=-1):
    targets = []
    predictions = []

    total_len = len(data_loader)
    tk0 = tqdm(enumerate(data_loader), total=total_len)
    
    with torch.no_grad():
        for step, (images, labels) in tk0:
            images = images.to(device)
            target = labels.to(device)
            
            features = feature_extractor(images)
            logits = model(features)
            
            targets.append(target.detach().cpu())
            predictions.append(logits.detach().cpu())
            del images, target, logits

    targets = torch.cat(targets, dim=0)
    predictions = torch.cat(predictions, dim=0)
    probabilities = F.softmax(predictions, dim=1)

    # base_score, best_score, best_th = find_best_threshold(targets, predictions[:, 1])
    # For multi-class classification, you might need the class with the highest probability
    predicted_classes = predictions.argmax(dim=1)

    try:
        wandb.log({"roc": wandb.plot.roc_curve(targets.numpy(), probabilities.numpy())})
        roc_auc = roc_auc_score(targets.numpy(), probabilities.numpy(), multi_class='ovo')
        
        wandb.log({"pr": wandb.plot.pr_curve(targets.numpy(), probabilities.numpy())})
        
    except:
        roc_auc = 0

    # Calculate accuracy
    accuracy = accuracy_score(targets.numpy(), predicted_classes.numpy())

    precision = precision_score(targets.numpy(), predicted_classes.numpy(), average='weighted')

    print(f'Epoch {epoch}: auc = {roc_auc:.4f} accuracy = {accuracy:.4f} precision = {precision:.4f}')
    return roc_auc, accuracy, precision

In [36]:
def create_model():
    # get the feature extractor
    resnet = timm.create_model(CFG.model_name, num_classes=0, pretrained=False)
    feature_extractor = SupConModel(resnet)
    feature_extractor.load_state_dict(torch.load(OUTPUT_FOLDER + 'ckpt_epoch_8.pth'))
    
    # remove the projection head
    feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])

    # create a simple linear classifier
    classifier = LinearClassifier()
    classifier.load_state_dict(torch.load(OUTPUT_FOLDER + 'lc_sclr_11.pth'))
    return feature_extractor.to(device), classifier.to(device)

In [37]:
from sklearn.manifold import TSNE
import matplotlib.colors as mcolors

def get_embeddings(model, data_loader):
    model.eval()
    
    features = []
    targets = []

    total_len = len(data_loader)
    tk0 = tqdm(enumerate(data_loader), total=total_len)
    with torch.no_grad():
        for step, (images, labels) in tk0:
            images = images.to(device)
            target = labels.to(device)

            embds = model(images)

            features.append(embds.detach().cpu())
            targets.append(target.detach().cpu())

    features = torch.cat(features, dim=0)
    targets = torch.cat(targets, dim=0)
    
    # # store the embeddings for future use
    # torch.save(features, os.path.join(wandb.run.dir, f"embeddings.pth"))
    # torch.save(targets, os.path.join(wandb.run.dir, f"targets.pth"))

    return features, targets


## Train folds

In [71]:
seed_everything(CFG.seed)

test_loader = DataLoader(
    test_data,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=CFG.workers,
    pin_memory=True,
    drop_last=False,
)

# PREPARE MODEL, OPTIMIZER AND SCHEDULER
feature_extractor, model = create_model()
feature_extractor.eval()
model.eval()

val_auc, val_accuracy, val_precision = evaluate_model(CFG, model, valid_loader, loss_criterion, epoch)

# Log metrics to wandb
wandb.log({
    'val_auc': val_auc,
    'val_accuracy': val_accuracy,
    'val_precision': val_precision,
})

features, targets = get_embeddings(feature_extractor, test_loader)
plot_tsne(features, targets)

Layer conv1 is trainable.
Layer bn1 is trainable.
Layer act1 is trainable.
Layer maxpool is trainable.
Layer layer1 is trainable.
Layer layer2 is trainable.
Layer layer3 is trainable.
Layer layer4 is trainable.
Layer global_pool is trainable.
Layer fc is trainable.
Model parameters: 23_518_277


Epoch 0 training 20/20 [LR 0.004969] - loss: 1.6485: 100%|███████████████| 20/20 [00:21<00:00,  1.09s/it]


Epoch 0: training loss = 1.6485 auc = 0.4868 accuracy = 0.2150 precision = 0.2188


100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.90s/it]


Epoch 0: validation loss = 1.6251 auc = 0.0000 accuracy = 0.2143 precision = 0.0940
[92mNew best score: 0.0000 -> 0.2143[0m


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Epoch 1 training 20/20 [LR 0.004878] - loss: 1.5778: 100%|███████████████| 20/20 [00:21<00:00,  1.06s/it]


Epoch 1: training loss = 1.5778 auc = 0.5908 accuracy = 0.2550 precision = 0.2785


100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.93s/it]


Epoch 1: validation loss = 1.6735 auc = 0.0000 accuracy = 0.2619 precision = 0.3811
[92mNew best score: 0.2143 -> 0.2619[0m


Epoch 2 training 20/20 [LR 0.004728] - loss: 1.5015: 100%|███████████████| 20/20 [00:21<00:00,  1.05s/it]

Epoch 2: training loss = 1.5015 auc = 0.6787 accuracy = 0.3000 precision = 0.2893



100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.86s/it]


Epoch 2: validation loss = 2.1124 auc = 0.0000 accuracy = 0.2619 precision = 0.2526


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Epoch 3 training 20/20 [LR 0.004523] - loss: 1.3510: 100%|███████████████| 20/20 [00:20<00:00,  1.02s/it]


Epoch 3: training loss = 1.3510 auc = 0.7412 accuracy = 0.3800 precision = 0.3749


100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.94s/it]


Epoch 3: validation loss = 2.7870 auc = 0.0000 accuracy = 0.2143 precision = 0.1909


Epoch 4 training 20/20 [LR 0.004268] - loss: 1.3476: 100%|███████████████| 20/20 [00:21<00:00,  1.06s/it]


Epoch 4: training loss = 1.3476 auc = 0.7447 accuracy = 0.3550 precision = 0.3476


100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.85s/it]


Epoch 4: validation loss = 3.1248 auc = 0.0000 accuracy = 0.2381 precision = 0.1518


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Epoch 5 training 20/20 [LR 0.003970] - loss: 1.3567: 100%|███████████████| 20/20 [00:20<00:00,  1.02s/it]


Epoch 5: training loss = 1.3567 auc = 0.7494 accuracy = 0.3750 precision = 0.3971


100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.86s/it]


Epoch 5: validation loss = 2.2044 auc = 0.0000 accuracy = 0.2143 precision = 0.1821


Epoch 6 training 20/20 [LR 0.003635] - loss: 1.4598: 100%|███████████████| 20/20 [00:20<00:00,  1.04s/it]


Epoch 6: training loss = 1.4598 auc = 0.6979 accuracy = 0.3350 precision = 0.3336


  0%|                                                                              | 0/5 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()