In [9]:
%pip install dataloader


Collecting dataloader
  Downloading dataloader-2.0.tar.gz (9.1 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: dataloader
  Building wheel for dataloader (pyproject.toml): started
  Building wheel for dataloader (pyproject.toml): finished with status 'done'
  Created wheel for dataloader: filename=dataloader-2.0-py3-none-any.whl size=10124 sha256=8b2a92e00d22b576b05a57c72cffc4a27ac9350dace361062db9991fd3404b97
  Stored in directory: c:\users\aryan\appdata\local\pip\cache\wheels\bf\90\e6\4bbc34a5c10ed4c30423dc34f3977454039d002d9450b0e5f6
Successfully built dataloader
Installing collected packages: dataloader
Successfully installed dataloader-2.0
Note: you may need to

In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data import WeightedRandomSampler

# 1. Transformations for the TRAINING Set
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=10, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
])

# 2. Transformations for the VALIDATION Set
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
])

# 3. Custom Dataset class with filtering of bad samples inside __getitem__
class MammogramDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.label_map = {'benign': 0, 'malignant': 1}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        try:
            img = Image.open(row['image_path']).convert('L')
            if self.transform:
                img = self.transform(img)
            label = self.label_map[row['label']]
            return img, label
        except Exception as e:
            print(f"⚠️  Error loading image {row['image_path']} at index {idx}: {e}")
            # Return None to signal bad sample, handle in collate_fn or dataloader
            return None, None

# 4. Custom collate_fn to filter out None samples
def collate_fn(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    if not batch:
        return torch.tensor([]), torch.tensor([])
    
    return DataLoader.default_collate(batch)

# 5. splitting data into training and validation sets
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
BASE_DIR = r'C:\Users\aryan\OneDrive\Desktop\hopescan_project\data'
train_df = pd.read_csv(f'{BASE_DIR}/processed_train_data.csv')
test_df = pd.read_csv(f'{BASE_DIR}/processed_test_data.csv')

print("➡️ Splitting data based on 'patient_id' to prevent data leakage...")
splitter = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
train_indices, val_indices = next(splitter.split(train_df, groups=train_df['patient_id']))

# 6. Create the final training and validation DataFrames
train_df_split = train_df.iloc[train_indices]
val_df_split = train_df.iloc[val_indices]

print(f"Training samples: {len(train_df_split)} | Validation samples: {len(val_df_split)}")
print(f"Number of unique patients in training set: {train_df_split['patient_id'].nunique()}")
print(f"Number of unique patients in validation set: {val_df_split['patient_id'].nunique()}")

# 7. Creating Dataset
train_dataset = MammogramDataset(train_df_split, transform=train_transforms)
val_dataset = MammogramDataset(val_df_split, transform=val_transforms)
test_dataset = MammogramDataset(test_df, transform=val_transforms)

print("➡️ Creating weighted sampler to address class imbalance...")
class_counts = train_df_split['label'].value_counts()
class_weights = torch.tensor([1.0 / class_counts['benign'], 1.0 / class_counts['malignant']]).float()
print(f"Class counts: {class_counts.to_dict()}")
print(f"Class weights: {class_weights}")

# 8. Creating sample weights for each training sample
print("➡️ Creating sample weights...")
labels = train_df_split['label'].map(train_dataset.label_map).tolist()
sample_weights = [class_weights[label].item() for label in labels]


# Create the sampler
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

print("✅ WeightedRandomSampler created successfully.")

# 9. Create DataLoaders

print("➡️ Creating DataLoaders...")
train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn)

print(f"Valid training samples: {len(train_dataset)}")
print(f"Valid validation samples: {len(val_dataset)}")
print(f"Valid test samples: {len(test_dataset)}")
print("✅ Setup complete, ready to train!")

➡️ Splitting data based on 'patient_id' to prevent data leakage...
Training samples: 2277 | Validation samples: 585
Number of unique patients in training set: 998
Number of unique patients in validation set: 250
➡️ Creating weighted sampler to address class imbalance...
Class counts: {'benign': 1340, 'malignant': 937}
Class weights: tensor([0.0007, 0.0011])
➡️ Creating sample weights...
✅ WeightedRandomSampler created successfully.
➡️ Creating DataLoaders...
Valid training samples: 2277
Valid validation samples: 585
Valid test samples: 704
✅ Setup complete, ready to train!
