## CNN Architecture

#### basic: 2 Convolutional Layers, Max Pooling, and 2 Fully Connected Layers.

In [None]:
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    # in_channels=3 (RGB), out_channels=16 filters
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
    
    # in_channels=16, out_channels=32 filters
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    
    # Max Pooling: downsample by factor of 2
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    # Fully Connected 1: input 7*7*32 -> output 128
    self.fc1 = nn.Linear(in_features= 7 * 7 * 32, out_features=128)
    
    # Fully Connected 2: input 128 -> output 4 classes
    self.fc2 = nn.Linear(in_features=128, out_features=4)

    # Activation
    self.relu = nn.ReLU()

  def forward(self, x):
    # Apply convolution + ReLU + pooling
    x = self.conv1(x)
    x = self.relu(x)
    x = self.pool(x)

    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)

    # Flatten the feature maps 
    x = x.view(-1, 7 * 7 * 32)

    # Fully connected layers
    x = self.fc1(x)
    x = self.relu(x)
    
    # Output layer (no activation, CrossEntropyLoss handles softmax)
    x = self.fc2(x)
    
    return x

## Exploratory Data Analysis

### 1. Data balancing

####  1.1 check the data for imbalances to prevent model bias

In [None]:
import os
import matplotlib.pyplot as plt

data_dir = "data/train"

# remove .DS_Store
classes = [c for c in os.listdir(data_dir) 
           if os.path.isdir(os.path.join(data_dir, c))]

class_counts = {
    cls: len([f for f in os.listdir(os.path.join(data_dir, cls)) 
              if f.endswith('.jpg')])
    for cls in classes
}

# ---- Plot ----
plt.rcParams['font.family'] = 'sans-serif' 
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] # 优先使用 Arial

# ---- Data preparation ----
keys = list(class_counts.keys())
values = list(class_counts.values())

# ---- create convas ----
fig, ax = plt.subplots(figsize=(9, 6), dpi=100) # dpi=100 让图片更清晰


bars = ax.bar(keys, values, color='#4C72B0', width=0.6, zorder=3, alpha=0.9)

# 1. Titles and axis labels
ax.set_title("Class Distribution", fontsize=16, fontweight='bold', pad=20, color='#333333')
ax.set_xlabel("Class Name", fontsize=12, labelpad=10, color='#555555')
ax.set_ylabel("Count", fontsize=12, labelpad=10, color='#555555')

# 2. Grid lines: Only the horizontal grid along the y-axis is retained; dashed lines, light gray.
ax.grid(axis='y', linestyle='--', alpha=0.5, zorder=0)

# 3. Remove unwanted borders (Spines)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False) 
ax.spines['bottom'].set_color('#CCCCCC') 

# 4. Scale processing
ax.tick_params(axis='x', length=0) 
ax.tick_params(axis='y', length=0) 

# ----Add numbers to the table ----
for bar in bars:
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2, 
        height + (max(values) * 0.01), 
        f"{int(height)}", 
        ha='center', va='bottom', 
        fontsize=11, fontweight='bold', color='#4C72B0'
    )

plt.tight_layout()
plt.show()


#### 1.2 split training data into train & valid 

（To ensure that the train and value sets use completely different transforms, preventing the validation set from being "enhanced and polluted,"）

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms
import numpy as np

#basic transform（for val / test）
base_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) #normalization
])

# build a ImageFolder with “no transform”，to split train/val
base_dataset = datasets.ImageFolder(root=data_dir, transform=None)
class_names = base_dataset.classes
print("Classes:", class_names)    # ['Apple', 'Banana', 'Mix', 'Orange'] 

# split training dataset 8/2
val_ratio = 0.2
num_total = len(base_dataset)
num_val = int(num_total * val_ratio)
num_train = num_total - num_val

generator = torch.Generator().manual_seed(42)   # random seed
train_subset, val_subset = random_split(base_dataset, [num_train, num_val], generator=generator)


Define a "wrapper Dataset" to apply different transforms as needed.

In [None]:
class TransformedDataset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset          #  The Subset return from random_split
        self.transform = transform    # Pass in the desired transforms.Compose

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, label = self.subset[idx]   #return image, label from base_dataset(original training)
        if self.transform:
            img = self.transform(img)
        return img, label


Construct a validation set (which remains constant across all experimental scenarios).

In [None]:
val_dataset = TransformedDataset(val_subset, base_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


#### 1.3 Experiment

1.3.1 Baseline (no enhancements, unbalanced)

In [None]:
# 训练用的“无增强” transform
train_transform_no_aug = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

train_dataset_no_aug = TransformedDataset(train_subset, train_transform_no_aug)
train_loader_no_aug = DataLoader(train_dataset_no_aug, batch_size=32, shuffle=True)


1.3.2 Unified lightweight enhancement (but still unbalanced)

In [None]:
train_transform_aug_all = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.RandomHorizontalFlip(p=0.5),  #Flip horizontally (or left-right) with a 50% probability.
    transforms.RandomRotation(degrees=15),  #Randomly rotate the image, ranging from -15° to +15°.
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), #Randomly change brightness, contrast, and saturation
    transforms.ToTensor(),
])

train_dataset_aug_all = TransformedDataset(train_subset, train_transform_aug_all)
train_loader_aug_all = DataLoader(train_dataset_aug_all, batch_size=32, shuffle=True)
