In [2]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
import os
from pathlib import Path
import numpy as np
from collections import Counter

In [8]:
class BankLogoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # Get all bank folders
        self.banks = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.class_to_idx = {bank: idx for idx, bank in enumerate(sorted(self.banks))}
        
        # Initialize lists to store paths and labels
        self.image_paths = []
        self.labels = []
        
        # Collect paths and labels
        for bank in self.banks:
            bank_dir = self.root_dir / bank
            if not bank_dir.exists():
                print(f"Warning: {bank_dir} does not exist")
                continue
                
            for img_path in bank_dir.glob('*'):
                if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.tif']:
                    self.image_paths.append(img_path)
                    self.labels.append(self.class_to_idx[bank])
        
        # Print dataset statistics
        self._print_stats()

    def _print_stats(self):
        label_counts = Counter(self.labels)
        print("\nDataset Statistics:")
        print("------------------")
        for bank, idx in self.class_to_idx.items():
            count = label_counts[idx]
            print(f"{bank}: {count} images")
        print(f"Total: {len(self.labels)} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.zeros((3, 224, 224)), label

In [4]:
def create_balanced_sampler(dataset):
    """Create a weighted sampler to balance the dataset"""
    # Count samples per class
    class_counts = Counter(dataset.labels)
    
    # Calculate weight for each sample
    weights = [1.0 / class_counts[label] for label in dataset.labels]
    
    # Convert weights to tensor
    weights = torch.DoubleTensor(weights)
    
    # Create sampler with replacement
    sampler = WeightedRandomSampler(
        weights=weights,
        num_samples=len(dataset) * 2,  # Oversample to 2x the original dataset size
        replacement=True
    )
    
    return sampler

In [5]:
def get_transforms(train=True):
    """Get transforms with heavy augmentation for training"""
    if train:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(20),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),  # Random translation
                scale=(0.9, 1.1),      # Random scaling
                shear=10               # Random shearing
            ),
            transforms.ColorJitter(
                brightness=0.3,
                contrast=0.3,
                saturation=0.2,
                hue=0.1
            ),
            transforms.RandomGrayscale(p=0.1),  # Randomly convert to grayscale
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

In [6]:
def create_dataloader(root_dir, batch_size=32, train=True):
    """Create dataloader with balanced sampling for training"""
    # Create dataset
    transform = get_transforms(train=train)
    dataset = BankLogoDataset(root_dir, transform=transform)
    
    if train:
        # Create balanced sampler for training
        sampler = create_balanced_sampler(dataset)
        shuffle = False  # Don't shuffle when using sampler
    else:
        sampler = None
        shuffle = True
    
    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle,
        num_workers=4,
        pin_memory=True
    )
    
    return dataloader, dataset

In [None]:
# Example usage
data_dir = "Misc/logos"
train_loader, dataset = create_dataloader(data_dir, batch_size=32, train=True)

# Print dataset statistics
print(f"Number of banks: {len(dataset.banks)}")
print(f"Total images before oversampling: {len(dataset)}")
print(f"Total batches after oversampling: {len(train_loader)}")

In [None]:
def main():
    # Example usage
    data_dir = "Misc/logos"
    train_loader, train_dataset = create_dataloader(data_dir, batch_size=32, train=True)
    
    # Print some statistics about the balanced dataset
    print("\nBalanced Dataset Statistics:")
    print("--------------------------")
    batch_labels = []
    for _, labels in train_loader:
        batch_labels.extend(labels.tolist())
    
    balanced_counts = Counter(batch_labels)
    for bank, idx in train_dataset.class_to_idx.items():
        print(f"{bank}: {balanced_counts[idx]} samples after balancing")

if __name__ == "__main__":
    main()