# Alzheimer's MRI Detection - Colab Training
## Train models and download them locally

### Step 1: Clone & Install

In [None]:
!git clone https://github.com/CreativeDragon1/BrainGuard.git
%cd BrainGuard
!pip install -q torch torchvision torchaudio pandas scikit-learn pillow tqdm matplotlib

### Step 2: Upload Dataset

In [None]:
from google.colab import files
import os, shutil

print('Click Choose Files and upload: train.parquet and test.parquet')
uploaded = files.upload()

os.makedirs('Assets/Datasets/MRI Dataset', exist_ok=True)
for f in uploaded:
    if f.endswith('.parquet'):
        shutil.move(f, f'Assets/Datasets/MRI Dataset/{f}')
        print(f'✓ {f} ready')
print('Dataset uploaded!')

### Step 3: Import & Setup

In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from pathlib import Path
import numpy as np, pandas as pd, io, json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

from models.cnn_model import ResNetModel

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

### Step 4: Define Classes

In [None]:
class MRIDataset(Dataset):
    def __init__(self, records, labels, train=True):
        self.records = records
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        img = Image.open(io.BytesIO(self.records[idx])).convert('L')
        img = self.transform(img)
        return img, self.labels[idx]

print('Dataset class ready')

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader=None, epochs=50, lr=1e-3):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.epochs = epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
        self.history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    def train_epoch(self):
        self.model.train()
        loss_total = acc_total = count = 0
        for imgs, lbls in tqdm(self.train_loader, desc='Train'):
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            self.optimizer.zero_grad()
            out = self.model(imgs)
            loss = self.criterion(out, lbls)
            loss.backward()
            self.optimizer.step()
            loss_total += loss.item()
            acc_total += (out.argmax(1) == lbls).sum().item()
            count += lbls.size(0)
        return loss_total / len(self.train_loader), acc_total / count
    
    def val(self):
        self.model.eval()
        loss_total = acc_total = count = 0
        with torch.no_grad():
            for imgs, lbls in tqdm(self.val_loader, desc='Val'):
                imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
                out = self.model(imgs)
                loss = self.criterion(out, lbls)
                loss_total += loss.item()
                acc_total += (out.argmax(1) == lbls).sum().item()
                count += lbls.size(0)
        return loss_total / len(self.val_loader), acc_total / count
    
    def train(self):
        for ep in range(self.epochs):
            tl, ta = self.train_epoch()
            vl, va = self.val()
            self.history['train_loss'].append(tl)
            self.history['val_loss'].append(vl)
            self.history['train_acc'].append(ta)
            self.history['val_acc'].append(va)
            print(f'Ep {ep+1}/{self.epochs}: TL={tl:.4f} TA={ta:.4f} VL={vl:.4f} VA={va:.4f}')
        self.save_model()
    
    def save_model(self):
        os.makedirs('models', exist_ok=True)
        torch.save(self.model.state_dict(), 'models/best_resnet.pth')
        print('Model saved!')

print('Trainer ready')

### Step 5: Load Data & Train

In [None]:
train_path = Path('Assets/Datasets/MRI Dataset/train.parquet')
test_path = Path('Assets/Datasets/MRI Dataset/test.parquet')

df_train = pd.read_parquet(train_path)
df_test = pd.read_parquet(test_path) if test_path.exists() else None

print(f'Train: {len(df_train)}, Test: {len(df_test) if df_test is not None else 0}')

train_recs = [r['image']['bytes'] for _, r in df_train.iterrows()]
train_lbls = df_train['label'].tolist()

idx = np.arange(len(train_recs))
np.random.shuffle(idx)
split = int(0.1 * len(idx))

train_ds = MRIDataset([train_recs[i] for i in idx[split:]], [train_lbls[i] for i in idx[split:]])
val_ds = MRIDataset([train_recs[i] for i in idx[:split]], [train_lbls[i] for i in idx[:split]])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

print('Data ready')

In [None]:
model = ResNetModel(pretrained=True, num_classes=4)
trainer = Trainer(model, train_loader, val_loader, epochs=50, lr=1e-3)
trainer.train()

### Step 6: Download Models

In [None]:
from google.colab import files

print('Downloading model...')
files.download('models/best_resnet.pth')

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(trainer.history['train_loss'], label='Train')
plt.plot(trainer.history['val_loss'], label='Val')
plt.title('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(trainer.history['train_acc'], label='Train')
plt.plot(trainer.history['val_acc'], label='Val')
plt.title('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_results.png')
plt.show()

files.download('training_results.png')
print('Done!')