In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils, datasets
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import StratifiedKFold
import pdb
from sklearn.metrics import roc_curve, auc
import sklearn.metrics as metrics

In [2]:
def multi_acc(y_pred, y_test):
    # Apply sigmoid to each prediction (for multilabel or binary classification)
    y_pred_sigmoid = torch.sigmoid(y_pred)

    # Convert probabilities to binary labels (threshold > 0.5)
    y_pred_tags = (y_pred_sigmoid > 0.5).float()

    # Compare with the ground truth
    correct_pred = (y_pred_tags == y_test).float()

    # Calculate accuracy
    acc = correct_pred.sum() / len(correct_pred)

    # Convert to percentage
    acc = torch.round(acc * 100)

    return acc

In [3]:
from transformers import ViTForImageClassification
import torch.nn as nn
import torch

# Load the pre-trained Vision Transformer model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",  # Pre-trained ViT model
    num_labels=6  # Number of output classes
)

# Freeze the feature extraction layers
for param in model.vit.parameters():
    param.requires_grad = False

# Replace the classifier head with a custom classifier
model.classifier = nn.Sequential(
    nn.Linear(model.config.hidden_size, 256),  # ViT's hidden size
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 6),  # Output layer for 6 classes
    nn.Sigmoid()  # Change to Sigmoid for multilabel classification
)

# Print the updated classifier
print(model.classifier)

# Move the model to the appropriate device (GPU/CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Sequential(
  (0): Linear(in_features=768, out_features=256, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=256, out_features=6, bias=True)
  (4): Sigmoid()
)


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [4]:
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

# Define class weight if necessary. Otherwise use default
#weights = [1, 1] # 1 weight for class 0 and 0.5 weight for class 1
#class_weights=torch.FloatTensor(weights).cuda()

# Define loss function and optimizer
#criterion = nn.CrossEntropyLoss(weight=class_weights)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
#optimizer=optim.SGD(model.parameters(), lr=0.001)

85,997,062 total parameters.
198,406 training parameters.


In [5]:
# Main Code
# Data loader and transformaiton
image_transforms = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()]),
    "test": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()]),
        }

# Place the train and test file directories in the root (Check the organization folder text file for details)
BraTS_dataset = datasets.ImageFolder(root =r"E:\python\deep learning\fsjfr\BraTs\BraTs",
                                      transform = image_transforms["train"]
                                     )
test_dataset = datasets.ImageFolder(root =r"E:\python\deep learning\fsjfr\Test",
                                     transform = image_transforms["test"]
)
 #

In [8]:

import torch
from torch.utils.data import random_split

# Get the total size of the dataset
# total_size = len(BraTS_dataset)

# # Calculate the size of the subset (40% of the dataset)
# subset_size = int(0.2 * total_size)

# # Split the subset into train, validation, and test sets (70%, 20%, 10% of the 40%)
# train_size = int(0.7 * subset_size)  # 70% of 40% for training
# val_size = int(0.2 * subset_size)    # 20% of 40% for validation
# test_size = subset_size - train_size - val_size  # Remaining for test

# # Split the dataset into train, validation, and test sets
# subset_dataset, _ = random_split(BraTS_dataset, [subset_size, total_size - subset_size])

# # Now, split the subset into train, validation, and test
# train_dataset, val_dataset, test_dataset = random_split(subset_dataset, [train_size, val_size, test_size])



total_size = len(BraTS_dataset)

# Split the BraTS dataset into train, validation sets (80%, 20%)
train_size = int(0.8 * total_size)
val_size = total_size - train_size  # Ensure all remaining data goes to validation

# Split the dataset using `random_split`
train_dataset, val_dataset = random_split(BraTS_dataset, [train_size, val_size])



# Create DataLoaders for each set
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Print dataset sizes to verify
print(f"Total dataset size: {total_size}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Total dataset size: 16821
Train dataset size: 13456
Validation dataset size: 3365
Test dataset size: 1360


In [9]:
accuracy_stats = {
    'train': [],
    "val": []
}
loss_stats = {
    'train': [],
    "val": []
}

In [10]:
import numpy as np
import torch.nn as nn
import torch

def multi_acc(y_pred, y_test):
    # Access the logits attribute of the ImageClassifierOutput object
    y_pred_sigmoid = F.logsigmoid(y_pred.logits)  # Use F.logsigmoid
    _, y_pred_tags = torch.max(y_pred_sigmoid, dim=1)

    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)

    return acc

#Change 1
criterion = nn.CrossEntropyLoss() # use CrossEntropyLoss for multi-class


In [11]:
n_epochs = 2

for e in range(n_epochs):

    # TRAINING
    i = 0
    min_val_loss = np.Inf
    train_epoch_loss = 0
    train_epoch_acc = 0
    model.train()
    for X_train_batch, y_train_batch in train_loader:
        X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)
        optimizer.zero_grad()
        y_train_pred = model(X_train_batch)

        #Change 2: Access the logits attribute from y_train_pred
        train_loss = criterion(y_train_pred.logits, y_train_batch)
        train_acc = multi_acc(y_train_pred, y_train_batch)
        train_loss.backward()
        optimizer.step()
        train_epoch_loss += train_loss.item()
        train_epoch_acc += train_acc.item()

    # VALIDATION
    with torch.no_grad():
        model.eval()
        val_epoch_loss = 0
        val_epoch_acc = 0
        for X_val_batch, y_val_batch in val_loader:
            X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)
            y_val_pred = model(X_val_batch)

            #Change 3: Access the logits attribute from y_val_pred
            val_loss = criterion(y_val_pred.logits, y_val_batch)
            val_acc = multi_acc(y_val_pred, y_val_batch)
            val_epoch_loss += val_loss.item()
            val_epoch_acc += val_acc.item()
    loss_stats['train'].append(train_epoch_loss / len(train_loader))
    loss_stats['val'].append(val_epoch_loss / len(val_loader))
    vLoss = val_epoch_loss / len(val_loader)

    # If the validation loss is at a minimum
    if vLoss < min_val_loss:
      # Save the model
      torch.save(model.state_dict(), "net.pth")  # OK
      epochs_no_improve = 0
      min_val_loss = vLoss

    accuracy_stats['train'].append(train_epoch_acc/len(train_loader))
    accuracy_stats['val'].append(val_epoch_acc/len(val_loader))
    print(f'Epoch {e+0:02}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Val Loss: {val_epoch_loss/len(val_loader):.5f} | Train Acc: {train_epoch_acc/len(train_loader):.3f} | Val Acc: {val_epoch_acc/len(val_loader):.3f}')


KeyboardInterrupt: 

In [None]:
train_val_acc_df = pd.DataFrame.from_dict(accuracy_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(30,10))
sns.lineplot(data=train_val_acc_df, x = "epochs", y="value", hue="variable",  ax=axes[0]).set_title('Train-Val Accuracy/Epoch')
sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable", ax=axes[1]).set_title('Train-Val Loss/Epoch')

In [None]:
model.load_state_dict(torch.load("net.pth"))
# TEST
y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        # Change here: Apply softmax to the logits
        y_test_pred = torch.softmax(y_test_pred.logits, dim=1)
        y_pred_list.append(y_test_pred.cpu().numpy())
        y_true_list.append(y_batch.cpu().numpy())

# ... (rest of the code remains the same)

# Evaluation
# Convert the lists of arrays into 1D arrays
y_pred_list = np.concatenate(y_pred_list)
y_true_list = np.concatenate(y_true_list)


# Confusion Matrix, Accuracy, F1-Score
cm = confusion_matrix(y_true_list, np.argmax(y_pred_list, axis=1))
print(cm)
# Calculate metrics for multi-class classification
FP = cm.sum(axis=0) - np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)
#TN, FP, FN, TP = confusion_matrix(y_true_list,y_pred_list).ravel()
sensitivity = np.round((TP / (TP + FN)),25)
specificity = np.round((TN / (FP + TN)),25)
precision = np.round((TP / (TP + FP)),25)
recall = np.round((TP / (TP + FN)),25)
f1 = np.round(((2*precision*recall)/(precision+recall)),25)
if np.isnan(f1).any():
   f1 = np.nan_to_num(f1, nan=0)
accuracy = np.round(((TP + TN) / len(y_true_list)),25)

from sklearn.metrics import roc_auc_score
roc_auc_score(y_true_list, y_pred_list, multi_class='ovr')

In [None]:
sns.heatmap(cm, annot=True, fmt='g')