<a href="https://colab.research.google.com/github/arogya-gyawali/brainscan_AI/blob/main/notebooks/data_prep/06_build_pytorch_dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🔄 02_build_pytorch_dataloader.ipynb — Custom Dataset and DataLoader

This notebook builds and tests a **PyTorch-compatible data loading pipeline** for the BrainScan AI project.

---

## 🧩 What It Does

- Defines a custom `BrainMRIDataset` class for loading `.png` MRI images using metadata from CSV
- Applies torchvision transforms (resizing, normalization)
- Returns PyTorch tensors with shape `[batch_size, 1, 224, 224]`
- Verifies label loading, shapes, and batch integrity using `DataLoader`

---

## 📤 Output

- `train_loader`, `val_loader`, and `test_loader` objects
- Ready-to-use data pipeline for training models

> ⚠️ Make sure your image folder path and metadata CSV are correctly linked from your Google Drive.


In [1]:
# Mount Drive and import libraries
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [11]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class BrainMRIDataset(Dataset):
    def __init__(self, dataframe, image_root_mri, image_root_healthy, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame containing 'file' and 'label'
            image_root_mri (str): Path to tumor MRI image folder
            image_root_healthy (str): Path to healthy MRI image folder
            transform (torchvision.transforms): Image transformations
        """
        self.df = dataframe
        self.image_root_mri = image_root_mri
        self.image_root_healthy = image_root_healthy
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = row["file"]

        # Choose path based on whether it's a healthy image
        if "IXI" in filename:
            img_path = os.path.join(self.image_root_healthy, filename)
        else:
            img_path = os.path.join(self.image_root_mri, filename)

        image = Image.open(img_path).convert("L")  # Convert to grayscale

        if self.transform:
            image = self.transform(image)

        label = int(row["label"])
        return image, label


In [12]:
from torch.utils.data import DataLoader
from torchvision import transforms

# === Paths ===
image_root_mri = "/content/drive/MyDrive/BrainScanAI/BrainScanAI_final_output/mri"
image_root_healthy = "/content/drive/MyDrive/BrainScanAI/BrainScanAI_final_output/no_tumor_IXI_png"

# Load split metadata
split_path = "/content/drive/MyDrive/BrainScanAI/splits"
train_df = pd.read_csv(os.path.join(split_path, "train.csv"))
val_df   = pd.read_csv(os.path.join(split_path, "val.csv"))
test_df  = pd.read_csv(os.path.join(split_path, "test.csv"))

# Define transforms (resize to 224x224 + normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# === Create Dataset Instances ===
train_dataset = BrainMRIDataset(train_df, image_root_mri, image_root_healthy, transform)
val_dataset   = BrainMRIDataset(val_df, image_root_mri, image_root_healthy, transform)
test_dataset  = BrainMRIDataset(test_df, image_root_mri, image_root_healthy, transform)

# === Create DataLoaders ===
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)


In [13]:
images, labels = next(iter(train_loader))
print("Batch shape:", images.shape)   # (batch_size, 1, 224, 224)
print("Labels:", labels)


Batch shape: torch.Size([32, 1, 224, 224])
Labels: tensor([0, 0, 0, 0, 2, 2, 0, 2, 3, 1, 2, 2, 0, 1, 3, 1, 1, 3, 0, 2, 3, 3, 2, 3,
        2, 0, 2, 0, 2, 2, 3, 3])
