# Installing Dependencies

In [2]:
!pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m42.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


In [3]:
!pip install nibabel



## Importing Libraries

In [12]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt 
import os
import torch
from monai.data import DataLoader, Dataset


In [13]:
data_dir = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"
sample_patient = "BraTS20_Training_001"

modalities = {
    "FLAIR": f"{data_dir}{sample_patient}/{sample_patient}_flair.nii",
    "T1": f"{data_dir}{sample_patient}/{sample_patient}_t1.nii",
    "T1ce": f"{data_dir}{sample_patient}/{sample_patient}_t1ce.nii",
    "T2": f"{data_dir}{sample_patient}/{sample_patient}_t2.nii",
}

### Displaying some sample images

In [None]:
fig, axes = plt.subplots(1,4, figsize=(20,10))
for i, (modality, path) in enumerate(modalities.items()):
    print(path)
    img = nib.load(path).get_fdata()

    mid_slice = img.shape[2] // 2

    axes[i].imshow(img[:, :, mid_slice], cmap="gray", aspect="auto")
    axes[i].set_title(modality, fontsize=14)
    axes[i].axis("off")

plt.show()

    

In [None]:
# Load the segmentation mask
seg_path = f"{data_dir}{sample_patient}/{sample_patient}_seg.nii"
seg_img = nib.load(seg_path).get_fdata()

# Plot segmentation mask
plt.figure(figsize=(6,6))
plt.imshow(seg_img[:, :, mid_slice], cmap="jet")  # Use "jet" colormap for segmentation
plt.axis("off")
plt.title("Tumor Segmentation Mask", fontsize=14)
plt.colorbar()
plt.show()


### Viewing 3D

In [None]:
flair_path = f"{data_dir}{sample_patient}/{sample_patient}_t1ce.nii"

img = nib.load(flair_path).get_fdata()

img = (img - np.min(img)) / (np.max(img) - np.min(img))

mid_slice = img.shape[2] // 2  
alt_slice = img.shape[2] // 3  

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img[:, :, mid_slice], cmap="hot")  
axes[0].set_title("Middle Slice")
axes[0].axis("off")

axes[1].imshow(img[:, :, alt_slice], cmap="gray") 
axes[1].set_title("Alternative Slice")
axes[1].axis("off")

plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.imshow(img[:, :, img.shape[2] // 2], cmap="gray")
plt.title("Middle Slice of FLAIR MRI")
plt.axis("off")
plt.show()


In [None]:
import plotly.graph_objects as go

In [None]:
img = nib.load(flair_path).get_fdata()

threshold = np.percentile(img,50)

x,y,z = np.where(img > threshold)

values = img[x,y,z]

#downsample
sample_size = min(5000, len(x))
indices = np.random.choice(len(x), sample_size, replace=False)


x, y, z, values = x[indices], y[indices], z[indices], values[indices]

fig = go.Figure(data=go.Scatter3d(
    x=x, y=y, z=z,
    mode="markers",
    marker=dict(
        size=3,  
        color=values,  
        colorscale="hot", 
        opacity=0.8  
    )    
))

fig.update_layout(
    title="3D MRI Scan Visualization (FLAIR)",
    scene=dict(
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z"
    )
)

fig.show()

In [None]:
!pip install hd-bet

# Get the Dataset Ready and Start Preprocessing

In [1]:
import os
import glob
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import monai
import numpy as np
import nibabel as nib
from monai.transforms import (
    Compose, LoadImage,ScaleIntensity, RandRotated, RandFlipd, RandZoomd,
    EnsureTyped, Resize
)

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

In [3]:
base_path = "/kaggle/input/brats20-dataset-training-validation/"
train_path = os.path.join(base_path, "BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/")
val_path = os.path.join(base_path, "BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/")

In [4]:
import os
from tqdm import tqdm  # For progress tracking

def get_file_list(data_path, is_train=True):
    if not os.path.exists(data_path):
        print(f"Path not found: {data_path}")
        return []

    patients = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))]
    print(f"Found {len(patients)} patients in {data_path}") 

    data = []
    for patient in tqdm(patients, desc="Processing Patients"):
        patient_path = os.path.join(data_path, patient)
        modalities = {
            't1': os.path.join(patient_path, f"{patient}_t1.nii"),
            't1ce': os.path.join(patient_path, f"{patient}_t1ce.nii"),
            't2': os.path.join(patient_path, f"{patient}_t2.nii"),
            'flair': os.path.join(patient_path, f"{patient}_flair.nii"),
        }

        # Also check for .nii.gz files if .nii files are missing
        for mod in modalities.keys():
            if not os.path.exists(modalities[mod]):
                gz_path = modalities[mod] + ".gz"
                if os.path.exists(gz_path):
                    modalities[mod] = gz_path  # Use .nii.gz instead

        if is_train:
            modalities['seg'] = os.path.join(patient_path, f"{patient}_seg.nii")
            if not os.path.exists(modalities['seg']):
                seg_gz_path = modalities['seg'] + ".gz"
                if os.path.exists(seg_gz_path):
                    modalities['seg'] = seg_gz_path

        # Check if any modality is still missing
        missing_files = [mod for mod, path in modalities.items() if not os.path.exists(path)]
        if missing_files:
            print(f"Skipping {patient} (Missing: {', '.join(missing_files)})")
            continue  

        data.append(modalities)
    
    print(f"Total valid patients: {len(data)}")
    return data


In [5]:
train_data = get_file_list(train_path, is_train=True)
val_data = get_file_list(val_path, is_train=False)

Found 369 patients in /kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/


Processing Patients: 100%|██████████| 369/369 [00:01<00:00, 223.66it/s]


Skipping BraTS20_Training_355 (Missing: seg)
Total valid patients: 368
Found 125 patients in /kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/


Processing Patients: 100%|██████████| 125/125 [00:00<00:00, 242.78it/s]

Total valid patients: 125





In [6]:
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

Training samples: 368
Validation samples: 125


In [8]:
import shutil

# Path to Kaggle working directory
output_dir = "/kaggle/working/preprocessed_brats2020/"

# Delete existing output directory (if it exists) to free up space
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    print("Deleted existing output directory.")

# Recreate the output directory
os.makedirs(output_dir, exist_ok=True)
print("Clean workspace ready!")


Deleted existing output directory.
Clean workspace ready!


In [None]:
import os
import numpy as np
from tqdm import tqdm
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
    ScaleIntensityd, RandRotated, RandFlipd, RandZoomd, EnsureTyped
)
from monai.data import Dataset, DataLoader

# Define Augmentations for Training Data
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityd(keys=['image'], minv=0.0, maxv=1.0),
    RandRotated(keys=['image'], range_x=15, prob=0.5),
    RandFlipd(keys=['image'], prob=0.5),
    RandZoomd(keys=['image'], min_zoom=0.9, max_zoom=1.1, prob=0.5),
    EnsureTyped(keys=['image', 'label'])
])

# Transformations for Validation (No Augmentation)
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityd(keys=['image'], minv=0.0, maxv=1.0),
    EnsureTyped(keys=['image', 'label'])
])

# Function to Apply Transformations and Save as .npy
def preprocess_and_save(dataset, save_dir, transforms):
    os.makedirs(save_dir, exist_ok=True)

    for idx, patient_data in enumerate(tqdm(dataset, desc="Processing Patients")):
        transformed = transforms(patient_data)  # Apply transformations

        np.save(os.path.join(save_dir, f"patient_{idx}_image.npy"), transformed["image"])
        np.save(os.path.join(save_dir, f"patient_{idx}_label.npy"), transformed["label"])

# Apply Augmentations & Save (Training with Augmentation, Validation without)
preprocess_and_save(train_data, save_dir="train_preprocessed", transforms=train_transforms)
preprocess_and_save(val_data, save_dir="val_preprocessed", transforms=val_transforms)


In [None]:
import os
import numpy as np
from tqdm import tqdm
from monai.transforms import (
    Compose, Resize, ScaleIntensity, EnsureChannelFirst,
    RandRotate, RandFlip, RandZoom
)
import nibabel as nib  # To read NIfTI files

# Define Preprocessing Pipeline (For Validation - No Augmentation)
val_preprocess = Compose([
    Resize(spatial_size=(128, 128, 128)),  # Resize to 128x128x128
    ScaleIntensity(minv=0, maxv=1),  # Normalize intensity
    EnsureChannelFirst()  # Ensure correct shape (C, H, W, D)
])

# Define Preprocessing + Augmentation Pipeline (For Training)
train_preprocess = Compose([
    Resize(spatial_size=(128, 128, 128)),  # Resize to 128x128x128
    ScaleIntensity(minv=0, maxv=1),  # Normalize intensity
    EnsureChannelFirst(),  # Ensure correct shape (C, H, W, D)
    RandRotate(range_x=15, prob=0.5),  # Randomly rotate (±15 degrees)
    RandFlip(spatial_axis=0, prob=0.5),  # Randomly flip horizontally
    RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5)  # Random zoom
])

# Function to extract label from segmentation (.seg) file
def get_label(seg_path):
    seg_img = nib.load(seg_path).get_fdata()  # Load segmentation image
    return 1 if 4 in seg_img else 0  # HGG if 4 is present, else LGG

# Function to process and save images & labels
def process_and_save(data_list, save_dir, transform):
    os.makedirs(save_dir, exist_ok=True)
    
    for patient in tqdm(data_list, desc=f"Processing {save_dir}"):
        patient_id = patient["id"]  # Unique identifier for patient
        seg_path = patient["seg"]  # Segmentation file path
        
        # Extract label from segmentation file
        label = get_label(seg_path)

        # Apply preprocessing and augmentation (for training)
        image_stack = np.stack([
            transform(patient["FLAIR"]),
            transform(patient["T1"]),
            transform(patient["T1ce"]),
            transform(patient["T2"])
        ], axis=0)  # Shape: (4, 128, 128, 128)

        # Save as NumPy file (image & label separately)
        np.save(os.path.join(save_dir, f"{patient_id}_image.npy"), image_stack)
        np.save(os.path.join(save_dir, f"{patient_id}_label.npy"), np.array(label))

# Process training data with augmentation
process_and_save(train_data, save_dir="train_preprocessed", transform=train_preprocess)

# Process validation data without augmentation
process_and_save(val_data, save_dir="val_preprocessed", transform=val_preprocess)


## Checking if Preprocessing is done properly

In [None]:
#Checking Numpy Vectors
import numpy as np

# Load a preprocessed sample
sample_image = np.load("train_preprocessed/patient_0_0_image.npy")

# Print details
print("🔹 NumPy Array Details:")
print(f" - Shape: {sample_image.shape}")
print(f" - Data Type: {sample_image.dtype}")
print(f" - Min Intensity: {sample_image.min():.4f}")
print(f" - Max Intensity: {sample_image.max():.4f}")

"""
Shape: (1, 128, 128, 128)  # Single channel (grayscale MRI)
 - Data Type: float32
 - Min Intensity: 0.0000
 - Max Intensity: 1.0000
 
 """

In [None]:
#Visualizing
import matplotlib.pyplot as plt

# Select the middle slice
mid_slice = sample_image.shape[2] // 2  

plt.imshow(sample_image[0, :, :, mid_slice], cmap="gray")  # Show middle slice
plt.title("Preprocessed MRI - Middle Slice")
plt.axis("off")
plt.show()


## Doing the Preprocessing for a single Patient

In [22]:
#Checking for one patient
# MONAI Transforms
load_nifti = LoadImage(image_only=True)
resize_transform = Resize(spatial_size=(128, 128, 128), mode="trilinear")
normalize_transform = ScaleIntensity(minv=0, maxv=1)

# Select one patient from training data
one_patient = train_data[0]  # First patient in train_data
print(f"Processing Patient: {one_patient}")

# Preprocess & store results
preprocessed_images = {}

for modality, file_path in one_patient.items():
    if modality == 'seg':  # Ignore segmentation
        continue

    if not os.path.exists(file_path):
        print(f"Warning: File not found {file_path}")
        continue

    img = load_nifti(file_path)  # Load NIfTI image
    img = img[0] if img.shape[0] == 1 else img
    img = resize_transform(img)  # Resize to (128, 128, 128)
    img = normalize_transform(img)  # Normalize intensity

    preprocessed_images[modality] = img.numpy()

Processing Patient: {'t1': '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_083/BraTS20_Training_083_t1.nii', 't1ce': '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_083/BraTS20_Training_083_t1ce.nii', 't2': '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_083/BraTS20_Training_083_t2.nii', 'flair': '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_083/BraTS20_Training_083_flair.nii', 'seg': '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_083/BraTS20_Training_083_seg.nii'}


In [23]:
# Show NumPy array details
modality_to_show = "flair"  # Change to 't1', 't1ce', or 't2' if needed

if modality_to_show in preprocessed_images:
    sample_img = preprocessed_images[modality_to_show]
    
    # Print NumPy details
    print(f"\n🔹 NumPy Array Details for {modality_to_show.upper()}:")
    print(f" - Shape: {sample_img.shape}")
    print(f" - Data Type: {sample_img.dtype}")
    print(f" - Min Value: {np.min(sample_img):.4f}")
    print(f" - Max Value: {np.max(sample_img):.4f}")


🔹 NumPy Array Details for FLAIR:
 - Shape: (240, 128, 128, 128)
 - Data Type: float32
 - Min Value: 0.0000
 - Max Value: 1.0000


In [None]:
#Visualizing Few Results
#Setting up function

def show_samples(samples):
    fig, ax = plt.subplots(1,4, figsize=(10,20))
    modalities = ['T1', 'T1ce', 'T2', 'FLAIR']
    print("Image shape:", sample['image'].shape)

    for i in range(4):
        ax[i].imshow(sample['image'][i], cmap='gray')
        ax[i].set_title(f"{modalities[i]}\nClass: {'HGG' if sample['label'] else 'LGG'}")
        ax[i].axis('off')
    plt.show()

# Visualize random training sample
train_dataset = BrainTumourDataset(train_data[5:], train_transforms)
sample = train_dataset[45]
show_samples(sample)

In [26]:
import torch
print("GPU Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


GPU Available: True
GPU Name: Tesla P100-PCIE-16GB


# Creating Custom Pytorch Dataset

In [None]:
from torch.utils.data import Dataset

class GliomaDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (str): Directory containing the preprocessed .npy files.
            transform (callable, optional): Optional transform to apply on images.
        """
        self.data_dir = data_dir
        self.transform = transform
        
        # Get list of all patient IDs (only unique ones)
        self.patient_ids = sorted(set([f.split("_")[0] for f in os.listdir(data_dir) if f.endswith("_image.npy")]))
    
    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        
        # Load image and label
        image_path = os.path.join(self.data_dir, f"{patient_id}_image.npy")
        label_path = os.path.join(self.data_dir, f"{patient_id}_label.npy")

        image = np.load(image_path)  # Shape: (4, 128, 128, 128)
        label = np.load(label_path)  # 0: LGG, 1: HGG
        
        # Convert to PyTorch tensors
        image = torch.tensor(image, dtype=torch.float32)  # (C, H, W, D)
        label = torch.tensor(label, dtype=torch.long)  # Scalar label

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)

        return image, label

# Data Loader

In [None]:
from torch.utils.data import DataLoader

# Define dataset paths
train_dir = "train_preprocessed"
val_dir = "val_preprocessed"

# Create Dataset instances
train_dataset = GliomaDataset(train_dir)
val_dataset = GliomaDataset(val_dir)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)


## 3D CNN Model for Glioma Classification

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GliomaCNN(nn.Module):
    def __init__(self):
        super(GliomaCNN, self).__init__()

        # Conv Block 1
        self.conv1 = nn.Conv3d(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(32)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)

        # Conv Block 2
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(64)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)

        # Conv Block 3
        self.conv3 = nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm3d(128)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)

        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 16 * 16 * 16, 256)  # Adjust based on input size
        self.fc2 = nn.Linear(256, 2)  # 2 Classes: LGG & HGG

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))

        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # No activation, since CrossEntropyLoss includes softmax
        return x


## Defining the Loss and Optimizer

In [None]:
import torch.optim as optim

# Define Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GliomaCNN().to(device)

# Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


### Train the Model

In [None]:
num_epochs = 10  # Adjust based on Kaggle runtime

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

print("Training Complete!")


## Evaluation of Model

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np

def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():  # No gradient calculation
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)  # Get the class with highest probability
            
            all_preds.extend(preds.cpu().numpy())  # Store predictions
            all_labels.extend(labels.cpu().numpy())  # Store true labels

    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    cm = confusion_matrix(all_labels, all_preds)

    print(f"🔹 Accuracy: {acc:.4f}")
    print(f"🔹 Precision: {precision:.4f}")
    print(f"🔹 Recall: {recall:.4f}")
    print(f"🔹 F1 Score: {f1:.4f}")
    print("🔹 Confusion Matrix:")
    print(cm)

    return acc, precision, recall, f1, cm


# Checking on Vision Transformer

In [27]:
#importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from timm.models.vision_transformer import vit_base_patch16_224
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


## Custom PyTorch Dataset for ViT

In [None]:
class GliomaViTDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data_list = [f for f in os.listdir(data_dir) if f.endswith('.npy')]

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.data_list[idx])
        data = np.load(file_path, allow_pickle=True).item()

        # Extract the center slice from each modality
        center_slice = 64  # Middle slice of (128,128,128)
        flair = data["FLAIR"][:, :, center_slice]
        t1 = data["T1"][:, :, center_slice]
        t1ce = data["T1ce"][:, :, center_slice]
        t2 = data["T2"][:, :, center_slice]

        # Stack as RGB image (ViT requires 3 channels)
        image = np.stack([flair, t1ce, t2], axis=0)  # Shape: (3, 128, 128)

        label = 1 if np.max(data["seg"]) == 4 else 0  # HGG if label==4 else LGG

        if self.transform:
            image = self.transform(image)

        return torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Define Transformations
vit_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),  # Resize to 224x224 for ViT
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load Dataset
train_dataset = GliomaViTDataset("train_preprocessed", transform=vit_transform)
val_dataset = GliomaViTDataset("val_preprocessed", transform=vit_transform)

# Create Dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)


## Define the ViT Model

In [None]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(ViTClassifier, self).__init__()
        self.vit = vit_base_patch16_224(pretrained=True)
        self.vit.head = nn.Linear(self.vit.head.in_features, num_classes)

    def forward(self, x):
        return self.vit(x)

# Move model to GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTClassifier().to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)


# Training

In [None]:
def train_vit(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Train Acc: {train_acc:.4f}")

# Train Model
train_vit(model, train_loader, val_loader, criterion, optimizer, epochs=10)


In [None]:
# Evaluation
def evaluate_model(model, dataloader):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    cm = confusion_matrix(all_labels, all_preds)

    print(f"🔹 Accuracy: {acc:.4f}")
    print(f"🔹 Precision: {precision:.4f}")
    print(f"🔹 Recall: {recall:.4f}")
    print(f"🔹 F1 Score: {f1:.4f}")
    print("🔹 Confusion Matrix:")
    print(cm)

# Run Evaluation
evaluate_model(model, val_loader)
