# 🎯 Tutorial: Deep Learning for Imbalanced Classification

![Imbalanced Classification](https://miro.medium.com/v2/resize:fit:1400/1*ZygUZWafGHOI02wI5ML5aA.png)

## Welcome to the World of Imbalanced Data! ⚖️

In this comprehensive tutorial, you'll learn:
- 📊 What is imbalanced classification and why it's challenging
- 🧮 Understanding class distribution and its impact on model performance
- 🤖 Advanced techniques for handling imbalanced datasets
- 💻 Hands-on implementation with PyTorch and CNNs
- 🎯 Weighted sampling and custom loss functions
- 🧪 Interactive exercises to master imbalanced classification

By the end, you'll be equipped to tackle any imbalanced classification problem with confidence!


## 📚 Table of Contents

1. [🎓 Understanding Imbalanced Classification](#1--understanding-imbalanced-classification)
2. [📊 The Problem with Traditional Approaches](#2--the-problem-with-traditional-approaches)
3. [🔧 Setting Up the Environment](#3--setting-up-the-environment)
4. [📈 Analyzing Class Distribution](#4--analyzing-class-distribution)
5. [🏗️ CNN Architecture for Binary Classification](#5--cnn-architecture-for-binary-classification)
6. [⚖️ Weighted Sampling Techniques](#6--weighted-sampling-techniques)
7. [🎯 Custom Loss Functions](#7--custom-loss-functions)
8. [🧠 Building the Complete Solution](#8--building-the-complete-solution)
9. [📊 Evaluation Metrics](#9--evaluation-metrics)
10. [🎮 Interactive Exercises](#10--interactive-exercises)
11. [🚀 Advanced Techniques](#11--advanced-techniques)
12. [📖 Summary and Next Steps](#12--summary-and-next-steps)


## 1. 🎓 Understanding Imbalanced Classification

### What is Imbalanced Classification?

Imagine you're building a medical diagnosis system 🏥. Out of 1000 patients:
- 950 are healthy (Class 0: "normal")
- 50 have a rare disease (Class 1: "positive")

This is a **severely imbalanced dataset** with a ratio of 19:1!

### The Challenge 😰

A naive classifier could achieve 95% accuracy by simply predicting "healthy" for everyone. But this would be **catastrophic** - we'd miss all the sick patients!

### Key Concepts:

- **Majority Class**: The class with more samples (Class 0: "normal")
- **Minority Class**: The class with fewer samples (Class 1: "onion")
- **Class Imbalance Ratio**: How skewed the distribution is
- **Sampling Bias**: Traditional training favors the majority class

### Real-World Examples:

🏥 **Medical Diagnosis**: Rare diseases vs. healthy patients  
🔒 **Fraud Detection**: Fraudulent vs. legitimate transactions  
📧 **Spam Detection**: Spam vs. legitimate emails  
🏭 **Quality Control**: Defective vs. good products  
🚨 **Anomaly Detection**: Abnormal vs. normal behavior  

### Mathematical Definition:

Given a dataset $D = \\{(x_1, y_1), (x_2, y_2), ..., (x_n, y_n)\\}$ where $y_i \\in \\{0, 1\\}$:

$$\\text{Imbalance Ratio} = \\frac{|\\{i : y_i = 0\\}|}{|\\{i : y_i = 1\\}|}$$

When this ratio is significantly > 1, we have an imbalanced dataset.

### Our Specific Problem 🎯

In this tutorial, we'll work with:
- **Images**: 224×224 grayscale images
- **Classes**: "normal" (geometric shapes) vs. "onion" (layered patterns)
- **Challenge**: Training data is imbalanced, test data is balanced
- **Goal**: Build a CNN that works well on balanced test data


## 2. 📊 The Problem with Traditional Approaches

### Why Standard Training Fails

Traditional machine learning assumes **balanced datasets**. When we have imbalanced data:

#### ❌ Problem 1: Biased Learning
The model sees the majority class 19× more often, so it learns to predict it by default.

#### ❌ Problem 2: Misleading Accuracy
High accuracy doesn't mean good performance - it might just reflect class distribution.

#### ❌ Problem 3: Poor Generalization
Models trained on imbalanced data often fail on balanced test sets.

### The Mathematics Behind the Problem

Consider a standard loss function (Cross-Entropy):

$$L = -\\frac{1}{N} \\sum_{i=1}^{N} [y_i \\log(\\hat{y}_i) + (1-y_i) \\log(1-\\hat{y}_i)]$$

With imbalanced data:
- Majority class contributes to ~95% of the loss
- Minority class contributes to ~5% of the loss
- Model optimizes mainly for majority class performance

### Evaluation Metrics Trap 📊

**Accuracy** can be misleading:
```
Dataset: 950 normal, 50 onion samples
Dumb classifier: Always predict "normal"
Accuracy: 950/1000 = 95% ✨ (Looks great!)
But: 0% recall for "onion" class 💀 (Catastrophic!)
```

### Better Metrics for Imbalanced Data:

- **Precision**: $\\frac{TP}{TP + FP}$ - Of predicted positives, how many are correct?
- **Recall**: $\\frac{TP}{TP + FN}$ - Of actual positives, how many did we find?
- **F1-Score**: $\\frac{2 \\cdot Precision \\cdot Recall}{Precision + Recall}$ - Harmonic mean
- **AUC-ROC**: Area under ROC curve - threshold-independent

### The Solution Preview 🎯

We'll use several techniques:
1. **Weighted Sampling**: Balance the training batches
2. **Proper Architecture**: CNN designed for binary classification
3. **Regularization**: Prevent overfitting to majority class
4. **Smart Training**: Monitor the right metrics


## 3. 🔧 Setting Up the Environment

Let's start by importing all necessary libraries and setting up our development environment. We'll be working with PyTorch for neural networks and various other libraries for data analysis and visualization.


In [None]:
# Essential imports for imbalanced classification tutorial
import abc
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc

# PyTorch for neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# For downloading and handling data
import gdown
import zipfile
import warnings

warnings.filterwarnings("ignore")

# Set up device - GPU greatly speeds up CNN training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ Environment setup complete!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖥️  CUDA available: {torch.cuda.is_available()}")

# Set up plotting style
plt.style.use("default")
sns.set_palette("husl")

In [None]:
# Let's create some sample data to understand the problem better
def create_sample_dataset():
    """Create a simulated imbalanced dataset for demonstration."""
    np.random.seed(42)

    # Simulate class distribution similar to our real problem
    n_majority = 950  # Class 0 (normal)
    n_minority = 50  # Class 1 (onion)

    # Create synthetic features (normally we'd have images)
    X_majority = np.random.normal(0, 1, (n_majority, 2))
    X_minority = np.random.normal(2, 1, (n_minority, 2))

    # Create labels
    y_majority = np.zeros(n_majority)
    y_minority = np.ones(n_minority)

    # Combine data
    X = np.vstack([X_majority, X_minority])
    y = np.hstack([y_majority, y_minority])

    return X, y


# Create sample dataset
X_sample, y_sample = create_sample_dataset()

print(f"📊 Sample dataset created:")
print(f"   Total samples: {len(X_sample)}")
print(f"   Class 0 (majority): {np.sum(y_sample == 0)} samples")
print(f"   Class 1 (minority): {np.sum(y_sample == 1)} samples")
print(f"   Imbalance ratio: {np.sum(y_sample == 0) / np.sum(y_sample == 1):.1f}:1")

# Visualize the imbalanced dataset
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Scatter plot
colors = ["blue", "red"]
for i, class_name in enumerate(["Normal", "Onion"]):
    mask = y_sample == i
    ax1.scatter(
        X_sample[mask, 0],
        X_sample[mask, 1],
        c=colors[i],
        label=f"{class_name} (n={np.sum(mask)})",
        alpha=0.7,
        s=20,
    )

ax1.set_xlabel("Feature 1")
ax1.set_ylabel("Feature 2")
ax1.set_title("🎯 Imbalanced Dataset Visualization")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Bar plot showing class distribution
class_counts = [np.sum(y_sample == 0), np.sum(y_sample == 1)]
ax2.bar(
    ["Class 0\\n(Normal)", "Class 1\\n(Onion)"],
    class_counts,
    color=["blue", "red"],
    alpha=0.7,
)
ax2.set_ylabel("Number of Samples")
ax2.set_title("📊 Class Distribution")
ax2.grid(True, alpha=0.3)

# Add percentage labels on bars
for i, count in enumerate(class_counts):
    percentage = count / len(y_sample) * 100
    ax2.text(
        i,
        count + 10,
        f"{count}\\n({percentage:.1f}%)",
        ha="center",
        va="bottom",
        fontweight="bold",
    )

plt.tight_layout()
plt.show()

## 4. 📈 Analyzing Class Distribution

Now let's set up our data loading and analyze the real class distribution in our problem. This is crucial for understanding the scope of the imbalance.


In [None]:
# Let's simulate the data loading process (similar to the actual problem)
# In the real problem, you would download data from Google Drive
# Here we'll create a mock dataset class to demonstrate the concepts


class MockImageDataset(torch.utils.data.Dataset):
    """Mock dataset class that simulates the real imbalanced image dataset."""

    def __init__(self, dataset_type: str, imbalance_ratio: float = 19.0):
        """
        Initialize mock dataset.

        Args:
            dataset_type: "train" or "valid"
            imbalance_ratio: Ratio of majority to minority class
        """
        self.dataset_type = dataset_type

        # Simulate realistic dataset sizes
        if dataset_type == "train":
            total_samples = 1000
            minority_samples = int(total_samples / (imbalance_ratio + 1))
            majority_samples = total_samples - minority_samples
        else:  # validation - balanced
            total_samples = 100
            minority_samples = majority_samples = total_samples // 2

        # Create mock file paths and labels
        self.filelist = []
        self.labels = []

        # Add majority class samples (normal)
        for i in range(majority_samples):
            self.filelist.append(f"{dataset_type}_data/normal_{i}.jpg")
            self.labels.append(0)

        # Add minority class samples (onion)
        for i in range(minority_samples):
            self.filelist.append(f"{dataset_type}_data/onion_{i}.jpg")
            self.labels.append(1)

        print(f"📊 {dataset_type.capitalize()} dataset created:")
        print(f"   Total samples: {len(self.labels)}")
        print(f"   Class 0 (normal): {self.labels.count(0)} samples")
        print(f"   Class 1 (onion): {self.labels.count(1)} samples")
        if self.labels.count(1) > 0:
            ratio = self.labels.count(0) / self.labels.count(1)
            print(f"   Imbalance ratio: {ratio:.1f}:1")

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        # In real dataset, we would load actual images
        # Here we create mock tensors that simulate grayscale images
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Create a mock image (224x224 grayscale)
        if self.labels[idx] == 0:  # normal class
            # Simulate geometric shapes with noise
            image = torch.randn(1, 224, 224) * 0.1 + 0.5
        else:  # onion class
            # Simulate layered patterns with noise
            image = torch.randn(1, 224, 224) * 0.1 + 0.7

        label = self.labels[idx]
        return image, label

    def loader(self, **kwargs):
        """Create DataLoader for this dataset."""
        return torch.utils.data.DataLoader(self, **kwargs)


# Create mock datasets
print("🔄 Creating mock datasets...")
train_dataset = MockImageDataset("train", imbalance_ratio=19.0)
print()
valid_dataset = MockImageDataset("valid", imbalance_ratio=1.0)  # Balanced validation

In [None]:
# Let's analyze the class distribution in detail
def analyze_class_distribution(dataset, name):
    """Analyze and visualize class distribution in a dataset."""

    # Count classes
    class_counts = Counter()
    for _, label in dataset:
        class_counts[label] += 1

    total = sum(class_counts.values())
    print(f"\\n📊 {name} Dataset Analysis:")
    print(f"   Total samples: {total}")

    for class_id in sorted(class_counts.keys()):
        count = class_counts[class_id]
        percentage = count / total * 100
        class_name = "Normal" if class_id == 0 else "Onion"
        print(
            f"   Class {class_id} ({class_name}): {count} samples ({percentage:.1f}%)"
        )

    if len(class_counts) == 2:
        ratio = class_counts[0] / class_counts[1]
        print(f"   Imbalance ratio: {ratio:.2f}:1")

    return class_counts


# Analyze both datasets
train_counts = analyze_class_distribution(train_dataset, "Training")
valid_counts = analyze_class_distribution(valid_dataset, "Validation")

# Visualize the distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Training dataset distribution
train_classes = ["Normal", "Onion"]
train_values = [train_counts[0], train_counts[1]]
bars1 = ax1.bar(train_classes, train_values, color=["skyblue", "salmon"], alpha=0.8)
ax1.set_title("🏋️ Training Dataset\\n(Imbalanced)", fontsize=14, fontweight="bold")
ax1.set_ylabel("Number of Samples")
ax1.grid(True, alpha=0.3)

# Add value labels on bars
for bar, value in zip(bars1, train_values):
    height = bar.get_height()
    ax1.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 5,
        f"{value}\\n({value/sum(train_values)*100:.1f}%)",
        ha="center",
        va="bottom",
        fontweight="bold",
    )

# Validation dataset distribution
valid_classes = ["Normal", "Onion"]
valid_values = [valid_counts[0], valid_counts[1]]
bars2 = ax2.bar(valid_classes, valid_values, color=["skyblue", "salmon"], alpha=0.8)
ax2.set_title("🎯 Validation Dataset\\n(Balanced)", fontsize=14, fontweight="bold")
ax2.set_ylabel("Number of Samples")
ax2.grid(True, alpha=0.3)

# Add value labels on bars
for bar, value in zip(bars2, valid_values):
    height = bar.get_height()
    ax2.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 1,
        f"{value}\\n({value/sum(valid_values)*100:.1f}%)",
        ha="center",
        va="bottom",
        fontweight="bold",
    )

plt.tight_layout()
plt.show()

print("\\n💡 Key Observation:")
print("   Training data is highly imbalanced, but validation data is balanced!")
print(
    "   This means our model must generalize from imbalanced to balanced distributions."
)

## 5. 🏗️ CNN Architecture for Binary Classification

Now let's design a Convolutional Neural Network specifically for binary image classification. Our architecture needs to be robust enough to handle the challenges of imbalanced data.

### Architecture Design Principles:

1. **Feature Extraction**: Convolutional layers to detect patterns
2. **Dimensionality Reduction**: Pooling layers to reduce computational cost
3. **Regularization**: Dropout to prevent overfitting
4. **Binary Output**: Single neuron with sigmoid activation for probability

### The Challenge:
- Input: 224×224 grayscale images (1 channel)
- Output: Binary classification probability [0, 1]
- Goal: Distinguish between geometric shapes ("normal") and layered patterns ("onion")


In [None]:
# First, let's understand the abstract base class we need to implement
class CnnClassifier(torch.nn.Module, abc.ABC):
    """
    Abstract base class for CNN classifiers.
    This defines the interface our solution must implement.
    """

    MODEL_PATH: str = "cnn-classifier.pth"

    @classmethod
    def load(cls):
        """Load model from file."""
        model = cls()
        model.load_state_dict(torch.load(cls.MODEL_PATH, map_location=device))
        return model

    @classmethod
    @abc.abstractmethod
    def create_with_training(cls):
        """Train model and save to file."""
        pass


print("✅ Base class defined!")
print("   Our CNN must inherit from CnnClassifier")
print("   Must implement: create_with_training() method")
print("   Must save model weights to: cnn-classifier.pth")

In [None]:
# Let's build our CNN architecture step by step


class BasicCNN(CnnClassifier):
    """
    A basic CNN for binary image classification.

    This will serve as our foundation for understanding CNN architecture
    before we add imbalanced data techniques.
    """

    def __init__(self):
        super().__init__()

        # Calculate dimensions after each layer
        # Input: 224x224x1

        # First convolutional block
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5)
        # After conv1: (224-5+1) = 220x220x8
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # After pool1: 220/2 = 110x110x8

        # Second convolutional block
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5)
        # After conv2: (110-5+1) = 106x106x8
        self.pool2 = nn.MaxPool2d(kernel_size=4, stride=4)
        # After pool2: 106/4 = 26x26x8 (rounded down)

        # Calculate flattened size: 26*26*8 = 5408
        self.flatten = nn.Flatten()

        # Fully connected layers
        self.fc1 = nn.Linear(in_features=5408, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=1)

        # Activation functions and regularization
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=0.5)

        print("🏗️ Basic CNN Architecture:")
        print("   Conv1: 1→8 channels, 5x5 kernel")
        print("   Pool1: 2x2 MaxPool")
        print("   Conv2: 8→8 channels, 5x5 kernel")
        print("   Pool2: 4x4 MaxPool")
        print("   FC1: 5408→256 neurons")
        print("   FC2: 256→1 neuron (binary output)")
        print("   Regularization: 50% Dropout")

    def forward(self, x):
        """Forward pass through the network."""
        # First conv block
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.dropout(x)

        # Second conv block
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = self.dropout(x)

        # Flatten and fully connected layers
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)  # Binary probability output

        return x

    @classmethod
    def create_with_training(cls):
        """Basic training (we'll improve this later for imbalanced data)."""
        return cls()  # Placeholder - no actual training yet


# Test our architecture
basic_model = BasicCNN().to(device)

# Test with a sample input
with torch.no_grad():
    sample_input = torch.randn(2, 1, 224, 224).to(device)  # Batch of 2 images
    output = basic_model(sample_input)
    print(f"\n✅ Architecture test successful!")
    print(f"   Input shape: {sample_input.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   Output values: {output.flatten().cpu().numpy()}")
    print(f"   Output range: [{output.min():.3f}, {output.max():.3f}]")

# Count parameters
total_params = sum(p.numel() for p in basic_model.parameters())
trainable_params = sum(p.numel() for p in basic_model.parameters() if p.requires_grad)
print(f"\n📊 Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB (FP32)")

## 6. ⚖️ Weighted Sampling Techniques

The key to handling imbalanced data is **weighted sampling**. Instead of training on the raw imbalanced dataset, we'll create balanced batches during training.

### The Weighted Sampling Strategy:

1. **Calculate Class Weights**: Inverse frequency weighting
2. **Assign Sample Weights**: Each sample gets a weight based on its class
3. **Weighted Random Sampler**: PyTorch's built-in balanced sampling
4. **Balanced Batches**: Each training batch has roughly equal class representation

### Mathematical Foundation:

For a dataset with $n_0$ samples of class 0 and $n_1$ samples of class 1:

$$w_0 = \frac{1}{n_0}, \quad w_1 = \frac{1}{n_1}$$

Since $n_0 >> n_1$, we have $w_1 >> w_0$, meaning minority class samples are selected more frequently during training.


In [None]:
# Let's implement weighted sampling step by step


def analyze_and_create_weights(dataset):
    """
    Analyze class distribution and create sample weights for balanced sampling.

    Returns:
        torch.Tensor: Weight for each sample in the dataset
        dict: Class analysis information
    """

    print("🔍 Analyzing class distribution...")

    # Count classes by iterating through dataset
    class_counts = {0: 0, 1: 0}
    labels = []

    for idx in range(len(dataset)):
        _, label = dataset[idx]
        labels.append(label)
        class_counts[label] += 1

    # Calculate class weights (inverse frequency)
    total_samples = len(labels)
    weight_class_0 = 1.0 / class_counts[0]
    weight_class_1 = 1.0 / class_counts[1]

    print(f"   Class 0 (Normal): {class_counts[0]} samples")
    print(f"   Class 1 (Onion):  {class_counts[1]} samples")
    print(f"   Total samples: {total_samples}")
    print(f"   Imbalance ratio: {class_counts[0]/class_counts[1]:.2f}:1")
    print(f"   Weight for class 0: {weight_class_0:.6f}")
    print(f"   Weight for class 1: {weight_class_1:.6f}")
    print(f"   Weight ratio: {weight_class_1/weight_class_0:.2f}:1")

    # Assign weight to each sample based on its class
    sample_weights = torch.zeros(total_samples)
    for idx, label in enumerate(labels):
        if label == 0:
            sample_weights[idx] = weight_class_0
        else:
            sample_weights[idx] = weight_class_1

    analysis_info = {
        "class_counts": class_counts,
        "total_samples": total_samples,
        "weight_class_0": weight_class_0,
        "weight_class_1": weight_class_1,
        "imbalance_ratio": class_counts[0] / class_counts[1],
    }

    return sample_weights, analysis_info


# Analyze our training dataset
sample_weights, analysis = analyze_and_create_weights(train_dataset)

print(f"\n📊 Sample weights created:")
print(f"   Weights shape: {sample_weights.shape}")
print(f"   Unique weights: {torch.unique(sample_weights).tolist()}")
print(f"   Weight distribution:")
print(
    f"     - {(sample_weights == analysis['weight_class_0']).sum()} samples with weight {analysis['weight_class_0']:.6f}"
)
print(
    f"     - {(sample_weights == analysis['weight_class_1']).sum()} samples with weight {analysis['weight_class_1']:.6f}"
)

In [None]:
# Now let's create a weighted sampler and test it


def create_balanced_dataloader(dataset, sample_weights, batch_size=32):
    """
    Create a DataLoader with weighted sampling for balanced batches.

    Args:
        dataset: The imbalanced dataset
        sample_weights: Weight for each sample
        batch_size: Size of each batch

    Returns:
        DataLoader with balanced sampling
    """

    # Calculate number of samples per epoch
    # We want full batches, so we'll use a round number
    total_samples = len(dataset)
    samples_per_epoch = batch_size * (total_samples // batch_size)

    print(f"🎯 Creating balanced DataLoader:")
    print(f"   Original dataset size: {total_samples}")
    print(f"   Samples per epoch: {samples_per_epoch}")
    print(f"   Batch size: {batch_size}")
    print(f"   Batches per epoch: {samples_per_epoch // batch_size}")

    # Create weighted random sampler
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=sample_weights,
        num_samples=samples_per_epoch,
        replacement=True,  # Allow sampling with replacement
    )

    # Create DataLoader with the sampler
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        drop_last=True,  # Ensure all batches have same size
    )

    return loader


# Create balanced DataLoader
balanced_loader = create_balanced_dataloader(
    train_dataset, sample_weights, batch_size=32
)

# Test the balanced sampling by analyzing a few batches
print(f"\n🧪 Testing balanced sampling:")
class_counts_in_batches = []

for batch_idx, (images, labels) in enumerate(balanced_loader):
    if batch_idx >= 5:  # Test first 5 batches
        break

    # Count classes in this batch
    unique, counts = torch.unique(labels, return_counts=True)
    batch_class_counts = {int(u): int(c) for u, c in zip(unique, counts)}

    # Fill in missing classes
    for class_id in [0, 1]:
        if class_id not in batch_class_counts:
            batch_class_counts[class_id] = 0

    class_counts_in_batches.append(batch_class_counts)

    print(
        f"   Batch {batch_idx + 1}: Class 0: {batch_class_counts[0]}, Class 1: {batch_class_counts[1]}"
    )

# Calculate average class distribution in batches
avg_class_0 = np.mean([counts[0] for counts in class_counts_in_batches])
avg_class_1 = np.mean([counts[1] for counts in class_counts_in_batches])

print(f"\n📊 Average class distribution in balanced batches:")
print(f"   Class 0 (Normal): {avg_class_0:.1f} samples per batch")
print(f"   Class 1 (Onion):  {avg_class_1:.1f} samples per batch")
print(f"   Balance ratio: {avg_class_0/avg_class_1 if avg_class_1 > 0 else 'inf'}:1")
print(f"   🎉 Much more balanced than original {analysis['imbalance_ratio']:.1f}:1!")

## 8. 🧠 Building the Complete Solution

Now let's put everything together and build our complete CNN classifier that handles imbalanced data properly. This will be the full implementation that matches the problem requirements.


In [None]:
# Complete implementation of our CNN classifier for imbalanced data


class YourCnnClassifier(CnnClassifier):
    """
    Professional CNN classifier for imbalanced classification.

    This implementation includes:
    - Proper CNN architecture for binary classification
    - Weighted sampling for handling imbalanced data
    - Regularization techniques to prevent overfitting
    - Complete training pipeline with progress monitoring
    """

    def __init__(self):
        """Initialize the CNN architecture."""
        super().__init__()

        # Define the network architecture using Sequential
        self.network = torch.nn.Sequential(
            # First convolutional block: 1 -> 8 channels
            torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Dropout(p=0.5),
            # Second convolutional block: 8 -> 8 channels
            torch.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=4, stride=4),
            torch.nn.Dropout(p=0.5),
            # Flatten for fully connected layers
            torch.nn.Flatten(),
            # Fully connected layers
            torch.nn.Linear(in_features=5408, out_features=256),
            torch.nn.Dropout(p=0.5),
            torch.nn.ReLU(),
            # Output layer (binary classification)
            torch.nn.Linear(in_features=256, out_features=1),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        """Forward pass through the network."""
        return self.network(x)

    @classmethod
    def create_with_training(cls):
        """
        Create and train the model with imbalanced data handling.

        This method implements the complete training pipeline:
        1. Class distribution analysis
        2. Weighted sampling setup
        3. Model training with proper techniques
        4. Model saving
        """
        # Initialize model
        model = cls().to(device)

        print("🚀 Starting training with imbalanced data handling...")

        # === STEP 1: ANALYZE CLASS DISTRIBUTION ===
        print("\n📊 Step 1: Analyzing class distribution")
        class_counts = {0: 0, 1: 0}

        # Count classes in training dataset
        for idx in range(len(train_dataset)):
            _, label = train_dataset[idx]
            class_counts[label] += 1

        total_samples = sum(class_counts.values())
        print(
            f"   Class 0 (Normal): {class_counts[0]} samples ({class_counts[0]/total_samples*100:.1f}%)"
        )
        print(
            f"   Class 1 (Onion):  {class_counts[1]} samples ({class_counts[1]/total_samples*100:.1f}%)"
        )
        print(f"   Imbalance ratio: {class_counts[0]/class_counts[1]:.2f}:1")

        # === STEP 2: CREATE WEIGHTED SAMPLING ===
        print("\n⚖️ Step 2: Setting up weighted sampling")

        # Calculate class weights (inverse frequency)
        weight_class_0 = 1.0 / class_counts[0]
        weight_class_1 = 1.0 / class_counts[1]

        print(f"   Weight for class 0: {weight_class_0:.6f}")
        print(f"   Weight for class 1: {weight_class_1:.6f}")
        print(f"   Weight ratio: {weight_class_1/weight_class_0:.2f}:1")

        # Assign weights to each sample
        sample_weights = torch.zeros(total_samples)
        sample_idx = 0

        for idx in range(len(train_dataset)):
            _, label = train_dataset[idx]
            if label == 0:
                sample_weights[sample_idx] = weight_class_0
            else:
                sample_weights[sample_idx] = weight_class_1
            sample_idx += 1

        # === STEP 3: SETUP TRAINING CONFIGURATION ===
        print("\n🔧 Step 3: Configuring training parameters")

        batch_size = 32
        n_epochs = 20
        learning_rate = 0.001

        print(f"   Batch size: {batch_size}")
        print(f"   Number of epochs: {n_epochs}")
        print(f"   Learning rate: {learning_rate}")

        # Create balanced DataLoader
        samples_per_epoch = batch_size * (total_samples // batch_size)
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=sample_weights, num_samples=samples_per_epoch, replacement=True
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, sampler=sampler, drop_last=True
        )

        # Setup optimizer and loss function
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = torch.nn.BCELoss()

        # === STEP 4: TRAINING LOOP ===
        print("\n🏋️ Step 4: Training the model")

        model.train()

        for epoch in range(n_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for batch_idx, (images, labels) in enumerate(train_loader):
                # Move data to device
                images, labels = images.to(device), labels.to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                predictions = model(images).flatten()

                # Calculate loss
                loss = criterion(predictions, labels.float())

                # Backward pass
                loss.backward()

                # Update weights
                optimizer.step()

                # Track loss
                epoch_loss += loss.item()
                batch_count += 1

                # Progress update every 50 batches
                if batch_idx % 50 == 0:
                    print(
                        f"      Epoch [{epoch+1}/{n_epochs}], Batch [{batch_idx}], Loss: {loss.item():.4f}"
                    )

            # End of epoch statistics
            avg_loss = epoch_loss / batch_count
            print(
                f"   ✅ Epoch {epoch+1}/{n_epochs} completed. Average loss: {avg_loss:.4f}"
            )

        # === STEP 5: SAVE MODEL ===
        print("\n💾 Step 5: Saving trained model")
        torch.save(model.state_dict(), cls.MODEL_PATH)
        print(f"   Model saved to: {cls.MODEL_PATH}")

        print("\n🎉 Training completed successfully!")
        return model


# Test our complete implementation
print("🧪 Testing complete CNN classifier implementation...")
complete_model = YourCnnClassifier().to(device)

# Test forward pass
with torch.no_grad():
    sample_input = torch.randn(4, 1, 224, 224).to(device)
    output = complete_model(sample_input)
    print(f"\n✅ Complete model test successful!")
    print(f"   Input shape: {sample_input.shape}")
    print(f"   Output shape: {output.shape}")
    print(f"   Output range: [{output.min():.3f}, {output.max():.3f}]")
    print(f"   Sample predictions: {output.flatten().cpu().numpy()}")

## 10. 🎮 Interactive Exercises

Now it's your turn to experiment and deepen your understanding! Try these challenges to master imbalanced classification techniques.

### 🎯 Exercise 1: Experiment with Different Imbalance Ratios

Try training models with different levels of class imbalance and observe how it affects performance.


In [None]:
# 🧪 Exercise: Compare different imbalance ratios


def create_imbalanced_dataset(imbalance_ratio, total_size=1000):
    """Create datasets with different imbalance ratios for experimentation."""
    minority_samples = int(total_size / (imbalance_ratio + 1))
    majority_samples = total_size - minority_samples

    return (
        MockImageDataset("train", imbalance_ratio),
        majority_samples,
        minority_samples,
    )


# Test different imbalance ratios
ratios_to_test = [1.0, 5.0, 10.0, 20.0]  # From balanced to highly imbalanced

print("🧪 Experimenting with different imbalance ratios:")
print("=" * 60)

for ratio in ratios_to_test:
    print(f"\n📊 Testing imbalance ratio: {ratio}:1")

    # Create dataset with this ratio
    test_dataset, maj_count, min_count = create_imbalanced_dataset(ratio, 1000)

    print(f"   Majority class: {maj_count} samples")
    print(f"   Minority class: {min_count} samples")

    # Create weights for this dataset
    weights, analysis = analyze_and_create_weights(test_dataset)

    print(
        f"   Weight ratio: {analysis['weight_class_1']/analysis['weight_class_0']:.2f}:1"
    )

    # Create balanced loader
    balanced_loader = create_balanced_dataloader(test_dataset, weights, batch_size=32)

    # Test a few batches
    class_balance_in_batches = []
    for batch_idx, (_, labels) in enumerate(balanced_loader):
        if batch_idx >= 3:
            break
        unique, counts = torch.unique(labels, return_counts=True)
        batch_counts = {int(u): int(c) for u, c in zip(unique, counts)}
        for class_id in [0, 1]:
            if class_id not in batch_counts:
                batch_counts[class_id] = 0
        class_balance_in_batches.append(batch_counts)

    avg_0 = np.mean([counts[0] for counts in class_balance_in_batches])
    avg_1 = np.mean([counts[1] for counts in class_balance_in_batches])

    print(f"   Avg batch composition: {avg_0:.1f} vs {avg_1:.1f}")
    print(f"   Batch balance ratio: {avg_0/avg_1 if avg_1 > 0 else 'inf'}:1")

print("\n💡 Key Observations:")
print("   - Higher imbalance ratios require stronger reweighting")
print(
    "   - Weighted sampling creates balanced batches regardless of original imbalance"
)
print(
    "   - The minority class gets sampled much more frequently in highly imbalanced cases"
)

## 12. 📖 Summary and Next Steps

Congratulations! 🎉 You've mastered the fundamentals of imbalanced classification with deep learning!

### What You've Learned:

1. **🎯 Imbalanced Data Challenges**:
   - Why traditional methods fail with imbalanced datasets
   - The importance of proper evaluation metrics
   - Real-world examples of imbalanced classification problems

2. **🏗️ CNN Architecture Design**:
   - Building CNNs for binary image classification
   - Proper layer dimensions and parameter calculations
   - Regularization techniques (dropout) to prevent overfitting

3. **⚖️ Weighted Sampling Techniques**:
   - Inverse frequency weighting for class balance
   - PyTorch's WeightedRandomSampler implementation
   - Creating balanced batches from imbalanced datasets

4. **🧠 Complete Solution Implementation**:
   - End-to-end training pipeline
   - Progress monitoring and debugging
   - Model saving and loading mechanisms

### For the Actual Problem Implementation:

You now have all the knowledge to implement the complete solution! The key components are:

```python
class YourCnnClassifier(CnnClassifier):
    def __init__(self):
        # CNN architecture with proper dimensions
        self.network = torch.nn.Sequential(
            # Conv layers with ReLU, MaxPool, Dropout
            # FC layers with regularization
            # Sigmoid output for binary classification
        )
    
    @classmethod
    def create_with_training(cls):
        # 1. Analyze class distribution
        # 2. Create weighted sampling
        # 3. Set up balanced DataLoader
        # 4. Train with proper loss function
        # 5. Save model weights
        return trained_model
```

### 🚀 Advanced Techniques to Explore:

- **Focal Loss**: Alternative loss function for imbalanced data
- **SMOTE**: Synthetic minority oversampling technique
- **Ensemble Methods**: Combining multiple models
- **Transfer Learning**: Using pre-trained models
- **Data Augmentation**: Generating synthetic samples

### 📚 Useful Resources:

- 📖 [Imbalanced Learn Documentation](https://imbalanced-learn.org/)
- 🛠️ [PyTorch WeightedRandomSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler)
- 📑 [Research: Learning from Imbalanced Data](https://link.springer.com/article/10.1007/s10115-007-0089-4)
- 🎨 [Evaluation Metrics for Imbalanced Classification](https://machinelearningmastery.com/metrics-evaluate-imbalanced-classification/)

### 💡 Key Takeaways:

1. **Weighted Sampling is Crucial**: Balance your training batches, not just your dataset
2. **Proper Evaluation Matters**: Use precision, recall, and F1-score, not just accuracy
3. **Regularization is Essential**: Prevent overfitting with dropout and proper architecture
4. **Monitor Training Progress**: Track loss and class distribution in batches
5. **Test on Balanced Data**: Your model should generalize to balanced test sets

**Good luck with your implementation!** 🌟

Remember: The key insight is that **training on balanced batches** (via weighted sampling) allows your model to learn both classes effectively, even when the original dataset is highly imbalanced. This technique is fundamental to many real-world applications where class imbalance is common.
