# **AI-Powered Brain Tumor & Dementia Training Suite**

This notebook trains three deep learning models:1.  **Gatekeeper:** A ResNet50 classifier (Normal vs. Tumor vs. Dementia).
2.  **Tumor Specialist:** An EfficientNet-B3 model (Glioma vs. Meningioma vs. Pituitary).
3.  **Dementia Specialist:** A MobileNetV3 model (Mild vs. Moderate vs. Very Mild).

## **Instructions**
1.  **Upload Data:** Upload the `data.zip` file generated by the CLI tool to the Files tab on the left.
2.  **Runtime:** Go to `Runtime` -> `Change runtime type` -> Select `T4 GPU`.
3.  **Run All:** Click `Runtime` -> `Run all`.
4.  **Download:** Once finished, download the `trained_models.zip` file containing your new models and training graphs.

In [None]:
# Check for GPU
import torch
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: GPU not found. Runtime -> Change runtime type -> T4 GPU")

In [None]:
# Install dependencies (if needed)
!pip install torch torchvision matplotlib tqdm

In [None]:
# Unzip Data
import os
import zipfile

if not os.path.exists('data'):
    if os.path.exists('data.zip'):
        print("Unzipping data.zip...")
        with zipfile.ZipFile('data.zip', 'r') as zip_ref:
            zip_ref.extractall('.')
        print("Data extracted successfully.")
    else:
        print("ERROR: 'data.zip' not found. Please upload it to the Files tab.")
else:
    print("Data folder already exists.")

In [None]:
# Define Gatekeeper Model (Inline)
import torch.nn as nn
from torchvision import models

class GatekeeperClassifier(nn.Module):
    def __init__(self, num_classes=3, freeze_base=True):
        super(GatekeeperClassifier, self).__init__()
        try:
            self.base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        except:
            self.base_model = models.resnet50(pretrained=True)

        if freeze_base:
            for param in self.base_model.parameters():
                param.requires_grad = False
        
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes) 
        )

    def forward(self, x):
        return self.base_model(x)

In [None]:
# Main Training Script (Inline)
import os
import glob
import time
import datetime
import logging
import json
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image

# Logging Setup
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(message)s', datefmt='%H:%M:%S')
logger = logging.getLogger()

# Configuration
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 0.001
IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
DATA_ROOT = 'data'
MODELS_DIR = 'models'
LOGS_DIR = 'training_logs'
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)

# Dataset Class
class MedicalImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        label = self.labels[idx]
        try:
            image = Image.open(path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            logger.error(f"Error loading image {path}: {e}")
            return torch.zeros((3, IMG_SIZE, IMG_SIZE)), torch.tensor(label, dtype=torch.long)

# Helpers
def format_time(seconds):
    return str(datetime.timedelta(seconds=int(seconds)))

def get_transforms():
    train_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return train_tf, val_tf

def plot_training_history(history, model_name):
    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history['train_loss'], label='Training Loss')
    plt.plot(epochs, history['val_loss'], label='Validation Loss')
    plt.title(f'{model_name} - Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(LOGS_DIR, f'{model_name.replace(" ", "_")}_loss.png'))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history['train_acc'], label='Training Accuracy')
    plt.plot(epochs, history['val_acc'], label='Validation Accuracy')
    plt.title(f'{model_name} - Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(os.path.join(LOGS_DIR, f'{model_name.replace(" ", "_")}_accuracy.png'))
    plt.close()

def train_and_save(model, train_loader, val_loader, save_path, num_classes, model_name):
    logger.info(f"STARTING TRAINING: {model_name} on {DEVICE}")
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_running_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        logger.info(f"Epoch {epoch+1}/{NUM_EPOCHS}: Train Loss={epoch_loss:.4f}, Acc={epoch_acc:.4f} | Val Loss={val_loss:.4f}, Acc={val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path) # Saving state dict is safer for portability
            logger.info(f"New best model saved! (Acc: {best_acc:.4f})")

    with open(os.path.join(LOGS_DIR, f'{model_name.replace(" ", "_")}_history.json'), 'w') as f:
        json.dump(history, f, indent=4)
    plot_training_history(history, model_name)

def gather_files():
    def get_files(pattern):
        return glob.glob(os.path.join(DATA_ROOT, pattern), recursive=True)

    bt_train = os.path.join('brain_tumor', 'Training')
    bt_test = os.path.join('brain_tumor', 'Testing')
    glioma = get_files(os.path.join(bt_train, 'glioma', '*')) + get_files(os.path.join(bt_test, 'glioma', '*'))
    meningioma = get_files(os.path.join(bt_train, 'meningioma', '*')) + get_files(os.path.join(bt_test, 'meningioma', '*'))
    pituitary = get_files(os.path.join(bt_train, 'pituitary', '*')) + get_files(os.path.join(bt_test, 'pituitary', '*'))
    tumor_normal = get_files(os.path.join(bt_train, 'notumor', '*')) + get_files(os.path.join(bt_test, 'notumor', '*'))
    
    alz = 'alzheimers'
    mild = get_files(os.path.join(alz, 'MildDemented', '*'))
    moderate = get_files(os.path.join(alz, 'ModerateDemented', '*'))
    very_mild = get_files(os.path.join(alz, 'VeryMildDemented', '*'))
    alz_normal = get_files(os.path.join(alz, 'NonDemented', '*'))
    return {'glioma': glioma, 'meningioma': meningioma, 'pituitary': pituitary, 'tumor_normal': tumor_normal, 'mild': mild, 'moderate': moderate, 'very_mild': very_mild, 'alz_normal': alz_normal}

# Execution Block
files = gather_files()
train_tf, val_tf = get_transforms()

# 1. Gatekeeper
gk_paths = files['tumor_normal'] + files['alz_normal'] + files['glioma'] + files['meningioma'] + files['pituitary'] + files['mild'] + files['moderate'] + files['very_mild']
gk_labels = [0]*len(files['tumor_normal'] + files['alz_normal']) + [1]*len(files['glioma'] + files['meningioma'] + files['pituitary']) + [2]*len(files['mild'] + files['moderate'] + files['very_mild'])
train_idx, val_idx = random_split(range(len(gk_paths)), [int(0.8*len(gk_paths)), len(gk_paths)-int(0.8*len(gk_paths))])
train_loader = DataLoader(MedicalImageDataset([gk_paths[i] for i in train_idx.indices], [gk_labels[i] for i in train_idx.indices], transform=train_tf), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(MedicalImageDataset([gk_paths[i] for i in val_idx.indices], [gk_labels[i] for i in val_idx.indices], transform=val_tf), batch_size=BATCH_SIZE, shuffle=False)
train_and_save(GatekeeperClassifier(num_classes=3), train_loader, val_loader, os.path.join(MODELS_DIR, 'gatekeeper_classifier.pt'), 3, "Gatekeeper")

# 2. Tumor
tm_paths = files['glioma'] + files['meningioma'] + files['pituitary']
tm_labels = [0]*len(files['glioma']) + [1]*len(files['meningioma']) + [2]*len(files['pituitary'])
train_idx, val_idx = random_split(range(len(tm_paths)), [int(0.8*len(tm_paths)), len(tm_paths)-int(0.8*len(tm_paths))])
train_loader = DataLoader(MedicalImageDataset([tm_paths[i] for i in train_idx.indices], [tm_labels[i] for i in train_idx.indices], transform=train_tf), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(MedicalImageDataset([tm_paths[i] for i in val_idx.indices], [tm_labels[i] for i in val_idx.indices], transform=val_tf), batch_size=BATCH_SIZE, shuffle=False)
tm_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
tm_model.classifier[1] = nn.Linear(tm_model.classifier[1].in_features, 3)
train_and_save(tm_model, train_loader, val_loader, os.path.join(MODELS_DIR, 'brain_tumor_classifier.pt'), 3, "Tumor Specialist")

# 3. Dementia
dm_paths = files['mild'] + files['moderate'] + files['very_mild']
dm_labels = [0]*len(files['mild']) + [1]*len(files['moderate']) + [2]*len(files['very_mild'])
train_idx, val_idx = random_split(range(len(dm_paths)), [int(0.8*len(dm_paths)), len(dm_paths)-int(0.8*len(dm_paths))])
train_loader = DataLoader(MedicalImageDataset([dm_paths[i] for i in train_idx.indices], [dm_labels[i] for i in train_idx.indices], transform=train_tf), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(MedicalImageDataset([dm_paths[i] for i in val_idx.indices], [dm_labels[i] for i in val_idx.indices], transform=val_tf), batch_size=BATCH_SIZE, shuffle=False)
dm_model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)
dm_model.classifier[3] = nn.Linear(dm_model.classifier[3].in_features, 3)
train_and_save(dm_model, train_loader, val_loader, os.path.join(MODELS_DIR, 'alzheimers_classifier.pt'), 3, "Dementia Specialist")

In [None]:
# Zip Results
!zip -r trained_models.zip models/ training_logs/
from google.colab import files
files.download('trained_models.zip')