# BrainGuard - Alzheimer's MRI Detection
## Google Colab Training Notebook

In [None]:
!git clone https://github.com/CreativeDragon1/BrainGuard.git
%cd BrainGuard

In [None]:
!pip install -q torch torchvision torchaudio pandas scikit-learn pillow tqdm matplotlib

### Upload Your Dataset

Click 'Choose Files' and upload: train.parquet and test.parquet

In [None]:
from google.colab import files
import os, shutil

uploaded = files.upload()
os.makedirs('Assets/Datasets/MRI Dataset', exist_ok=True)

for f in uploaded:
    if f.endswith('.parquet'):
        shutil.move(f, f'Assets/Datasets/MRI Dataset/{f}')
        print(f'Uploaded: {f}')

In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from pathlib import Path
import numpy as np, pandas as pd, io
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from models.cnn_model import ResNetModel

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

In [None]:
class MRIDataset(Dataset):
    def __init__(self, records, labels):
        self.records = records
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])
    
    def __len__(self):
        return len(self.records)
    
    def __getitem__(self, idx):
        img = Image.open(io.BytesIO(self.records[idx])).convert('L')
        img = self.transform(img)
        return img, self.labels[idx]

In [None]:
# Load data
train_path = Path('Assets/Datasets/MRI Dataset/train.parquet')
test_path = Path('Assets/Datasets/MRI Dataset/test.parquet')

df_train = pd.read_parquet(train_path)
df_test = pd.read_parquet(test_path) if test_path.exists() else None

train_recs = [r['image']['bytes'] for _, r in df_train.iterrows()]
train_lbls = df_train['label'].tolist()

idx = np.arange(len(train_recs))
np.random.shuffle(idx)
split = int(0.1 * len(idx))

train_ds = MRIDataset([train_recs[i] for i in idx[split:]], [train_lbls[i] for i in idx[split:]])
val_ds = MRIDataset([train_recs[i] for i in idx[:split]], [train_lbls[i] for i in idx[:split]])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

print(f'Train: {len(train_ds)}, Val: {len(val_ds)}')

In [None]:
# Training loop
model = ResNetModel(pretrained=True, num_classes=4).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(50):
    model.train()
    train_loss = train_acc = count = 0
    for imgs, lbls in tqdm(train_loader, desc=f'Ep {epoch+1}'):
        imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, lbls)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (out.argmax(1) == lbls).sum().item()
        count += lbls.size(0)
    
    model.eval()
    val_loss = val_acc = val_count = 0
    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            out = model(imgs)
            loss = criterion(out, lbls)
            val_loss += loss.item()
            val_acc += (out.argmax(1) == lbls).sum().item()
            val_count += lbls.size(0)
    
    history['train_loss'].append(train_loss / len(train_loader))
    history['train_acc'].append(train_acc / count)
    history['val_loss'].append(val_loss / len(val_loader))
    history['val_acc'].append(val_acc / val_count)
    
    if (epoch + 1) % 5 == 0:
        print(f'Epoch {epoch+1}: TL={history["train_loss"][-1]:.4f} TA={history["train_acc"][-1]:.4f} VL={history["val_loss"][-1]:.4f} VA={history["val_acc"][-1]:.4f}')

print('Training complete!')

In [None]:
# Save model
import os
os.makedirs('models', exist_ok=True)
torch.save(model.state_dict(), 'models/best_resnet.pth')
print('Model saved to models/best_resnet.pth')

In [None]:
# Plot results
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('results.png')
plt.show()

In [None]:
# Download files
from google.colab import files

print('Downloading model...')
files.download('models/best_resnet.pth')
print('Downloading results...')
files.download('results.png')
print('Done! Files are ready in your Downloads folder')