# Assignment 4: Flower Classification with Pretrained CNNs

This notebook uses YOLOv5 to classify flower types from the Oxford Flowers-102 dataset.

## Requirements:
- PyTorch
- Models: YOLOv5s-cls
- Dataset: Oxford Flowers-102
- Random split: 50% train / 25% val / 25% test
- Two different random seeds for repeated experiments
- Probabilistic output: softmax probabilities for each class

## CONFIGURATION & ENVIRONMENT

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import shutil
import numpy as np
import scipy.io
import tarfile
from sklearn.model_selection import train_test_split
from IPython.display import display
from tqdm.notebook import tqdm
import torch.nn.functional as F
from PIL import Image
import random

BATCH_SIZE = 64
EPOCHS = 20
LR = 0.001
IMG_SIZE = 224
NUM_CLASSES = 102
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Ensure we are in the yolov5 directory
if not os.path.exists('yolov5') and not os.path.exists('/content/yolov5'):
    print("YOLOv5 not found. Cloning repository...")
    !git clone https://github.com/ultralytics/yolov5
    !pip install -qr yolov5/requirements.txt

if os.path.exists('yolov5'):
    os.chdir('yolov5')
elif os.path.exists('/content/yolov5'):
    os.chdir('/content/yolov5')

## DATA PREPROCESSING

The preprocessing consists of two stages: Physical Organization and Input Transformation.

A. Physical Data Organization (prepare_dataset function):
   1. Download: Fetches '102flowers.tgz' (images) and 'imagelabels.mat' (labels).
   2. Randomized Stratified Split: We did a Stratified Random Split to divide the data into 50% Train, 25% Validation, and 25% Test.
      - 'Stratified' ensures every flower category is represented in every split.
   3. Repetition: To ensure robustness, we repeated the entire splitting and training process twice using different random seeds.
                  This created two completely different random arrangements of the data to verify that the model's accuracy is stable and not due to a lucky split.
   4. Restructuring: Copies images into class-specific folders (required by PyTorch ImageFolder):
      datasets/run_name/ [train|val|test] / [class_id] / image.jpg

B. Input Transformation (On-the-Fly transforms):
   1. Training (Augmentation):
      - RandomResizedCrop(224): Randomly crops and scales part of the image. Forces model to learn parts.
      - RandomHorizontalFlip: Makes model invariant to direction.
      - ToTensor & Normalize: Converts to [0,1] tensor and standardizes using ImageNet mean/std.
   2. Validation/Test:
      - Resize(256) -> CenterCrop(224): Deterministic sizing for consistent evaluation.

In [2]:
# Transforms (B. Input Transformation)
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [3]:
def prepare_dataset(run_name, seed):
    """(A. Physical Organization) Downloads and splits dataset."""
    raw_dir = 'flowers102_raw'
    dest_dir = f'datasets/flowers102_{run_name}'

    # 1. Download & Extract
    if not os.path.exists(raw_dir):
        print("Raw data not found. Downloading...")
        os.makedirs(raw_dir, exist_ok=True)
        urls = [
            'https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz',
            'https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
        ]
        for url in urls:
            filename = os.path.join(raw_dir, os.path.basename(url))
            torch.hub.download_url_to_file(url, filename)

        print("Extracting images...")
        with tarfile.open(os.path.join(raw_dir, '102flowers.tgz'), 'r:gz') as tar:
            tar.extractall(raw_dir)

    # 2. Create Split
    if os.path.exists(dest_dir):
        print(f"Dataset split '{dest_dir}' already exists.")
        return dest_dir

    print(f"Creating dataset split: {run_name} (Seed {seed})...")
    labels = scipy.io.loadmat(os.path.join(raw_dir, 'imagelabels.mat'))['labels'][0]
    indices = np.arange(len(labels))

    # Stratified Split: 50% Train, 25% Val, 25% Test
    train_idx, temp_idx = train_test_split(indices, test_size=0.5, random_state=seed, stratify=labels)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=seed, stratify=labels[temp_idx])

    splits = {'train': train_idx, 'val': val_idx, 'test': test_idx}

    for split_name, split_indices in splits.items():
        for idx in split_indices:
            filename = f"image_{idx + 1:05d}.jpg"
            src = os.path.join(raw_dir, 'jpg', filename)
            label = labels[idx]
            dst = os.path.join(dest_dir, split_name, str(label))
            os.makedirs(dst, exist_ok=True)
            shutil.copy(src, os.path.join(dst, filename))

    return dest_dir

## Detailed Network Architecture (YOLOv5s-cls) — *Classification model*

This notebook uses **YOLOv5s-cls** for **image classification** (not object detection). It follows the common transfer-learning setup: a pretrained **backbone** extracts features, then a **classification head** predicts one of 102 flower classes.

### 1) Input
- Each image is preprocessed to shape $(B, 3, 224, 224)$.

### 2) Backbone (feature extractor)
- A stack of convolutional blocks turns pixels into higher-level features (edges → textures → parts).
- YOLOv5-style backbones include CSP/C3 blocks and an SPPF-style layer to combine information across scales.

### 3) Head (102 classes)
- Global average pooling + flatten produce a feature vector per image.
- The final linear layer is replaced to output **102 logits**: shape $(B, 102)$.

### 4) Probabilities
- **Training:** `CrossEntropyLoss` takes logits directly.
- **Inference:** `softmax(logits)` converts them into class probabilities that sum to 1.

In [4]:
def get_model(model_name='yolov5s-cls.pt'):
    """Loads YOLOv5s and adapts head for 102 classes."""
    # Download if missing
    if not os.path.exists(model_name):
        print(f"Downloading {model_name}...")
        torch.hub.download_url_to_file(f'https://github.com/ultralytics/yolov5/releases/download/v7.0/{model_name}', model_name)

    try:
        model = torch.hub.load('.', 'custom', path=model_name, source='local')
    except:
        model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_name)

    # Robust Head Replacement
    success = False
    def replace_layer(m):
        if isinstance(m, nn.Linear):
            return nn.Linear(m.in_features, NUM_CLASSES), True
        return m, False

    if hasattr(model, 'model') and hasattr(model.model, 'model'): # Nested wrapper
        if isinstance(model.model.model[-1].linear, nn.Linear):
             model.model.model[-1].linear = nn.Linear(model.model.model[-1].linear.in_features, NUM_CLASSES)
             success = True
    elif hasattr(model, 'model') and hasattr(model.model[-1], 'linear'): # Standard
         model.model[-1].linear = nn.Linear(model.model[-1].linear.in_features, NUM_CLASSES)
         success = True

    if not success:
        for name, m in reversed(list(model.named_modules())):
            if isinstance(m, nn.Linear):
                parent = dict(model.named_modules())[name.rsplit('.', 1)[0]]
                setattr(parent, name.rsplit('.', 1)[1], nn.Linear(m.in_features, NUM_CLASSES))
                success = True
                break

    if hasattr(model, 'nc'): model.nc = NUM_CLASSES
    return model.to(DEVICE)

## TRAINING LOOP

In [5]:
def train_and_evaluate(run_name, data_root, model_name='yolov5s-cls.pt'):
    """Runs the full training loop and returns (model, history)."""
    print(f"\n>>> Starting {run_name} (Model: {model_name}) <<<")

    train_loader = DataLoader(datasets.ImageFolder(os.path.join(data_root, 'train'), transform=train_transforms), batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(datasets.ImageFolder(os.path.join(data_root, 'val'), transform=val_transforms), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(datasets.ImageFolder(os.path.join(data_root, 'test'), transform=val_transforms), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    model = get_model(model_name)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'test_loss': [], 'test_acc': []}

    for epoch in range(EPOCHS):
        # Train
        model.train()
        r_loss, correct, total = 0.0, 0, 0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train", leave=False):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            r_loss += loss.item() * inputs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        history['train_loss'].append(r_loss/total)
        history['train_acc'].append(correct/total)

        # Eval Helper
        def evaluate(loader):
            model.eval()
            r_loss, correct, total = 0.0, 0, 0
            with torch.no_grad():
                for inputs, labels in loader:
                    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    r_loss += loss.item() * inputs.size(0)
                    correct += (outputs.argmax(1) == labels).sum().item()
                    total += labels.size(0)
            return r_loss/total, correct/total

        v_loss, v_acc = evaluate(val_loader)
        t_loss, t_acc = evaluate(test_loader)

        history['val_loss'].append(v_loss); history['val_acc'].append(v_acc)
        history['test_loss'].append(t_loss); history['test_acc'].append(t_acc)

        print(f"Epoch {epoch+1:02d}: Train Acc={history['train_acc'][-1]:.4f} | Val Acc={v_acc:.4f} | Test Acc={t_acc:.4f}")

    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(16, 6))
    ep = range(1, EPOCHS+1)
    ax[0].plot(ep, history['train_acc'], label='Train', marker='o')
    ax[0].plot(ep, history['val_acc'], label='Val', marker='o')
    ax[0].plot(ep, history['test_acc'], label='Test', marker='o', linestyle='--')
    ax[0].set_title(f'{run_name} Accuracy'); ax[0].legend(); ax[0].grid(True, alpha=0.3)

    ax[1].plot(ep, history['train_loss'], label='Train', marker='o')
    ax[1].plot(ep, history['val_loss'], label='Val', marker='o')
    ax[1].plot(ep, history['test_loss'], label='Test', marker='o', linestyle='--')
    ax[1].set_title(f'{run_name} Cross Entropy Loss'); ax[1].legend(); ax[1].grid(True, alpha=0.3)

    display(fig)
    plt.close(fig)
    return model, history

## RUN 1

In [None]:
print("\n--- Executing Run 1 (Seed 42) ---")
data_path_1 = prepare_dataset('run1', seed=42)
model_run1, history_run1 = train_and_evaluate('Run 1', data_path_1)

## RUN 2

In [None]:
print("\n--- Executing Run 2 (Seed 123) ---")
data_path_2 = prepare_dataset('run2', seed=123)
model_run2, history_run2 = train_and_evaluate('Run 2', data_path_2)

## COMPARISON PLOTS

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display

if 'history_run1' in globals() and 'history_run2' in globals():
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))

    # Determine epochs from data length
    if isinstance(history_run1, dict) and 'train_acc' in history_run1:
        n_epochs = len(history_run1['train_acc'])
    else:
        n_epochs = 10 # Fallback
    ep = range(1, n_epochs + 1)

    metrics = [
        ('train_acc', 'Train Accuracy'), ('val_acc', 'Val Accuracy'), ('test_acc', 'Test Accuracy'),
        ('train_loss', 'Train Loss'),    ('val_loss', 'Val Loss'),    ('test_loss', 'Test Loss')
    ]

    for i, (metric_key, title) in enumerate(metrics):
        ax = axes[i // 3, i % 3]
        # Run 1
        if metric_key in history_run1:
            ax.plot(ep, history_run1[metric_key], label='Run 1 (Seed 42)', marker='o', linestyle='-', color='tab:blue', alpha=0.7)
        # Run 2
        if metric_key in history_run2:
            ax.plot(ep, history_run2[metric_key], label='Run 2 (Seed 123)', marker='x', linestyle='--', color='tab:orange', alpha=0.7)

        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Accuracy' if 'acc' in metric_key else 'Loss')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.suptitle("Performance Comparison: Run 1 vs Run 2 (Stability Check)", fontsize=16)
    plt.tight_layout()
    display(fig)
    plt.close(fig)
else:
    print("Cannot plot comparison: Missing history data.")

## RANDOM SAMPLE

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
import os
import random
from IPython.display import display

# --- 1. Define Flower Names Mapping ---
# Mapping from Label ID (1-102) to English Name
FLOWER_NAMES = {
    '1': 'pink primrose', '2': 'hard-leaved pocket orchid', '3': 'canterbury bells', '4': 'sweet pea', '5': 'english marigold',
    '6': 'tiger lily', '7': 'moon orchid', '8': 'bird of paradise', '9': 'monkshood', '10': 'globe thistle',
    '11': 'snapdragon', '12': "colt's foot", '13': 'king protea', '14': 'spear thistle', '15': 'yellow iris',
    '16': 'globe-flower', '17': 'purple coneflower', '18': 'peruvian lily', '19': 'balloon flower', '20': 'giant white arum lily',
    '21': 'fire lily', '22': 'pincushion flower', '23': 'fritillary', '24': 'red ginger', '25': 'grape hyacinth',
    '26': 'corn poppy', '27': 'prince of wales feathers', '28': 'stemless gentian', '29': 'artichoke', '30': 'sweet william',
    '31': 'carnation', '32': 'garden phlox', '33': 'love in the mist', '34': 'mexican aster', '35': 'alpine sea holly',
    '36': 'ruby-lipped cattleya', '37': 'cape flower', '38': 'great masterwort', '39': 'siam tulip', '40': 'lenten rose',
    '41': 'barbeton daisy', '42': 'daffodil', '43': 'sword lily', '44': 'poinsettia', '45': 'bolero deep blue',
    '46': 'wallflower', '47': 'marigold', '48': 'buttercup', '49': 'oxeye daisy', '50': 'common dandelion',
    '51': 'petunia', '52': 'wild pansy', '53': 'primula', '54': 'sunflower', '55': 'pelargonium',
    '56': 'bishop of llandaff', '57': 'gaura', '58': 'geranium', '59': 'orange dahlia', '60': 'pink-yellow dahlia',
    '61': 'cautleya spicata', '62': 'japanese anemone', '63': 'black-eyed susan', '64': 'silverbush', '65': 'californian poppy',
    '66': 'osteospermum', '67': 'spring crocus', '68': 'bearded iris', '69': 'windflower', '70': 'tree poppy',
    '71': 'gazania', '72': 'azalea', '73': 'water lily', '74': 'rose', '75': 'thorn apple',
    '76': 'morning glory', '77': 'passion flower', '78': 'lotus', '79': 'toad lily', '80': 'anthurium',
    '81': 'frangipani', '82': 'clematis', '83': 'hibiscus', '84': 'columbine', '85': 'desert-rose',
    '86': 'tree mallow', '87': 'magnolia', '88': 'cyclamen', '89': 'watercress', '90': 'canna lily',
    '91': 'hippeastrum', '92': 'bee balm', '93': 'ball moss', '94': 'foxglove', '95': 'bougainvillea',
    '96': 'camellia', '97': 'mallow', '98': 'mexican petunia', '99': 'bromelia', '100': 'blanket flower',
    '101': 'trumpet creeper', '102': 'blackberry lily'
}

def get_class_name(idx, classes_list):
    """Maps model index -> folder name -> english name."""
    if 0 <= idx < len(classes_list):
        folder_name = classes_list[idx]
        return FLOWER_NAMES.get(folder_name, f"ID {folder_name}")
    return "Unknown"

def predict_and_plot(image, label_id_str):
    """Inference helper."""
    print("Running inference...")
    if 'model_run1' not in globals() or 'model_run2' not in globals():
        print("Models not loaded. Please run the training cells above.")
        return

    # Preprocess
    try:
        if 'val_transforms' not in globals():
             print("Error: 'val_transforms' not defined.")
             return
        input_tensor = val_transforms(image).unsqueeze(0).to(DEVICE)
    except Exception as e:
        print(f"Preprocessing Error: {e}")
        return

    # Get Class Mapping (Model Index -> Folder Name)
    test_dir = 'datasets/flowers102_run1/test'
    if os.path.exists(test_dir):
        classes = sorted([d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))])
    else:
        # Fallback if folders missing
        classes = [str(i) for i in range(1, 103)]
        classes.sort()

    # Inference
    try:
        model_run1.eval()
        model_run2.eval()
        with torch.no_grad():
            out1 = F.softmax(model_run1(input_tensor), dim=1)
            out2 = F.softmax(model_run2(input_tensor), dim=1)
    except Exception as e:
        print(f"Inference Error: {e}")
        return

    # Top 3
    prob1, cat1 = torch.topk(out1, 3)
    prob2, cat2 = torch.topk(out2, 3)

    # Visualization
    print("Generating results...")
    try:
        fig, ax = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [1, 1.5]})

        true_name = FLOWER_NAMES.get(label_id_str, label_id_str)
        ax[0].imshow(image)
        ax[0].axis('off')
        ax[0].set_title(f"Input Image\nTrue Label: {true_name}", fontsize=14, color='darkgreen')

        ax[1].axis('off')
        table_data = [["Rank", "Run 1 (Seed 42)", "Run 2 (Seed 123)"]]
        for i in range(3):
            # Run 1
            idx1 = cat1[0][i].item()
            p1 = f"{prob1[0][i].item()*100:.1f}%"
            name1 = get_class_name(idx1, classes)

            # Run 2
            idx2 = cat2[0][i].item()
            p2 = f"{prob2[0][i].item()*100:.1f}%"
            name2 = get_class_name(idx2, classes)

            table_data.append([f"#{i+1}", f"{name1} ({p1})", f"{name2} ({p2})"])

        table = ax[1].table(cellText=table_data, loc='center', cellLoc='left', colWidths=[0.1, 0.45, 0.45])
        table.auto_set_font_size(False); table.set_fontsize(11); table.scale(1, 2.5)
        ax[1].set_title("Prediction Comparison", fontsize=14)

        plt.tight_layout()
        display(fig)
        plt.close(fig)
    except Exception as e:
        print(f"Plotting Error: {e}")

# --- 2. Random Test Selection ---
print("Selecting random image from test set...")
test_dir = 'datasets/flowers102_run1/test'
all_files = []
if os.path.exists(test_dir):
    for r, _, f in os.walk(test_dir):
        for file in f:
            if file.endswith('.jpg'):
                all_files.append(os.path.join(r, file))

if not all_files:
    print(f"No test images found in '{test_dir}'. Did you run the training cells?")
else:
    path = random.choice(all_files)
    img = Image.open(path).convert('RGB')
    # Folder name is the Label ID
    label_str = os.path.basename(os.path.dirname(path))
    predict_and_plot(img, label_str)