In [None]:
pip install nibabel torch torchvision monai numpy matplotlib



In [None]:
import os
import numpy as np
import nibabel as nib
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight

In [None]:
from google.colab import drive
drive.mount('/content/drive')

shared_drives_path = "/content/drive/Shared drives/"

if os.path.exists(shared_drives_path):
    drives = os.listdir(shared_drives_path)
    print("Available Shared Drives:")
    for drive_name in drives:
        print(f"  📁 {drive_name}")
else:
    print("No Shared Drives found. Make sure you've mounted your drive.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Available Shared Drives:
  📁 Fall2025BrainTumorDiagnosis


In [None]:
base_path = "/content/drive/Shared drives/Fall2025BrainTumorDiagnosis/"

In [None]:
print("📂 Contents of Shared Drive:")
for item in sorted(os.listdir(base_path)):
    full_path = os.path.join(base_path, item)
    if os.path.isfile(full_path):
        print(f"  📄 {item}")
    else:
        print(f"  📁 {item}/")

# Get all patient folders
patient_folders = sorted([f for f in os.listdir(base_path)
                         if os.path.isdir(os.path.join(base_path, f))])

print(f"Total patient folders: {len(patient_folders)}")
print(f"\nFirst 5 patients:")
for folder in patient_folders[:5]:
    print(f"  📁 {folder}")

# Check scans in the first patient
first_patient = patient_folders[0]
patient_path = os.path.join(base_path, first_patient)
scans = sorted(os.listdir(patient_path))

print(f"\n📂 Scans in {first_patient}:")
print(f"Total scans: {len(scans)}\n")
for i, scan in enumerate(scans, 1):
    print(f"{i}. {scan}")

# Check for label files in base directory
print("\n📄 Looking for label files in base directory:")
for item in os.listdir(base_path):
    if item.endswith(('.csv', '.txt', '.xlsx', '.json')):
        print(f"  Found: {item}")

📂 Contents of Shared Drive:
  📁 UCSF-PDGM-0004_nifti/
  📁 UCSF-PDGM-0005_nifti/
  📁 UCSF-PDGM-0007_nifti/
  📁 UCSF-PDGM-0008_nifti/
  📁 UCSF-PDGM-0009_nifti/
  📁 UCSF-PDGM-0010_nifti/
  📁 UCSF-PDGM-0011_nifti/
  📁 UCSF-PDGM-0012_nifti/
  📁 UCSF-PDGM-0013_nifti/
  📁 UCSF-PDGM-0014_nifti/
  📁 UCSF-PDGM-0015_nifti/
  📁 UCSF-PDGM-0016_nifti/
  📁 UCSF-PDGM-0017_nifti/
  📁 UCSF-PDGM-0018_nifti/
  📁 UCSF-PDGM-0019_nifti/
  📁 UCSF-PDGM-0020_nifti/
  📁 UCSF-PDGM-0021_nifti/
  📁 UCSF-PDGM-0022_nifti/
  📁 UCSF-PDGM-0023_nifti/
  📁 UCSF-PDGM-0024_nifti/
  📁 UCSF-PDGM-0025_nifti/
  📁 UCSF-PDGM-0026_nifti/
  📁 UCSF-PDGM-0027_nifti/
  📁 UCSF-PDGM-0029_nifti/
  📁 UCSF-PDGM-0030_nifti/
  📁 UCSF-PDGM-0031_nifti/
  📁 UCSF-PDGM-0032_nifti/
  📁 UCSF-PDGM-0033_nifti/
  📁 UCSF-PDGM-0034_nifti/
  📁 UCSF-PDGM-0035_nifti/
  📁 UCSF-PDGM-0036_nifti/
  📁 UCSF-PDGM-0037_nifti/
  📁 UCSF-PDGM-0038_nifti/
  📁 UCSF-PDGM-0039_nifti/
  📁 UCSF-PDGM-0040_nifti/
  📁 UCSF-PDGM-0041_nifti/
  📁 UCSF-PDGM-0042_nifti/
  📁 UCSF-P

In [None]:
patient_folders = sorted([f for f in os.listdir(base_path)
                         if os.path.isdir(os.path.join(base_path, f))])

tumor_count = 0
no_tumor_count = 0
error_count = 0

print(f"Checking all {len(patient_folders)} patients...\n")

for patient in patient_folders:
    tumor_seg_path = os.path.join(base_path, patient,
                                   f"{patient.replace('_nifti', '')}_tumor_segmentation.nii.gz")

    try:
        img = nib.load(tumor_seg_path).get_fdata()
        if np.any(img > 0):
            tumor_count += 1
        else:
            no_tumor_count += 1
            print(f"NO TUMOR: {patient}")  # Print patients without tumor
    except:
        error_count += 1
        print(f"ERROR loading: {patient}")

print(f"\n{'='*50}")
print(f"SUMMARY:")
print(f"  Total patients: {len(patient_folders)}")
print(f"  Has tumor: {tumor_count}")
print(f"  No tumor: {no_tumor_count}")
print(f"  Errors: {error_count}")
print(f"{'='*50}")

Checking all 501 patients...


SUMMARY:
  Total patients: 501
  Has tumor: 501
  No tumor: 0
  Errors: 0


In [None]:
# Path to your CSV file
csv_path = "https://raw.githubusercontent.com/Erdos-Projects/fall-2025-brain-tumor-diagnosis/refs/heads/main/Clinical-data/UCSF-PDGM-metadata_v5.csv"

# Read the CSV
df = pd.read_csv(csv_path)

# Display basic info
print(f"CSV loaded successfully!")
print(f"Shape: {df.shape[0]} rows × {df.shape[1]} columns\n")

# Show column names
print("Column names:")
for i, col in enumerate(df.columns, 1):
    print(f"  {i}. {col}")

# Show first few rows
print("\nFirst 5 rows:")
print(df.head())

# Check if there's a patient ID column that matches our folder names
print("\nChecking for patient ID column...")
potential_id_cols = [col for col in df.columns if 'id' in col.lower() or 'patient' in col.lower() or 'case' in col.lower()]
if potential_id_cols:
    print(f"Potential ID columns: {potential_id_cols}")
    print(f"\nSample values from '{potential_id_cols[0]}':")
    print(df[potential_id_cols[0]].head(10))

CSV loaded successfully!
Shape: 501 rows × 16 columns

Column names:
  1. ID
  2. Sex
  3. Age at MRI
  4. WHO CNS Grade
  5. Final pathologic diagnosis (WHO 2021)
  6. MGMT status
  7. MGMT index
  8. 1p/19q
  9. IDH
  10. 1-dead 0-alive
  11. OS
  12. EOR
  13. Biopsy prior to imaging
  14. BraTS21 ID
  15. BraTS21 Segmentation Cohort
  16. BraTS21 MGMT Cohort

First 5 rows:
              ID Sex  Age at MRI  WHO CNS Grade  \
0  UCSF-PDGM-004   M          66              4   
1  UCSF-PDGM-005   F          80              4   
2  UCSF-PDGM-007   M          70              4   
3  UCSF-PDGM-008   M          70              4   
4  UCSF-PDGM-009   F          68              4   

  Final pathologic diagnosis (WHO 2021)    MGMT status MGMT index   1p/19q  \
0            Glioblastoma, IDH-wildtype       negative          0  unknown   
1            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
2            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
3

In [None]:
# Load CSV
csv_path = "https://raw.githubusercontent.com/Erdos-Projects/fall-2025-brain-tumor-diagnosis/refs/heads/main/Clinical-data/UCSF-PDGM-metadata_v5.csv"
df = pd.read_csv(csv_path)

# Get patient folders
patient_folders = sorted([f for f in os.listdir(base_path)
                         if os.path.isdir(os.path.join(base_path, f))])

# Extract patient number and pad to 4 digits
df['patient_num'] = df['ID'].str.extract(r'UCSF-PDGM-(\d+)')[0]
df['folder_name'] = 'UCSF-PDGM-' + df['patient_num'].str.zfill(4) + '_nifti'

# Filter to only patients we have imaging data for
df_with_imaging = df[df['folder_name'].isin(patient_folders)]

print(f"Total patients in CSV: {len(df)}")
print(f"Total folders: {len(patient_folders)}")
print(f"Matched patients (have both data and labels): {len(df_with_imaging)}")

# Check WHO Grade distribution
print(f"\n{'='*50}")
print(f"WHO Grade distribution:")
print(df_with_imaging['WHO CNS Grade'].value_counts().sort_index())
print(f"{'='*50}")

# Show first few matches
print(f"\nFirst 5 matched patients:")
print(df_with_imaging[['ID', 'folder_name', 'WHO CNS Grade']].head())

# Create labels dictionary
labels_dict = dict(zip(df_with_imaging['folder_name'],
                       df_with_imaging['WHO CNS Grade']))

print(f"\nLabels dictionary created!")
print(f"Example entries:")
for i, (folder, grade) in enumerate(list(labels_dict.items())[:5]):
    print(f"  {folder}: Grade {grade}")

Total patients in CSV: 501
Total folders: 501
Matched patients (have both data and labels): 501

WHO Grade distribution:
WHO CNS Grade
2     56
3     43
4    402
Name: count, dtype: int64

First 5 matched patients:
              ID           folder_name  WHO CNS Grade
0  UCSF-PDGM-004  UCSF-PDGM-0004_nifti              4
1  UCSF-PDGM-005  UCSF-PDGM-0005_nifti              4
2  UCSF-PDGM-007  UCSF-PDGM-0007_nifti              4
3  UCSF-PDGM-008  UCSF-PDGM-0008_nifti              4
4  UCSF-PDGM-009  UCSF-PDGM-0009_nifti              4

Labels dictionary created!
Example entries:
  UCSF-PDGM-0004_nifti: Grade 4
  UCSF-PDGM-0005_nifti: Grade 4
  UCSF-PDGM-0007_nifti: Grade 4
  UCSF-PDGM-0008_nifti: Grade 4
  UCSF-PDGM-0009_nifti: Grade 4


In [None]:
# =============================================
# 1. LOAD AND PREPARE LABELS
# =============================================

def load_labels(csv_path, base_path):
    """
    Load labels from CSV and match to folder names

    Args:
        csv_path: Path to metadata CSV file
        base_path: Path to folder containing patient folders

    Returns:
        labels_dict: Dictionary mapping folder_name -> grade (0, 1, 2)
        patient_folders: List of patient folder names
    """
    # Load CSV
    df = pd.read_csv(csv_path)

    # Get all patient folders
    patient_folders = sorted([f for f in os.listdir(base_path)
                             if os.path.isdir(os.path.join(base_path, f))])

    # Match CSV IDs to folder names (pad numbers to 4 digits)
    df['patient_num'] = df['ID'].str.extract(r'UCSF-PDGM-(\d+)')[0]
    df['folder_name'] = 'UCSF-PDGM-' + df['patient_num'].str.zfill(4) + '_nifti'

    # Filter to patients with imaging data
    df_matched = df[df['folder_name'].isin(patient_folders)]

    # Convert WHO grades to class indices: Grade 2->0, Grade 3->1, Grade 4->2
    grade_to_class = {2: 0, 3: 1, 4: 2}
    df_matched['class_label'] = df_matched['WHO CNS Grade'].map(grade_to_class)

    # Create labels dictionary
    labels_dict = dict(zip(df_matched['folder_name'], df_matched['class_label']))

    print(f"Loaded {len(labels_dict)} patients with labels")
    print(f"Class distribution: {df_matched['WHO CNS Grade'].value_counts().sort_index()}")

    return labels_dict, patient_folders


# =============================================
# 2. DATASET CLASS
# =============================================

class MRIDataset(Dataset):
    def __init__(self, base_path, patient_list, labels_dict, slice_idx=120, selected_modalities=None):
        """
        Dataset for multi-channel 2D MRI slices

        Args:
            base_path: Path to folder containing patient folders
            patient_list: List of patient folder names to include
            labels_dict: Dictionary mapping folder_name -> label
            slice_idx: Which slice to extract from 3D volume (default: 120)
            selected_modalities: List of modality names to use. If None, uses all available.
        """
        self.base_path = base_path
        self.patient_list = patient_list
        self.labels_dict = labels_dict
        self.slice_idx = slice_idx

        # Define modality order (consistent for all patients)
        # Excluding non-NIfTI files and segmentation masks
        #self.all_modalities = [
            #'ADC', 'ASL', 'DTI_eddy_FA', 'DTI_eddy_L1', 'DTI_eddy_L2',
            #'DTI_eddy_L3', 'DTI_eddy_MD', 'DWI', 'DWI_bias',
            #'FLAIR', 'FLAIR_bias', 'SWI', 'SWI_bias',
            #'T1', 'T1_bias', 'T1c', 'T1c_bias', 'T2', 'T2_bias'
      #]
        self.all_modalities = [
            'ADC', 'FLAIR', 'T1', 'T1c', 'T2'
        ]

        # Use selected modalities or all
        self.modalities = selected_modalities if selected_modalities else self.all_modalities

        print(f"Using {len(self.modalities)} modalities: {self.modalities[:3]}... (showing first 3)")

    def __len__(self):
        return len(self.patient_list)

    def __getitem__(self, idx):
        patient_folder = self.patient_list[idx]
        patient_id = patient_folder.replace('_nifti', '')
        patient_path = os.path.join(self.base_path, patient_folder)
        seg_path = patient_path / f"{patient_id}_tumor_segmentation.nii.gz"

        try:
            seg_data = nib.load(str(seg_path)).get_fdata()
            seg_bin = (seg_data > 0).astype(np.uint8)
            tumor_area_by_slice = seg_bin.sum(axis=(0, 1))
            z = int(np.argmax(tumor_area_by_slice)) if tumor_area_by_slice.max() > 0 else seg_bin.shape[2] // 2
        except Exception as e:
            print(f"Warning: Could not load segmentation for {patient_folder}: {e}")

        # Load all modalities and stack as channels
        channels = []
        for modality in self.modalities:
            file_path = patient_path / f"{patient_id}_{modality}.nii.gz"

            try:
                # Load 3D volume
                img_3d = nib.load(file_path).get_fdata().astype(np.float32)

                # Extract 2D slice
                if img_3d.shape[2] > self.slice_idx:
                    img_2d = img_3d[:, :, self.slice_idx]
                else:
                    # If slice index out of range, use middle slice
                    img_2d = img_3d[:, :, img_3d.shape[2]//2]

                # Normalize to [0, 1]
                img_min, img_max = img_2d.min(), img_2d.max()
                if img_max > img_min:
                    img_2d = (img_2d - img_min) / (img_max - img_min)
                else:
                    img_2d = np.zeros_like(img_2d)

                channels.append(img_2d)

            except Exception as e:
                # If file missing or error, use zero-filled channel
                print(f"Warning: Could not load {modality} for {patient_folder}: {e}")
                channels.append(np.zeros((240, 240), dtype=np.float32))

        # Stack channels: (num_channels, H, W)
        img_tensor = np.stack(channels, axis=0)

        # Get label
        label = self.labels_dict[patient_folder]

        return torch.tensor(img_tensor, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


# =============================================
# 3. CNN MODEL
# =============================================

class CNN2D(nn.Module):
    def __init__(self, n_channels=5, n_classes=3):
        """
        2D CNN for multi-channel MRI classification

        Args:
            n_channels: Number of input channels (MRI modalities)
            n_classes: Number of output classes (3 for WHO grades 2,3,4)
        """
        super().__init__()

        self.conv_layers = nn.Sequential(
            # Block 1: n_channels -> 32
            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 240x240 -> 120x120

            # Block 2: 32 -> 64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 120x120 -> 60x60

            # Block 3: 64 -> 128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 60x60 -> 30x30

            # Block 4: 128 -> 256
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 30x30 -> 15x15
        )

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 15 * 15, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x)
        return x


# =============================================
# 4. TRAINING FUNCTION
# =============================================

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cuda'):
    """
    Train the CNN model

    Returns:
        train_losses, val_losses, train_accs, val_accs
    """
    model = model.to(device)

    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * correct / total
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    return train_losses, val_losses, train_accs, val_accs


# =============================================
# 5. MAIN EXECUTION
# =============================================

if __name__ == "__main__":
    # Paths
    csv_path = "/content/drive/MyDrive/UCSF-PDGM-metadata_v5-checkpoint.csv"
    base_path = "/content/drive/Shareddrives/Fall2025BrainTumorDiagnosis/"

    # Load labels
    labels_dict, all_patients = load_labels(csv_path, base_path)

    # Get patients with labels
    patients_with_labels = list(labels_dict.keys())

    # Train/val split (80/20)
    train_patients, val_patients = train_test_split(
        patients_with_labels, test_size=0.2, random_state=42,
        stratify=[labels_dict[p] for p in patients_with_labels]
    )

    print(f"\nTrain: {len(train_patients)}, Val: {len(val_patients)}")

    # Create datasets
    train_dataset = MRIDataset(base_path, train_patients, labels_dict, slice_idx=120)
    val_dataset = MRIDataset(base_path, val_patients, labels_dict, slice_idx=120)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

    # Initialize model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nUsing device: {device}")

    model = CNN2D(n_channels=5, n_classes=3)

    # Calculate class weights to handle imbalance
    from sklearn.utils.class_weight import compute_class_weight
    all_labels = [labels_dict[p] for p in patients_with_labels]
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(all_labels),
        y=all_labels
    )
    class_weights = torch.FloatTensor(class_weights).to(device)
    print(f"\nClass weights (to handle imbalance): {class_weights}")
    print(f"  Grade 2 weight: {class_weights[0]:.2f}")
    print(f"  Grade 3 weight: {class_weights[1]:.2f}")
    print(f"  Grade 4 weight: {class_weights[2]:.2f}")


    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train
    print("\nStarting training...")
    train_losses, val_losses, train_accs, val_accs = train_model(
        model, train_loader, val_loader, criterion, optimizer,
        num_epochs=20, device=device
    )

    # Plot results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(train_losses, label='Train Loss')
    ax1.plot(val_losses, label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.set_title('Training and Validation Loss')

    ax2.plot(train_accs, label='Train Acc')
    ax2.plot(val_accs, label='Val Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.set_title('Training and Validation Accuracy')

    plt.tight_layout()
    plt.show()

    print("\nTraining complete!")

Loaded 495 patients with labels
Class distribution: WHO CNS Grade
2     56
3     43
4    402
Name: count, dtype: int64

Train: 396, Val: 99
Using 5 modalities: ['ADC', 'FLAIR', 'T1']... (showing first 3)
Using 5 modalities: ['ADC', 'FLAIR', 'T1']... (showing first 3)

Using device: cpu

Starting training...
Epoch 1/20
  Train Loss: 3.3239, Train Acc: 64.39%
  Val Loss: 0.7581, Val Acc: 79.80%
Epoch 2/20
  Train Loss: 1.1228, Train Acc: 71.97%
  Val Loss: 0.6227, Val Acc: 79.80%
Epoch 3/20
  Train Loss: 0.7716, Train Acc: 76.77%
  Val Loss: 0.7543, Val Acc: 79.80%


KeyboardInterrupt: 

#Balanced Input

In [None]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# =============================================
# 1. LOAD LABELS AND FILTER DUPLICATES
# =============================================
def load_labels(csv_path, base_path):
    df = pd.read_csv(csv_path)
    df = df[~df['ID'].str.contains('_FU')].copy()
    patient_folders = sorted([f for f in os.listdir(base_path)
                              if os.path.isdir(os.path.join(base_path, f))])
    df['patient_num'] = df['ID'].str.extract(r'UCSF-PDGM-(\d+)')[0]
    df['folder_name'] = 'UCSF-PDGM-' + df['patient_num'].str.zfill(4) + '_nifti'
    df_matched = df[df['folder_name'].isin(patient_folders)]
    grade_to_class = {2:0, 3:1, 4:2}
    df_matched['class_label'] = df_matched['WHO CNS Grade'].map(grade_to_class)
    labels_dict = dict(zip(df_matched['folder_name'], df_matched['class_label']))
    print(f"Loaded {len(labels_dict)} patients with labels")
    print("Class distribution:", df_matched['WHO CNS Grade'].value_counts().sort_index())
    return labels_dict, patient_folders

# =============================================
# 2. DATASET CLASS WITH AUGMENTATION
# =============================================
class MRIDataset(Dataset):
    def __init__(self, base_path, patient_list, labels_dict, modalities=None, augment=False):
        self.base_path = base_path
        self.patient_list = patient_list
        self.labels_dict = labels_dict
        self.modalities = modalities if modalities else ['ADC','FLAIR','T1','T1c','T2']
        self.augment = augment

    def __len__(self):
        return len(self.patient_list)

    def __getitem__(self, idx):
        folder = self.patient_list[idx]
        pid = folder.replace('_nifti','')
        patient_path = os.path.join(self.base_path, folder)

        # ----------------------
        # Find z-slice with max tumor
        # ----------------------
        try:
            seg_file = os.path.join(patient_path, f"{pid}_tumor_segmentation.nii.gz")
            seg = nib.load(seg_file).get_fdata()
            seg_bin = (seg>0).astype(np.uint8)
            tumor_area_by_slice = seg_bin.sum(axis=(0,1))
            z = int(np.argmax(tumor_area_by_slice)) if tumor_area_by_slice.max()>0 else seg.shape[2]//2
        except:
            z = 120

        # ----------------------
        # Load MRI slices
        # ----------------------
        channels = []
        for mod in self.modalities:
            img_file = os.path.join(patient_path, f"{pid}_{mod}.nii.gz")
            try:
                img = nib.load(img_file).get_fdata().astype(np.float32)
                img2d = img[:,:,z] if img.shape[2] > z else img[:,:,img.shape[2]//2]
                img2d = (img2d - img2d.min()) / (img2d.max()-img2d.min()) if img2d.max()>img2d.min() else np.zeros_like(img2d)
            except:
                img2d = np.zeros((240,240), dtype=np.float32)
            channels.append(img2d)

        img_tensor = torch.tensor(np.stack(channels, axis=0), dtype=torch.float32)
        label = torch.tensor(self.labels_dict[folder], dtype=torch.long)

        # ----------------------
        # Augmentation for minority classes
        # ----------------------
        if self.augment and label.item() in [0,1]:  # Only for grade 2 and 3
            if np.random.rand() > 0.5:
                img_tensor = img_tensor.flip(-1)  # Horizontal flip
            if np.random.rand() > 0.5:
                img_tensor = img_tensor.flip(-2)  # Vertical flip
            k = np.random.choice([0,1,2,3])
            img_tensor = torch.rot90(img_tensor, k, dims=[1,2])  # Random 90 deg rotation

        return img_tensor, label

# =============================================
# 3. CNN MODEL
# =============================================
class CNN2D(nn.Module):
    def __init__(self, n_channels=5, n_classes=3):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(n_channels,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(256*15*15,512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512,128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128,n_classes)
        )

    def forward(self,x):
        x = self.conv_layers(x)
        x = torch.flatten(x,1)
        x = self.fc(x)
        return x

# =============================================
# 4. TRAINING FUNCTION
# =============================================
def train_model_per_class(model, train_loader, val_loader, criterion, optimizer, device='cuda', epochs=20):
    model = model.to(device)
    n_classes = 3
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_acc_per_class': [], 'val_acc_per_class': [],
        'train_loss_per_class': [], 'val_loss_per_class': []
    }

    for epoch in range(epochs):
        # --------- TRAIN ---------
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        class_correct = np.zeros(n_classes)
        class_total = np.zeros(n_classes)
        class_loss_sum = np.zeros(n_classes)

        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

            for c in range(n_classes):
                idxs = (labels==c)
                class_total[c] += idxs.sum().item()
                class_correct[c] += (preds[idxs]==labels[idxs]).sum().item()
                if idxs.sum()>0:
                    class_loss_sum[c] += criterion(outputs[idxs], labels[idxs]).item()*idxs.sum().item()

        train_loss = running_loss/total
        train_acc = 100.*correct/total
        train_acc_pc = 100.*class_correct/np.maximum(class_total,1)
        train_loss_pc = class_loss_sum/np.maximum(class_total,1)

        # --------- VALIDATION ---------
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        class_correct = np.zeros(n_classes)
        class_total = np.zeros(n_classes)
        class_loss_sum = np.zeros(n_classes)

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)

                running_loss += loss.item()*labels.size(0)
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

                for c in range(n_classes):
                    idxs = (labels==c)
                    class_total[c] += idxs.sum().item()
                    class_correct[c] += (preds[idxs]==labels[idxs]).sum().item()
                    if idxs.sum()>0:
                        class_loss_sum[c] += criterion(outputs[idxs], labels[idxs]).item()*idxs.sum().item()

        val_loss = running_loss/total
        val_acc = 100.*correct/total
        val_acc_pc = 100.*class_correct/np.maximum(class_total,1)
        val_loss_pc = class_loss_sum/np.maximum(class_total,1)

        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_acc_per_class'].append(train_acc_pc)
        history['val_acc_per_class'].append(val_acc_pc)
        history['train_loss_per_class'].append(train_loss_pc)
        history['val_loss_per_class'].append(val_loss_pc)

        # Print
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.2f}% | Val Loss {val_loss:.4f}, Acc {val_acc:.2f}%")
        for c in range(n_classes):
            print(f"  Grade {c+2}: Train Acc {train_acc_pc[c]:.2f}%, Loss {train_loss_pc[c]:.4f} | Val Acc {val_acc_pc[c]:.2f}%, Loss {val_loss_pc[c]:.4f}")

    return history

# =============================================
# 5. MAIN EXECUTION
# =============================================
if __name__=="__main__":
    labels_dict, all_patients = load_labels(csv_path, base_path)
    patients = list(labels_dict.keys())

    train_patients, val_patients = train_test_split(
        patients, test_size=0.2, random_state=42,
        stratify=[labels_dict[p] for p in patients]
    )
    print(f"Train {len(train_patients)}, Val {len(val_patients)}")

    # Compute sampling weights for WeightedRandomSampler
    all_labels_train = [labels_dict[p] for p in train_patients]
    class_sample_count = np.array([all_labels_train.count(c) for c in np.unique(all_labels_train)])
    weights = 1. / class_sample_count
    samples_weight = np.array([weights[c] for c in all_labels_train])
    sampler = WeightedRandomSampler(weights=samples_weight, num_samples=len(samples_weight), replacement=True)

    train_dataset = MRIDataset(base_path, train_patients, labels_dict, augment=True)
    val_dataset = MRIDataset(base_path, val_patients, labels_dict, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = CNN2D(n_channels=5, n_classes=3)

    # Class weights for loss
    all_labels = [labels_dict[p] for p in patients]
    class_weights = torch.FloatTensor(compute_class_weight('balanced', classes=np.unique(all_labels), y=all_labels)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train
    history = train_model_per_class(model, train_loader, val_loader, criterion, optimizer, device=device, epochs=20)


Loaded 495 patients with labels
Class distribution: WHO CNS Grade
2     56
3     43
4    396
Name: count, dtype: int64
Train 396, Val 99
Epoch 1/20 | Train Loss 3.9023, Acc 38.13% | Val Loss 2.1232, Acc 11.11%
  Grade 2: Train Acc 41.09%, Loss 3.9063 | Val Acc 100.00%, Loss 0.1742
  Grade 3: Train Acc 56.72%, Loss 2.9159 | Val Acc 0.00%, Loss 1.9050
  Grade 4: Train Acc 16.54%, Loss 11.2427 | Val Acc 0.00%, Loss 3.6459
Epoch 2/20 | Train Loss 1.4194, Acc 37.63% | Val Loss 1.8829, Acc 9.09%
  Grade 2: Train Acc 44.92%, Loss 1.5234 | Val Acc 0.00%, Loss 1.5574
  Grade 3: Train Acc 55.30%, Loss 1.0483 | Val Acc 100.00%, Loss 0.4047
  Grade 4: Train Acc 15.75%, Loss 3.7271 | Val Acc 0.00%, Loss 3.1312
Epoch 3/20 | Train Loss 1.0473, Acc 42.68% | Val Loss 1.4654, Acc 16.16%
  Grade 2: Train Acc 49.30%, Loss 1.0346 | Val Acc 54.55%, Loss 0.8617
  Grade 3: Train Acc 52.71%, Loss 0.8612 | Val Acc 33.33%, Loss 0.7835
  Grade 4: Train Acc 24.80%, Loss 2.4763 | Val Acc 8.86%, Loss 2.3475
Epoch 4/

In [None]:
#balanced input + smooth out the class weights loss + k-fold stratified + metadata in the CNN model

import os
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler

# =============================================
# 1. LOAD LABELS AND METADATA
# =============================================
def load_labels(csv_path, base_path):
    df = pd.read_csv(csv_path)
    df = df[~df['ID'].str.contains('_FU')].copy()
    patient_folders = sorted([
        f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))
    ])
    df['patient_num'] = df['ID'].str.extract(r'UCSF-PDGM-(\d+)')[0]
    df['folder_name'] = 'UCSF-PDGM-' + df['patient_num'].str.zfill(4) + '_nifti'
    df_matched = df[df['folder_name'].isin(patient_folders)]

    grade_to_class = {2: 0, 3: 1, 4: 2}
    df_matched['class_label'] = df_matched['WHO CNS Grade'].map(grade_to_class)

    # Encode metadata
    df_matched['Sex_encoded'] = df_matched['Sex'].apply(lambda x: 1 if x == "M" else 0)
    df_matched['IDH_encoded'] = df_matched['IDH'].apply(lambda x: 0 if x == "wildtype" else 1)
    df_matched['Codeletion_encoded'] = df_matched['1p/19q'].apply(lambda x: 1 if x == "co-deleted" else 0)
    df_matched['Age_scaled'] = StandardScaler().fit_transform(df_matched[['Age at MRI']])

    labels_dict = dict(zip(df_matched['folder_name'], df_matched['class_label']))
    metadata_dict = df_matched.set_index('folder_name')[[
        'Age_scaled', 'Sex_encoded', 'IDH_encoded', 'Codeletion_encoded'
    ]].to_dict(orient='index')

    print(f"Loaded {len(labels_dict)} patients with labels")
    print("Class distribution:")
    print(df_matched['WHO CNS Grade'].value_counts().sort_index())

    return labels_dict, metadata_dict, patient_folders

# =============================================
# 2. DATASET CLASS
# =============================================
class MRIDataset(Dataset):
    def __init__(self, base_path, patient_list, labels_dict, metadata_dict, modalities=None, augment=False):
        self.base_path = base_path
        self.patient_list = patient_list
        self.labels_dict = labels_dict
        self.metadata_dict = metadata_dict
        self.modalities = modalities if modalities else ['ADC', 'FLAIR', 'T1', 'T1c', 'T2']
        self.augment = augment

    def __len__(self):
        return len(self.patient_list)

    def __getitem__(self, idx):
        folder = self.patient_list[idx]
        pid = folder.replace('_nifti', '')
        patient_path = os.path.join(self.base_path, folder)

        # --- Find z-slice with max tumor ---
        try:
            seg_file = os.path.join(patient_path, f"{pid}_tumor_segmentation.nii.gz")
            seg = nib.load(seg_file).get_fdata()
            seg_bin = (seg > 0).astype(np.uint8)
            tumor_area_by_slice = seg_bin.sum(axis=(0, 1))
            z = int(np.argmax(tumor_area_by_slice)) if tumor_area_by_slice.max() > 0 else seg.shape[2] // 2
        except:
            z = 120

        # --- Load MRI slices ---
        channels = []
        for mod in self.modalities:
            img_file = os.path.join(patient_path, f"{pid}_{mod}.nii.gz")
            try:
                img = nib.load(img_file).get_fdata().astype(np.float32)
                img2d = img[:, :, z] if img.shape[2] > z else img[:, :, img.shape[2] // 2]
                img2d = (img2d - img2d.min()) / (img2d.max() - img2d.min()) if img2d.max() > img2d.min() else np.zeros_like(img2d)
            except:
                img2d = np.zeros((240, 240), dtype=np.float32)
            channels.append(img2d)

        img_tensor = torch.tensor(np.stack(channels, axis=0), dtype=torch.float32)
        label = torch.tensor(self.labels_dict[folder], dtype=torch.long)

        # --- Metadata ---
        meta_values = list(self.metadata_dict[folder].values())
        meta_tensor = torch.tensor(meta_values, dtype=torch.float32)

        # --- Augmentation ---
        if self.augment and label.item() in [0, 1]:
            if np.random.rand() > 0.5:
                img_tensor = img_tensor.flip(-1)
            if np.random.rand() > 0.5:
                img_tensor = img_tensor.flip(-2)
            k = np.random.choice([0, 1, 2, 3])
            img_tensor = torch.rot90(img_tensor, k, dims=[1, 2])

        return img_tensor, meta_tensor, label

# =============================================
# 3. CNN + METADATA MODEL
# =============================================
class CNN2D_Meta(nn.Module):
    def __init__(self, n_channels=5, n_meta=4, n_classes=3):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(n_channels, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.flatten_dim = 256 * 15 * 15
        self.fc_img = nn.Sequential(
            nn.Linear(self.flatten_dim, 512), nn.ReLU(), nn.Dropout(0.5)
        )
        self.fc_meta = nn.Sequential(
            nn.Linear(n_meta, 32), nn.ReLU(), nn.Dropout(0.2)
        )
        self.fc_combined = nn.Sequential(
            nn.Linear(512 + 32, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, n_classes)
        )

    def forward(self, x_img, x_meta):
        x_img = self.conv_layers(x_img)
        x_img = torch.flatten(x_img, 1)
        x_img = self.fc_img(x_img)
        x_meta = self.fc_meta(x_meta)
        x = torch.cat((x_img, x_meta), dim=1)
        out = self.fc_combined(x)
        return out

# =============================================
# 4. TRAINING FUNCTION WITH PER-GRADE METRICS
# =============================================
def train_model_per_class(model, train_loader, val_loader, criterion, optimizer, device='cuda', epochs=20):
    model.to(device)
    n_classes = 3
    for epoch in range(epochs):
        # ---------- TRAIN ----------
        model.train()
        running_loss, correct, total = 0, 0, 0
        class_correct, class_total, class_loss_sum = np.zeros(n_classes), np.zeros(n_classes), np.zeros(n_classes)

        for imgs, metas, labels in train_loader:
            imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs, metas)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

            for c in range(n_classes):
                idx = (labels == c)
                class_total[c] += idx.sum().item()
                if idx.sum() > 0:
                    class_correct[c] += (preds[idx] == labels[idx]).sum().item()
                    class_loss_sum[c] += criterion(outputs[idx], labels[idx]).item() * idx.sum().item()

        train_loss = running_loss / total
        train_acc = 100. * correct / total
        train_acc_pc = 100. * class_correct / np.maximum(class_total, 1)
        train_loss_pc = class_loss_sum / np.maximum(class_total, 1)

        # ---------- VALIDATION ----------
        model.eval()
        running_loss, correct, total = 0, 0, 0
        class_correct, class_total, class_loss_sum = np.zeros(n_classes), np.zeros(n_classes), np.zeros(n_classes)

        with torch.no_grad():
            for imgs, metas, labels in val_loader:
                imgs, metas, labels = imgs.to(device), metas.to(device), labels.to(device)
                outputs = model(imgs, metas)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * labels.size(0)
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

                for c in range(n_classes):
                    idx = (labels == c)
                    class_total[c] += idx.sum().item()
                    if idx.sum() > 0:
                        class_correct[c] += (preds[idx] == labels[idx]).sum().item()
                        class_loss_sum[c] += criterion(outputs[idx], labels[idx]).item() * idx.sum().item()

        val_loss = running_loss / total
        val_acc = 100. * correct / total
        val_acc_pc = 100. * class_correct / np.maximum(class_total, 1)
        val_loss_pc = class_loss_sum / np.maximum(class_total, 1)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.2f}% | Val Loss {val_loss:.4f}, Acc {val_acc:.2f}%")
        for c in range(n_classes):
            print(f"  Grade {c+2}: Train Acc {train_acc_pc[c]:.2f}%, Loss {train_loss_pc[c]:.4f} | Val Acc {val_acc_pc[c]:.2f}%, Loss {val_loss_pc[c]:.4f}")

# =============================================
# 5. MAIN EXECUTION WITH STRATIFIED K-FOLD
# =============================================
if __name__ == "__main__":
    labels_dict, metadata_dict, all_patients = load_labels(csv_path, base_path)
    patients = list(labels_dict.keys())
    labels = np.array([labels_dict[p] for p in patients])

    # --- Calibrate class weights ---
    unique_classes, class_counts = np.unique(labels, return_counts=True)
    class_weights = 1.0 / np.sqrt(class_counts)
    class_weights = class_weights / class_weights.sum() * len(unique_classes)
    print("Adjusted class weights:", class_weights)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # --- Stratified K-Fold ---
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    for fold, (train_idx, val_idx) in enumerate(skf.split(patients, labels)):
        print(f"\n===== Fold {fold+1} / 5 =====")

        train_patients = [patients[i] for i in train_idx]
        val_patients = [patients[i] for i in val_idx]

        all_labels_train = [labels_dict[p] for p in train_patients]
        class_sample_count = np.array([all_labels_train.count(c) for c in np.unique(all_labels_train)])
        weights = 1.0 / np.sqrt(class_sample_count)
        samples_weight = np.array([weights[c] for c in all_labels_train])
        sampler = WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True)

        train_dataset = MRIDataset(base_path, train_patients, labels_dict, metadata_dict, augment=True)
        val_dataset = MRIDataset(base_path, val_patients, labels_dict, metadata_dict, augment=False)

        train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

        model = CNN2D_Meta(n_channels=5, n_meta=4, n_classes=3).to(device)
        weight_tensor = torch.FloatTensor(class_weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=weight_tensor)
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        train_model_per_class(model, train_loader, val_loader, criterion, optimizer, device=device, epochs=20)



Loaded 495 patients with labels
Class distribution:
WHO CNS Grade
2     56
3     43
4    396
Name: count, dtype: int64
Adjusted class weights: [1.19177886 1.36005184 0.44816929]

===== Fold 1 / 5 =====
Epoch 1/20 | Train Loss 5.2270, Acc 40.91% | Val Loss 1.7112, Acc 24.24%
  Grade 2: Train Acc 27.59%, Loss 6.0375 | Val Acc 63.64%, Loss 0.7162
  Grade 3: Train Acc 36.36%, Loss 6.1647 | Val Acc 0.00%, Loss 1.2538
  Grade 4: Train Acc 47.41%, Loss 3.8774 | Val Acc 21.25%, Loss 2.0504


#Balanced Input and Taking Multiple Slices

In [None]:
for g in optimizer.param_groups:
    g['lr'] = g['lr'] * 0.3
print("Reduced LR:", optimizer.param_groups[0]['lr'])

#Save the model to Google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
save_path = "/content/drive/MyDrive/brain_tumor_cnn/model_final.pth"
torch.save(model.state_dict(), save_path)
print("Model saved to:", save_path)