# 03. Transfer Learning (ResNet50)

## Introduction
This notebook implements Transfer Learning using a pre-trained ResNet50 model.
It is fully self-contained.

## Setup

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input

tf.random.set_seed(42)
np.random.seed(42)

2025-11-24 10:10:49.468598: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-11-24 10:10:49.636615: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-24 10:10:53.246618: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.


## 1. Data Loading
Using the same strategy as the baseline: training on the multi-class data from the test folder for demonstration.

In [3]:
import sys
import os
import glob
import pandas as pd
import random
import shutil
import logging
import json
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from skimage import measure

# --- Inlined from src.evaluation.metrics ---

def get_metrics():
    """
    Returns a list of Keras metrics.
    """
    return [
        tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
    ]

def calculate_auc(y_true, y_pred):
    """
    Calculate AUC-ROC score.
    """
    # For multi-class, we might need one-vs-rest
    if len(np.unique(y_true)) > 2:
        return roc_auc_score(y_true, y_pred, multi_class='ovr')
    return roc_auc_score(y_true, y_pred)

def calculate_f1(y_true, y_pred_classes):
    """
    Calculate F1 score.
    """
    return f1_score(y_true, y_pred_classes, average='macro')

def get_confusion_matrix(y_true, y_pred_classes):
    """
    Calculate confusion matrix.
    """
    return confusion_matrix(y_true, y_pred_classes)

def calculate_iou(y_true, y_pred, threshold=0.5):
    """
    Calculate Intersection over Union (IoU) for segmentation masks.
    y_true: Ground truth masks (0 or 1)
    y_pred: Predicted anomaly maps (0 to 1)
    """
    y_pred_bin = (y_pred > threshold).astype(int)
    y_true = y_true.astype(int)
    
    intersection = np.logical_and(y_true, y_pred_bin).sum()
    union = np.logical_or(y_true, y_pred_bin).sum()
    
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
        
    return intersection / union

def calculate_pro(y_true, y_pred, threshold=0.5):
    """
    Calculate Per-Region Overlap (PRO).
    Average coverage of each connected component in ground truth.
    """
    y_pred_bin = (y_pred > threshold).astype(int)
    y_true = y_true.astype(int)
    
    # Label connected components in ground truth
    labeled_gt, num_features = measure.label(y_true, return_num=True, connectivity=2)
    
    if num_features == 0:
        # No anomalies in ground truth
        # If prediction is also empty, perfect. If prediction has noise, it's a false positive.
        # PRO is typically defined on anomalous regions. 
        # We return 1.0 if no anomalies exist (perfect coverage of "nothing").
        return 1.0
        
    pro_scores = []
    for region_idx in range(1, num_features + 1):
        region_mask = (labeled_gt == region_idx)
        region_area = region_mask.sum()
        
        # Intersection of prediction with this region
        intersection = np.logical_and(region_mask, y_pred_bin).sum()
        
        coverage = intersection / region_area
        pro_scores.append(coverage)
        
    return np.mean(pro_scores)

# --- Inlined from src.evaluation.benchmark ---

logger = logging.getLogger(__name__)

def evaluate_model(model, test_dataset, test_masks=None):
    """
    Evaluates a model on the test dataset.
    """
    logger.info("Evaluating model...")
    results = model.evaluate(test_dataset, return_dict=True)
    
    # Get predictions for advanced metrics
    y_true = []
    y_pred = []
    
    for images, labels in test_dataset:
        preds = model.predict(images, verbose=0)
        y_true.extend(labels.numpy())
        y_pred.extend(preds)
        
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred_classes = np.argmax(y_pred, axis=1)
    
    # Calculate additional metrics
    try:
        auc = calculate_auc(y_true, y_pred) # Might fail if only one class in batch
        f1 = calculate_f1(y_true, y_pred_classes)
        results['auc'] = auc
        results['f1'] = f1
        
        if test_masks is not None:
            # Assuming test_masks matches y_pred order
            # We need to ensure y_pred is in the same shape/format as masks if they are images
            # But y_pred from classification model is (N, num_classes) or (N, 1)
            # IoU/PRO require segmentation maps (N, H, W, 1)
            # If the model is an Autoencoder, we might generate maps.
            # If it's a classifier, IoU/PRO don't make sense unless we have CAM/GradCAM.
            
            # For now, we only calculate if we have compatible shapes
            if len(y_pred.shape) == 4 and len(test_masks.shape) == 4:
                 iou = calculate_iou(test_masks, y_pred)
                 pro = calculate_pro(test_masks, y_pred)
                 results['iou'] = iou
                 results['pro'] = pro
            else:
                logger.warning("Skipping IoU/PRO: Predictions or masks shape mismatch for segmentation.")
    except Exception as e:
        logger.warning(f"Could not calculate advanced metrics: {e}")
        
    return results

def compare_models(models_dict, test_dataset):
    """
    Compares multiple models.
    """
    comparison = {}
    for name, model in models_dict.items():
        logger.info(f"Benchmarking {name}...")
        comparison[name] = evaluate_model(model, test_dataset)
        
    return comparison

# --- Inlined from src.preprocessing.dataset ---

def augment_image(image_path, save_dir, prefix, count):
    """
    Reads an image, applies random augmentations, and saves it.
    """
    try:
        img = tf.io.read_file(image_path)
        img = tf.image.decode_png(img, channels=3)
        
        # Random augmentations
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        img = tf.image.random_brightness(img, max_delta=0.2)
        img = tf.image.random_contrast(img, lower=0.8, upper=1.2)
        
        # Ensure valid range [0, 255]
        img = tf.clip_by_value(img, 0, 255)
        img = tf.cast(img, tf.uint8)
        
        encoded_img = tf.image.encode_png(img)
        
        filename = f"{prefix}_{count}.png"
        save_path = os.path.join(save_dir, filename)
        
        tf.io.write_file(save_path, encoded_img)
        return save_path
    except Exception as e:
        logger.warning(f"Failed to augment image {image_path}: {e}")
        return None

def balance_classes(df, target_count=1000, save_root="data/processed/augmented"):
    """
    Balances classes in the DataFrame by augmenting rare classes to reach target_count.
    """
    if df.empty:
        return df
        
    logger.info(f"Balancing classes to target count: {target_count}...")
    
    # Create augmentation directory
    if os.path.exists(save_root):
        # Optional: Clear previous augmentations to avoid staleness, 
        # but might be slow if we re-run often. For now, let's keep it simple and overwrite/add.
        pass
    else:
        os.makedirs(save_root, exist_ok=True)
        
    new_rows = []
    
    # Group by label (which maps to a specific Category_DefectType)
    # We need to know the label_str to name files appropriately
    unique_labels = df['label'].unique()
    
    for label in unique_labels:
        class_subset = df[df['label'] == label]
        current_count = len(class_subset)
        
        if current_count >= target_count:
            continue
            
        needed = target_count - current_count
        label_str = class_subset.iloc[0]['label_str']
        category = class_subset.iloc[0]['category']
        
        logger.info(f"Augmenting class '{label_str}': {current_count} -> {target_count} (+{needed})")
        
        # Create class specific save dir
        class_save_dir = os.path.join(save_root, label_str)
        os.makedirs(class_save_dir, exist_ok=True)
        
        # Source images to augment
        source_images = class_subset['filepath'].tolist()
        
        for i in range(needed):
            # Randomly select a source image
            src_img = random.choice(source_images)
            
            # Augment and save
            new_path = augment_image(src_img, class_save_dir, "aug", i)
            
            if new_path:
                new_rows.append({
                    'filepath': new_path,
                    'category': category,
                    'label': label,
                    'label_str': label_str
                })
                
    if new_rows:
        augmented_df = pd.DataFrame(new_rows)
        df = pd.concat([df, augmented_df], ignore_index=True)
        
    logger.info(f"Balancing complete. Total samples: {len(df)}")
    return df

def load_and_split_data(data_dir, split_ratios=(0.8, 0.1, 0.1), seed=42, target_category=None, augment=False):
    """
    Load MVTec AD data, merge normal and abnormal, and split into train/val/test.
    Assigns unique labels for each (Category, DefectType) pair.
    
    Args:
        data_dir (str): Path to the root of the MVTec AD dataset (containing category folders).
        split_ratios (tuple): (train_ratio, val_ratio, test_ratio). Must sum to 1.
        seed (int): Random seed for reproducibility.
        target_category (str, optional): If provided, only load data for this specific category.
        augment (bool): If True, augment the training set to balance classes (1000 samples/class).
        
    Returns:
        tuple: (train_df, val_df, test_df, class_names)
               Each df has columns ['filepath', 'category', 'label', 'label_str']
               class_names: list of string labels indexed by the label integer.
    """
    if sum(split_ratios) != 1.0:
        raise ValueError("Split ratios must sum to 1.0")
        
    train_ratio, val_ratio, test_ratio = split_ratios
    
    # 1. Collect all data and identify classes
    data = []
    
    # Get all categories (subdirectories in data_dir)
    if target_category:
        if not os.path.isdir(os.path.join(data_dir, target_category)):
            raise ValueError(f"Category '{target_category}' not found in {data_dir}")
        categories = [target_category]
    else:
        categories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
        categories.sort() # Ensure deterministic order
    
    # First pass: Collect all unique label strings to build mapping
    # We need to scan to find all defect types
    unique_labels = set()
    
    for category in categories:
        cat_dir = os.path.join(data_dir, category)
        
        # Train data (only 'good')
        unique_labels.add(f"{category}_good")
        
        # Test data (contains 'good' and various defect types)
        test_dir = os.path.join(cat_dir, 'test')
        if os.path.exists(test_dir):
            for defect_type in os.listdir(test_dir):
                if os.path.isdir(os.path.join(test_dir, defect_type)):
                    unique_labels.add(f"{category}_{defect_type}")
                    
    # Create class mapping
    class_names = sorted(list(unique_labels))
    class_to_idx = {name: i for i, name in enumerate(class_names)}
    
    logger.info(f"Found {len(class_names)} unique classes: {class_names}")
    
    # Second pass: Collect data with labels
    for category in categories:
        cat_dir = os.path.join(data_dir, category)
        
        # Train data (only 'good' usually in MVTec AD train set)
        train_good_dir = os.path.join(cat_dir, 'train', 'good')
        if os.path.exists(train_good_dir):
            label_str = f"{category}_good"
            label = class_to_idx[label_str]
            for img_path in glob.glob(os.path.join(train_good_dir, '*.png')):
                data.append({
                    'filepath': img_path,
                    'category': category,
                    'label': label,
                    'label_str': label_str
                })
                
        # Test data (contains 'good' and various defect types)
        test_dir = os.path.join(cat_dir, 'test')
        if os.path.exists(test_dir):
            for defect_type in os.listdir(test_dir):
                defect_dir = os.path.join(test_dir, defect_type)
                if not os.path.isdir(defect_dir):
                    continue
                    
                label_str = f"{category}_{defect_type}"
                label = class_to_idx[label_str]
                
                for img_path in glob.glob(os.path.join(defect_dir, '*.png')):
                    data.append({
                        'filepath': img_path,
                        'category': category,
                        'label': label,
                        'label_str': label_str
                    })

    df = pd.DataFrame(data)
    
    if df.empty:
        logger.warning(f"No data found in {data_dir}")
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), []

    # 2. Stratified Split
    # We want to stratify by Label (which now encodes both Category and DefectType)
    # However, some defect types might have very few samples (e.g. < 3), which makes stratification impossible.
    # We should fall back to simple random split or warn if stratification fails.
    
    # Check class counts
    class_counts = df['label'].value_counts()
    rare_classes = class_counts[class_counts < 2].index.tolist()
    
    if rare_classes:
        logger.warning(f"Classes {rare_classes} have fewer than 2 samples. Stratification for these will fail/be imperfect.")
        # For now, we proceed. train_test_split might error if a class has only 1 sample.
        # We can filter out single-sample classes or just duplicate them? 
        # Let's assume MVTec AD has enough samples per defect type (usually > 10).
    
    # First split: Train vs (Val + Test)
    test_val_ratio = val_ratio + test_ratio
    
    if test_val_ratio == 0:
        return df, pd.DataFrame(), pd.DataFrame(), class_names
        
    try:
        train_df, temp_df = train_test_split(
            df, 
            train_size=train_ratio, 
            stratify=df['label'], 
            random_state=seed
        )
    except ValueError as e:
        logger.warning(f"Stratified split failed (likely due to rare classes): {e}. Falling back to random split.")
        train_df, temp_df = train_test_split(
            df, 
            train_size=train_ratio, 
            random_state=seed
        )

    # 3. Augmentation (Balance Classes) - ONLY on Train set
    if augment:
        train_df = balance_classes(train_df, target_count=1000)
    
    # Second split: Val vs Test
    if test_ratio == 0:
        val_df = temp_df
        test_df = pd.DataFrame()
    elif val_ratio == 0:
        val_df = pd.DataFrame()
        test_df = temp_df
    else:
        relative_test_size = test_ratio / (val_ratio + test_ratio)
        try:
            val_df, test_df = train_test_split(
                temp_df,
                test_size=relative_test_size,
                stratify=temp_df['label'],
                random_state=seed
            )
        except ValueError:
             val_df, test_df = train_test_split(
                temp_df,
                test_size=relative_test_size,
                random_state=seed
            )
        
    logger.info(f"Data split complete.")
    logger.info(f"Train: {len(train_df)} images")
    logger.info(f"Val: {len(val_df)} images")
    logger.info(f"Test: {len(test_df)} images")
    
    return train_df, val_df, test_df, class_names

from tensorflow.keras.applications.resnet50 import preprocess_input

IMG_SIZE = (256, 256)
BATCH_SIZE = 32
DATA_DIR = "../data/raw"
TARGET_CATEGORY = 'capsule' # Train only on this category

# Load data
print(f"Loading and splitting data for {TARGET_CATEGORY}...")
train_df, val_df, test_df, class_names = load_and_split_data(DATA_DIR, target_category=TARGET_CATEGORY, augment=True)
num_classes = len(class_names)

print(f"Classes: {class_names}")
print(f"Number of classes: {num_classes}")

# Dataset creation helper
def process_path(filepath, label):
    img = tf.io.read_file(filepath)
    img = tf.io.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, IMG_SIZE)
    return img, label

def create_dataset(dataframe, batch_size=32, shuffle=False):
    filepaths = dataframe['filepath'].values
    labels = dataframe['label'].values
    ds = tf.data.Dataset.from_tensor_slices((filepaths, labels))
    ds = ds.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

AUTOTUNE = tf.data.AUTOTUNE

# Create base datasets
train_ds = create_dataset(train_df, BATCH_SIZE, shuffle=True)
val_ds = create_dataset(val_df, BATCH_SIZE, shuffle=False)
test_ds = create_dataset(test_df, BATCH_SIZE, shuffle=False)

# Apply ResNet preprocessing
def preprocess_resnet(image, label):
    return preprocess_input(image), label

train_ds = train_ds.map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(preprocess_resnet, num_parallel_calls=AUTOTUNE)

print("Datasets created and preprocessed.")

Loading and splitting data for capsule...
Classes: ['capsule_crack', 'capsule_faulty_imprint', 'capsule_good', 'capsule_poke', 'capsule_scratch', 'capsule_squeeze']
Number of classes: 6
Datasets created and preprocessed.


## 2. Model Definition
We load ResNet50 (ImageNet weights), freeze the base, and add a custom classification head.

In [None]:
def create_resnet_model(input_shape, num_classes):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    
    # Freeze base model
    base_model.trainable = False
    
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs, name="resnet50_transfer")
    return model

model = create_resnet_model(IMG_SIZE + (3,), num_classes)
model.summary()

## 3. Training
Training the top layers.

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)

## 4. Fine-tuning (Optional)
Unfreezing the last few layers of ResNet for better performance.

In [None]:
base_model = model.layers[1]
base_model.trainable = True

# Freeze all except last 20 layers
for layer in base_model.layers[:-20]:
    layer.trainable = False
    
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Lower LR
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)

In [None]:
# Advanced Evaluation

print("Evaluating model with advanced metrics...")
results = evaluate_model(model, test_ds)

print("\nEvaluation Results:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")