In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
from torch.utils.data import random_split
import time
from sklearn.utils.class_weight import compute_class_weight
import os
import time
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
# ======================
#  DATASET & NORMALIZATION
# ======================
train_dir = './data/Training'

# Temporary loader for mean/std
temp_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
temp_dataset = datasets.ImageFolder(train_dir, transform=temp_transform)
temp_loader = DataLoader(temp_dataset, batch_size=64, shuffle=False, num_workers=min(2, os.cpu_count() // 2), pin_memory=(device.type == "cuda"))

imgs = torch.cat([img for img, _ in temp_loader])
mean, std = imgs.mean([0, 2, 3]), imgs.std([0, 2, 3])
print("Mean:", mean.item(), "Std:", std.item())

# Final transforms (mild augmentations)
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    # These 3 transformations are recommended to not get caught up on orientation or lighting
    # Shouldn't use more as too many can distort or cause slower convergence
    transforms.RandomHorizontalFlip(p=0.2),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    # -------------------------------------------------------------------------------------
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Reload full dataset with final transforms
full_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=min(2, os.cpu_count() // 2), pin_memory=(device.type == "cuda"))
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=min(2, os.cpu_count() // 2), pin_memory=(device.type == "cuda"))

print(f"Train images: {len(train_dataset)}, Validation images: {len(val_dataset)}")
print("Classes:", full_dataset.classes)

In [None]:
# ======================
#  CLASS WEIGHTS
# ======================
train_labels = [full_dataset.samples[i][1] for i in train_dataset.indices]
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)
weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print("Class Weights:", weights)

# ======================
#  MODEL (ResNet-18 Transfer Learning)
# ======================
model = models.resnet18(weights='IMAGENET1K_V1')
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 4)
model = model.to(device)

# Optionally freeze early layers
for name, param in model.named_parameters():
    if "layer1" in name or "layer2" in name:
        param.requires_grad = False

criterion = nn.CrossEntropyLoss(weight=weights) # More chill imbalance compensation, better for pretrained

# Typical optimizer and scheduler for pretrained, specifies steps and convergence
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3)

# ======================
#  VALIDATION FUNCTION
# ======================
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return val_loss / len(val_loader), correct / total

# ======================
#  SAVE UTIL
# ======================
save_dir = "./data"
def next_version(prefix, ext):
    existing = [int(f.split("_")[-1].split(".")[0])
                for f in os.listdir(os.path.join(save_dir, "models"))
                if f.startswith(prefix) and f.endswith(ext)
                and f.split("_")[-1].split(".")[0].isdigit()]
    return max(existing) + 1 if existing else 1

model_number = next_version("Brain_Tumor_Model_pretrained", ".pth")
model_path = os.path.join(save_dir, "models", f"Brain_Tumor_Model_pretrained_{model_number}.pth")
print(f"ðŸ§© Saving model to: {model_path}")

# ======================
# TRAINING LOOP
# ======================
total_start = time.time()
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop_patience = 7
num_epochs = 30

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    loop = tqdm(train_loader, leave=True)
    loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")

    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    scheduler.step(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch + 1,
            "val_loss": val_loss,
            "training_time": time.time() - total_start
        }
        torch.save(checkpoint, model_path)
        print(f"ðŸ’¾ New best model saved (Val Loss: {val_loss:.4f})")
    else:
        epochs_no_improve += 1
        print(f"ðŸ•’ No improvement for {epochs_no_improve} epoch(s)")
        if epochs_no_improve >= early_stop_patience:
            print(f"\nðŸ›‘ Early stopping at epoch {epoch+1}")
            break

print(f"\nâœ… Total Training Time: {(time.time()-total_start)/60:.2f} min")

In [None]:
# ======================
# LOAD SPECIFIC SAVED MODEL
# ======================

import os
import torch
import torch.nn as nn
from torchvision import models
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Re-create the same model architecture used during training
save_dir = "./data"
model_number = 7 # Change to version you want
model_path = os.path.join(save_dir, "models", f"Brain_Tumor_Model_pretrained_{model_number}.pth")

print(f"ðŸ§© Loading pretrained ResNet-18 model from: {model_path}")

# Recreate the same architecture used during training
model = models.resnet18(weights='IMAGENET1K_V1')
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 4)
model = model.to(device)

# âœ… Load full checkpoint (not just weights)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
training_time = checkpoint.get("training_time", 0)
epoch_trained = checkpoint.get("epoch", "N/A")

print(f"âœ… Model loaded from checkpoint trained for {training_time/60:.2f} minutes "
      f"({epoch_trained} epochs)")

# Set to evaluation mode
model.eval()
print("âœ… Model ready for inference.")

In [None]:
# ======================
# TEST + REPORT
# ======================
test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_dir = "./data/Testing"
test_dataset = datasets.ImageFolder(root=test_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
class_names = test_dataset.classes

model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

y_true, y_pred = np.array(y_true), np.array(y_pred)
acc = (y_true == y_pred).sum() / len(y_true)
print(f"\nâœ… Test Accuracy: {acc*100:.2f}%")

# Classification report
report = classification_report(y_true, y_pred, target_names=class_names)
print("\nClassification Report:\n", report)

report_path = os.path.join(save_dir, "reports", f"eval_report_{model_number}.txt")

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()

metrics = precision_recall_fscore_support(y_true, y_pred, labels=[0,1,2,3])
print("Per-class Precision:", metrics[0])
print("Per-class Recall:", metrics[1])

# ======================
# SAVE REPORT
# ======================
with open(report_path, "w") as f:
    f.write(f"Model file: {os.path.basename(model_path)}\n")
    f.write(f"Total Training Time: {training_time/60:.2f} minutes\n")
    f.write(f"\nTest Accuracy: {acc*100:.2f}%\n\n")
    f.write(report)
print(f"ðŸ“„ Report saved to: {report_path}")