In [2]:
# =======================================
# 📦 Step 1: Imports and Setup
# =======================================
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
from torch.utils.data import Dataset
from PIL import Image
import os

class BrainMRIDataset(Dataset):
    def __init__(self, root_dir, mode='Training', transform=None):
        """
        Args:
            root_dir (str): Path to 'split_data'
            mode (str): 'Training' or 'Testing'
            transform: torchvision transforms
        """
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.data = []

        # Labels: Real = 0, Fake = 1
        class_map = {'Real': 0, 'Fake': 1}

        for class_name, label in class_map.items():
            folder_path = os.path.join(root_dir, class_name, mode)
            for img_file in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_file)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

In [5]:
# Step 3: Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
root_dir='./data/Dataset/'
# Load pre-split datasets from Step 2
train_dataset = BrainMRIDataset(root_dir=root_dir , mode='Training', transform=transform)
val_dataset = BrainMRIDataset(root_dir=root_dir , mode='Testing', transform=transform)

# Step 4: DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [6]:
from torchvision import models

# Load pre-trained ResNet-18 model
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer to match the number of classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # Assuming binary classification (Real vs Fake)

# Move the model to the appropriate device (GPU or CPU)
model = model.to(device)



In [7]:
# =======================================
# 🏃 Step 5: Training & Validation Functions
# =======================================
from sklearn.metrics import accuracy_score

# =======================================
# 🏃 Step 5: Training & Validation Functions
# =======================================
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # You can try a higher lr for faster convergence
num_epochs = 10

def train(model, train_loader, val_loader, epochs=num_epochs):
    best_acc = 0.0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"\nEpoch {epoch+1} - Training Loss: {avg_loss:.4f}")
        
        # Validation
        acc = evaluate(model, val_loader)
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_model.pth")
            print("✅ Best model saved!")

def evaluate(model, val_loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating"):
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            y_pred.extend(preds)
            y_true.extend(labels.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    print("\nClassification Report:\n", classification_report(y_true, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
    print("Accuracy:", acc)
    return acc

In [None]:
# =======================================
# 🔥 Step 6: Start Training
# =======================================
train(model, train_loader, val_loader, epochs=num_epochs)

Epoch 1: 100%|███████████████████████████████████████████████████████████████████████| 399/399 [05:27<00:00,  1.22it/s]



Epoch 1 - Training Loss: 0.0037


Validating: 100%|██████████████████████████████████████████████████████████████████████| 91/91 [01:52<00:00,  1.23s/it]



Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      3240
           1       1.00      1.00      1.00      2552

    accuracy                           1.00      5792
   macro avg       1.00      1.00      1.00      5792
weighted avg       1.00      1.00      1.00      5792

Confusion Matrix:
 [[3240    0]
 [   0 2552]]
Accuracy: 1.0
✅ Best model saved!


Epoch 2: 100%|███████████████████████████████████████████████████████████████████████| 399/399 [04:25<00:00,  1.51it/s]



Epoch 2 - Training Loss: 0.0000


Validating: 100%|██████████████████████████████████████████████████████████████████████| 91/91 [00:51<00:00,  1.76it/s]



Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      3240
           1       1.00      1.00      1.00      2552

    accuracy                           1.00      5792
   macro avg       1.00      1.00      1.00      5792
weighted avg       1.00      1.00      1.00      5792

Confusion Matrix:
 [[3240    0]
 [   0 2552]]
Accuracy: 1.0


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████| 399/399 [05:38<00:00,  1.18it/s]



Epoch 3 - Training Loss: 0.0000


Validating: 100%|██████████████████████████████████████████████████████████████████████| 91/91 [01:30<00:00,  1.01it/s]



Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      3240
           1       1.00      1.00      1.00      2552

    accuracy                           1.00      5792
   macro avg       1.00      1.00      1.00      5792
weighted avg       1.00      1.00      1.00      5792

Confusion Matrix:
 [[3240    0]
 [   0 2552]]
Accuracy: 1.0


Epoch 4:  13%|█████████▍                                                              | 52/399 [00:58<06:21,  1.10s/it]

In [None]:
# =======================================
# 💾 Step 7: Save the Trained Model
# =======================================
if not os.path.exists('model'):
    os.makedirs('model')
torch.save(model.state_dict(), "model/best_resnet_model.pth")
print("Model saved!")