In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import shutil
import random
import tarfile
import requests
import pickle
import pandas as pd
import numpy as np
import zipfile
import shutil
import urllib.request
import time
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.notebook import tqdm  # Use notebook tqdm for Colab
from sklearn.model_selection import train_test_split
from torch.amp.autocast_mode import autocast

In [None]:
# --- Configuration ---
TINY_IMAGENET_URL = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' # Or use .tar.gz if preferred
DRIVE_MOUNT_POINT = '/content/drive/MyDrive/' # Optional: Google Drive mount point
DATA_DIR = DRIVE_MOUNT_POINT + 'data/tiny-imagenet-200-1'
SAVE_DIR = DRIVE_MOUNT_POINT + 'vit_pretrained_tinyimagenet'
OUTPUT_DIR = 'output_data_splits' # Directory to save split info and model
SAVE_TO_DRIVE = True # Set to True to save outputs to Google Drive/My Drive/Colab Outputs

# Splitting percentages
ID_CLASS_RATIO = 0.80
EXAMPLE_PRETRAIN_RATIO = 0.75
# Model and Training Hyperparameters
MODEL_NAME = 'vit_tiny_patch16_224'
IMAGE_SIZE = 224
MEAN=[0.485, 0.456, 0.406]
STD=[0.229, 0.224, 0.225]

BATCH_SIZE = 128 # Adjust based on Colab GPU memory (T4, V100, etc.)
LEARNING_RATE = 1e-4 # Peak learning rate after warmup
WEIGHT_DECAY = 0.05
WARMUP_EPOCHS = 5 # <<< Number of epochs for linear warmup <<< ADD THIS
NUM_EPOCHS = 100 # Adjust as needed for convergence

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 42

# <<< ADDED: Early Stopping Parameters >>>
PATIENCE = 10 # How many epochs to wait for improvement before stopping
EARLY_STOPPING_METRIC = 'val_acc' # Metric to monitor ('val_acc' or 'val_loss')
MIN_DELTA = 0.001 # Minimum change to qualify as an improvement (for val_acc, use positive; for val_loss, use negative)

# <<< ADDED: Label Smoothing Parameter >>>
LABEL_SMOOTHING = 0.1 # Factor for label smoothing (0.0 means no smoothing)

print(DEVICE)

In [None]:
# --- Setup ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create a directory for our dataset
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Data will be saved in: {DATA_DIR}")


# URL for the Tiny-ImageNet dataset
zip_path = os.path.join(DATA_DIR, 'tiny-imagenet-200.zip')

# Download the dataset if it doesn't exist
if not os.path.exists(zip_path):
    print("Downloading Tiny-ImageNet dataset...")

    # Create a progress bar for download
    def report_progress(block_num, block_size, total_size):
        progress = float(block_num * block_size) / float(total_size) * 100.0
        print(f"\rDownloading: {progress:.2f}%", end="")

    # Download with progress reporting
    urllib.request.urlretrieve(TINY_IMAGENET_URL, zip_path, reporthook=report_progress)
    print("\nDownload complete!")
else:
    print("Dataset already downloaded.")

# Extract the dataset if not already extracted
extract_dir = os.path.join(DATA_DIR, 'tiny-imagenet-200')
if not os.path.exists(extract_dir):
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    print("Extraction complete!")
else:
    print("Dataset already extracted.")

# Basic validation to check the dataset structure
train_dir = os.path.join(extract_dir, 'train')
val_dir = os.path.join(extract_dir, 'val')

if os.path.exists(train_dir) and os.path.exists(val_dir):
    # Count the number of classes in training set
    train_classes = os.listdir(train_dir)
    print(f"Number of classes in training set: {len(train_classes)}")

    # Check a few example classes
    print(f"Example classes: {train_classes[:5]}")

    # Check the structure of one class
    example_class = train_classes[0]
    example_class_dir = os.path.join(train_dir, example_class)
    example_images_dir = os.path.join(example_class_dir, 'images')
    example_images = os.listdir(example_images_dir)
    print(f"Number of images in {example_class}: {len(example_images)}")
    print(f"Example image paths: {example_images[:3]}")

    print("Dataset structure validation complete!")
else:
    print("Dataset structure seems incorrect. Please check the extraction.")

In [None]:
# --- Copy data from Drive to local Colab storage (if not already local) ---
# Assumes your data is in Drive at this path:
gdrive_data_dir = DATA_DIR # ADJUST THIS PATH if needed
local_data_dir = '/content/tiny-imagenet-local' # Destination on Colab's fast local disk

print(f"Source data path (Google Drive): {gdrive_data_dir}")
print(f"Local data destination: {local_data_dir}")

start_copy_time = time.time()
if os.path.exists(gdrive_data_dir):
    if not os.path.exists(local_data_dir):
        print("Copying dataset from Google Drive to local Colab storage...")
        print("This might take 10 - 20 minutes...")
        try:
            # Use shutil.copytree for simplicity, though rsync can be faster if available/needed
            shutil.copytree(gdrive_data_dir, local_data_dir)
            # OR using rsync (often faster, handles interruptions better, might need installation)
            print(f"Dataset copied successfully to {local_data_dir}")
        except Exception as e:
            print(f"ERROR copying dataset: {e}")
            print("Proceeding might be extremely slow if using Drive directly.")
            # Decide how to handle error: exit, or try to use Drive path?
            # For now, we'll assume the user wants to continue, but warn them.
            # Set DATA_DIR to the Drive path as a fallback IF local copy failed.
            # DATA_DIR = gdrive_data_dir # Fallback - uncomment if needed
    else:
        print("Dataset already exists in local Colab storage.")
    # *** IMPORTANT: Update DATA_DIR to use the local path ***
    DATA_DIR = local_data_dir
else:
    print(f"Warning: Google Drive data directory not found at {gdrive_data_dir}")
    print("Assuming data might already be local or script needs adjustment.")
    # If you previously downloaded directly to /content/tiny-imagenet-200, set that path
    if os.path.exists('/content/tiny-imagenet-200'):
         DATA_DIR = '/content/tiny-imagenet-200' # Example if downloaded directly
         print(f"Using existing directory: {DATA_DIR}")
    else:
         # If neither Drive path nor default local path exists, raise error or set expected path
         print(f"ERROR: Cannot find data. Set DATA_DIR manually or check paths.")
         # DATA_DIR = '/content/tiny...' # Set your expected path here if needed
         # exit() # Or exit if data is missing

end_copy_time = time.time()
print(f"Data setup took {end_copy_time - start_copy_time:.2f} seconds.")
print(f"Using DATA_DIR: {DATA_DIR}") # Verify this path is used later

# --- Now proceed with the rest of your script ---
# Ensure Steps 2, 5, etc., use this updated DATA_DIR variable

In [None]:
# --- 2. Load Class Information ---
wnids_path = os.path.join(extract_dir, 'wnids.txt')
words_path = os.path.join(extract_dir, 'words.txt')

if not os.path.exists(wnids_path) or not os.path.exists(words_path):
     print(f"Error: Cannot find {wnids_path} or {words_path}. Dataset might be corrupted or incomplete.")
     # Exit or handle error appropriately
     # For now, we'll stop if these critical files are missing after download/extract attempt.
     exit()


all_wnids = []
with open(wnids_path, 'r') as f:
    all_wnids = [line.strip() for line in f]
num_total_classes = len(all_wnids)
print(f"Total classes found: {num_total_classes}")

wnid_to_name = {}
with open(words_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            wnid_to_name[parts[0]] = parts[1]

# --- 3. Create Mappings ---
wnid_to_idx = {wnid: i for i, wnid in enumerate(all_wnids)}
idx_to_wnid = {i: wnid for wnid, i in wnid_to_idx.items()}

# --- 4. Class Split (ID vs OOD) ---
num_id_classes = int(num_total_classes * ID_CLASS_RATIO)
num_ood_classes = num_total_classes - num_id_classes

print(f"Splitting classes: {num_id_classes} ID classes, {num_ood_classes} OOD classes.")

shuffled_wnids = list(all_wnids)
random.shuffle(shuffled_wnids) # Shuffle wnids randomly

id_wnids = sorted(shuffled_wnids[:num_id_classes]) # Sort for consistency
ood_wnids = sorted(shuffled_wnids[num_id_classes:]) # Sort for consistency

print(f"\nSelected {len(id_wnids)} ID WNIDs (first 5): {id_wnids[:5]}")
print(f"Selected {len(ood_wnids)} OOD WNIDs (first 5): {ood_wnids[:5]}")

# Map ID WNIDs to the new label space (0 to num_id_classes-1) for the classifier
wnid_to_pretrain_label = {wnid: i for i, wnid in enumerate(id_wnids)}
pretrain_label_to_wnid = {i: wnid for wnid, i in wnid_to_pretrain_label.items()}

In [None]:
# --- 5. Prepare File Lists and Example Splits ---
train_dir = os.path.join(extract_dir, 'train')

pretrain_files = [] # List of (filepath, label_idx_0_to_159)
# *** MODIFIED: Store (filepath, wnid) for reserved files ***
id_reserved_files_with_wnid = []
ood_class_files = [] # List of filepaths (examples from OOD classes)
file_to_label_map = {} # Map filepath -> pretrain_label for the dataset

print("\nScanning training directory and splitting examples...")
start_time = time.time()

for wnid in tqdm(os.listdir(train_dir)):
    class_dir = os.path.join(train_dir, wnid)
    if os.path.isdir(class_dir):
        images_dir = os.path.join(class_dir, 'images')
        if not os.path.exists(images_dir):
             print(f"Warning: 'images' subdirectory not found in {class_dir}. Skipping this class.")
             continue

        image_files = [os.path.join(images_dir, fname) for fname in os.listdir(images_dir) if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
        random.shuffle(image_files)

        if wnid in id_wnids:
            # In-Distribution class
            num_examples = len(image_files)
            num_pretrain = int(num_examples * EXAMPLE_PRETRAIN_RATIO)

            pretrain_split = image_files[:num_pretrain]
            reserved_split = image_files[num_pretrain:] # Paths for reserved files

            # *** MODIFIED: Store (filepath, wnid) for reserved files ***
            for fpath in reserved_split:
                id_reserved_files_with_wnid.append((fpath, wnid)) # Store WNID

            pretrain_label = wnid_to_pretrain_label[wnid]
            for fpath in pretrain_split:
                pretrain_files.append((fpath, pretrain_label))
                file_to_label_map[fpath] = pretrain_label

        elif wnid in ood_wnids:
            # Out-of-Distribution class
            ood_class_files.extend(image_files)
        # else: # WNID not in ID or OOD lists (shouldn't happen with correct setup)

end_time = time.time()
print(f"Finished splitting files in {end_time - start_time:.2f} seconds.")
print(f"Total pretraining examples: {len(pretrain_files)}")
print(f"Total ID-reserved examples (for validation): {len(id_reserved_files_with_wnid)}") # Modified print
print(f"Total OOD-class examples: {len(ood_class_files)}")

# --- Create Validation File List with Correct Labels ---
validation_files = []
print("\nMapping reserved files to pretraining labels for validation set...")
for fpath, wnid in tqdm(id_reserved_files_with_wnid, desc="Mapping val files"):
    if wnid in wnid_to_pretrain_label: # Ensure wnid is an ID class (should always be true here)
        pretrain_label = wnid_to_pretrain_label[wnid]
        validation_files.append((fpath, pretrain_label))
    else:
        print(f"Warning: WNID {wnid} from reserved files not found in pretrain label map. Skipping file {fpath}")

print(f"Created validation file list with {len(validation_files)} samples.")

In [None]:
# --- 6. Custom Dataset (No significant changes needed here) ---
# The dataset remains simple, loading verified images one by one.
class TinyImageNetPretrain(Dataset):
    def __init__(self, file_label_list, transform=None):
        """
        Args:
            file_label_list (list): List of tuples (verified_filepath, label_index).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.file_label_list = file_label_list
        self.transform = transform

    def __len__(self):
        return len(self.file_label_list)

    def __getitem__(self, idx):
        # We assume files exist and are readable due to pre-verification
        img_path, label = self.file_label_list[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            # Should be rare now, but good to have a fallback
            print(f"ERROR in __getitem__ for {img_path}: {e}. Returning None.")
             # Returning None requires a collate_fn in DataLoader to handle it,
             # or risking errors later. For simplicity, we'll rely on pre-verification.
             # If errors still occur here, investigate the specific files.
             # Re-raising might be better for debugging:
            raise RuntimeError(f"Failed to load image {img_path} during training") from e

In [None]:
# --- 7. Data Transformations (No changes needed) ---
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.6, 1.0), ratio=(0.75, 1.33), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

# --- 8. Create Datasets and DataLoaders (Optimized) ---
pretrain_dataset = TinyImageNetPretrain(
    file_label_list=pretrain_files, # Use the verified list
    transform=train_transform
)

# DataLoader Optimizations:
# - num_workers: Increased (e.g., 4). Monitor Colab CPU/RAM. If issues occur, reduce to 2.
# - persistent_workers: Reduces overhead between epochs (requires PyTorch >= 1.7, usually true on Colab)
# - prefetch_factor: Controls batch preloading per worker. Default (2) is often good.
# - pin_memory: Kept as True for faster CPU->GPU transfers.
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # <--- Increased from 2 (Monitor resources)
    pin_memory=True, # <--- Keep True for GPU
    persistent_workers=True, # <--- Added for efficiency
    prefetch_factor=2 # <--- Explicitly set default (can be tuned)
)

print(f"\nOptimized pretraining DataLoader created:")
print(f"  Dataset size: {len(pretrain_dataset)} samples.")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Num workers: {pretrain_loader.num_workers}")
print(f"  Persistent workers: {pretrain_loader.persistent_workers}")
print(f"  Pin memory: {pretrain_loader.pin_memory}")

In [None]:
# [(After Step 8: Create Datasets and DataLoaders - pretrain_loader exists)]

# --- 8b. Create Validation Dataset and DataLoader ---

# Use simpler transforms for validation (matching training's final size/normalization)
val_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

# Reuse the same Dataset class (TinyImageNetPretrain is suitable)
# Use the 'validation_files' list created in the modified Step 5
val_dataset = TinyImageNetPretrain(
    file_label_list=validation_files,
    transform=val_transform
)

if len(val_dataset) > 0:
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE * 2, # Often possible to use larger batch for validation (no grads)
        shuffle=False, # No need to shuffle validation data
        num_workers=2, # Can use fewer workers for validation if needed
        pin_memory=True,
        persistent_workers=False # Less critical for validation
    )
    print(f"\nValidation DataLoader created:")
    print(f"  Dataset size: {len(val_dataset)} samples.")
    print(f"  Batch size: {val_loader.batch_size}") # Print actual batch size used
else:
    print("\nWarning: Validation dataset is empty. Skipping validation loader creation.")
    val_loader = None

# [(Proceed to Step 9: Model Setup)]


In [None]:
# --- 9. Model Setup (No changes needed) ---
print(f"\nLoading model: {MODEL_NAME}")
model = timm.create_model(
    MODEL_NAME,
    pretrained=False,
    num_classes=num_id_classes
)
model = model.to(device)
print(f"Model loaded and moved to {device}.")

In [None]:
# --- 10. Training Setup (Optimizer, Scaler, and Scheduler) ---
from torch.amp.grad_scaler import GradScaler
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# --- Initialize AMP GradScaler ---
use_amp = torch.cuda.is_available()
scaler = GradScaler('cuda' if torch.cuda.is_available() else 'cpu')
print(f"AMP (Automatic Mixed Precision) enabled: {use_amp}")

# --- Initialize Learning Rate Scheduler ---
# Calculate total training steps and warmup steps
# Need len(pretrain_loader) which is calculated after DataLoader creation (Step 8)
try:
    steps_per_epoch = len(pretrain_loader)
    total_training_steps = NUM_EPOCHS * steps_per_epoch
    warmup_steps = WARMUP_EPOCHS * steps_per_epoch
    cosine_steps = total_training_steps - warmup_steps

    print(f"\nScheduler Configuration:")
    print(f"  Total Epochs: {NUM_EPOCHS}")
    print(f"  Warmup Epochs: {WARMUP_EPOCHS}")
    print(f"  Steps per Epoch: {steps_per_epoch}")
    print(f"  Total Training Steps: {total_training_steps}")
    print(f"  Warmup Steps: {warmup_steps}")
    print(f"  Cosine Annealing Steps: {cosine_steps}")
    print(f"  Peak LR: {LEARNING_RATE}")

    if warmup_steps > total_training_steps:
         raise ValueError("WARMUP_EPOCHS cannot be greater than or equal to NUM_EPOCHS")
    if warmup_steps < 0:
         raise ValueError("WARMUP_EPOCHS cannot be negative")

    # Scheduler 1: Linear Warmup
    # Starts from LEARNING_RATE * start_factor and linearly increases to LEARNING_RATE
    linear_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.01, # Start LR = 0.01 * LEARNING_RATE
        end_factor=1.0,    # End LR = 1.0 * LEARNING_RATE
        total_iters=warmup_steps # Number of steps for warmup
    )

    # Scheduler 2: Cosine Annealing Decay
    # Starts annealing from LEARNING_RATE down to eta_min over cosine_steps
    cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=cosine_steps, # Number of steps for one cosine cycle (after warmup)
        eta_min=1e-6        # Minimum learning rate
    )

    # Combine Schedulers: Use Linear for `warmup_steps`, then switch to Cosine
    # Milestones are the steps at which to switch schedulers.
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer,
        schedulers=[linear_scheduler, cosine_scheduler],
        milestones=[warmup_steps] # Switch to cosine_scheduler at step `warmup_steps`
    )
    print("  Using SequentialLR (Linear Warmup + Cosine Annealing).")

except NameError:
     print("Error: pretrain_loader not defined before scheduler setup. Ensure DataLoader is created first.")
     scheduler = None # Or handle error appropriately
except Exception as e:
     print(f"Error setting up scheduler: {e}")
     scheduler = None


In [None]:
# [(After Step 10: Training Setup)]

# --- 10b. Validation Function ---
def validate_one_epoch(model, loader, criterion, device, use_amp):
    model.eval() # Set model to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad(): # Disable gradient calculation
        progress_bar = tqdm(loader, desc="Validation", leave=False)
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Use autocast for consistency, although grads aren't computed
            with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item(), batch_acc=(predicted == labels).float().mean().item())

    if total_samples == 0:
        print("Warning: No samples processed during validation.")
        return 0.0, 0.0 # Return zeros if no samples

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

In [None]:
# --- 11. Training Loop (Optimized with AMP and Scheduler) ---
print(f"\nStarting Optimized Pretraining w/ Scheduler for {NUM_EPOCHS} epochs...")
start_train_time = time.time()

# Initialize history dictionary
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}

# <<< ADDED: Early Stopping Initialization >>>
epochs_no_improve = 0
best_metric_value = None
best_model_state = None # Store the best model state_dict here
# Determine initial best value based on metric (minimize loss, maximize accuracy)
if EARLY_STOPPING_METRIC == 'val_loss':
    best_metric_value = float('inf')
    print(f"Early stopping monitoring '{EARLY_STOPPING_METRIC}' (lower is better). Patience: {PATIENCE}")
elif EARLY_STOPPING_METRIC == 'val_acc':
    best_metric_value = float('-inf')
    print(f"Early stopping monitoring '{EARLY_STOPPING_METRIC}' (higher is better). Patience: {PATIENCE}")
else:
    print(f"Warning: Invalid EARLY_STOPPING_METRIC '{EARLY_STOPPING_METRIC}'. Disabling early stopping.")
    PATIENCE = float('inf') # Effectively disable stopping if metric is wrong
# <<< END ADDED >>>

if scheduler is None:
     print("ERROR: Scheduler not initialized. Check Step 10.")
     # Decide how to proceed if scheduler failed

for epoch in range(NUM_EPOCHS):
    # --- Training Phase ---
    model.train() # Set model to training mode
    train_running_loss = 0.0
    train_correct_predictions = 0
    train_total_samples = 0

    progress_bar = tqdm(pretrain_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Train", leave=False)

    for i, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if scheduler is not None:
            scheduler.step()

        # --- Training Statistics ---
        train_running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        train_total_samples += labels.size(0)
        train_correct_predictions += (predicted == labels).sum().item()
        # --- End Statistics ---

        if (i + 1) % 100 == 0 or (i + 1) == len(pretrain_loader): # Update less often
             current_lr = optimizer.param_groups[0]['lr']
             progress_bar.set_postfix(loss=loss.item(), batch_acc=(predicted == labels).float().mean().item(), lr=f"{current_lr:.1e}")

    epoch_train_loss = train_running_loss / train_total_samples if train_total_samples > 0 else 0
    epoch_train_acc = train_correct_predictions / train_total_samples if train_total_samples > 0 else 0

    # --- Validation Phase ---
    current_val_loss, current_val_acc = 0.0, 0.0 # Default values
    if val_loader:
      current_val_loss, current_val_acc = validate_one_epoch(model, val_loader, criterion, device, use_amp)

    # --- Store History ---
    current_lr = optimizer.param_groups[0]['lr']
    history['train_loss'].append(epoch_train_loss)
    history['train_acc'].append(epoch_train_acc * 100) # Store as percentage
    history['val_loss'].append(current_val_loss)
    history['val_acc'].append(current_val_acc * 100) # Store as percentage
    history['lr'].append(current_lr)

    # --- Print Epoch Results ---
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> "
          f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f} | "
          f"Val Loss: {current_val_loss:.4f}, Val Acc: {current_val_acc:.4f} | "
          f"LR: {current_lr:.1e}")

    # --- Optional: Plot progress during training (can be noisy) ---
    # if (epoch + 1) % 5 == 0: # Plot every 5 epochs
    #    plot_training_progress(history, "TinyImageNet_Pretrain_InProgress")
   # <<< ADDED: Early Stopping Check >>>
    if val_loader and PATIENCE != float('inf'): # Only check if validation ran and patience is set
        # Determine the metric to check
        metric_to_check = current_val_loss if EARLY_STOPPING_METRIC == 'val_loss' else current_val_acc * 100 # Use % acc

        improved = False
        if EARLY_STOPPING_METRIC == 'val_loss':
            # Check if current loss is lower than best loss by MIN_DELTA
            if metric_to_check < best_metric_value - abs(MIN_DELTA): # Ensure MIN_DELTA is positive for loss check
                 improved = True
        elif EARLY_STOPPING_METRIC == 'val_acc':
            # Check if current accuracy is higher than best accuracy by MIN_DELTA
             if metric_to_check > best_metric_value + abs(MIN_DELTA): # Ensure MIN_DELTA is positive for acc check
                 improved = True

        if improved:
            print(f"  ({EARLY_STOPPING_METRIC} improved from {best_metric_value:.4f} to {metric_to_check:.4f})")
            best_metric_value = metric_to_check
            epochs_no_improve = 0
            # Save the best model state in memory
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            print(f"  ({EARLY_STOPPING_METRIC} did not improve from {best_metric_value:.4f}. Patience: {epochs_no_improve}/{PATIENCE})")

        if epochs_no_improve >= PATIENCE:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs.")
            break # Exit the training loop
    # <<< END ADDED >>>

end_train_time = time.time()
print(f"\nOptimized pretraining w/ Scheduler & Validation finished in {(end_train_time - start_train_time)/60:.2f} minutes.")

# --- Output Final Accuracy ---
final_train_acc = history['train_acc'][-1] if history['train_acc'] else 0
final_val_acc = history['val_acc'][-1] if history['val_acc'] else 0
print(f"\nFinal Training Accuracy: {final_train_acc:.2f}%")
print(f"Final Validation Accuracy: {final_val_acc:.2f}%")

# --- (Proceed to Step 12: Save Results) ---

In [None]:
# --- 12. Save Results ---
print("\nSaving results...")

# a) Save Model State Dictionary
if not os.path.exists(SAVE_DIR):
  os.makedirs(SAVE_DIR, exist_ok=True)

pretrain_files_onlypaths = [item[0] for item in pretrain_files] # Save only paths

model_info = {
  'model_state_dict': model.state_dict(),
  'class_info': {
    'ind_all_ids': id_wnids, # All the string class IDs of IND data
    'ind_pretrain_files_paths': pretrain_files_onlypaths,
    'ind_reserved_files_paths': id_reserved_files_with_wnid,
    'ood_files_paths': ood_class_files,
    'ood_class_ids': ood_wnids, # All the string class IDs of IND data
    'class_id_to_pretrain_label': wnid_to_pretrain_label,
    'wnid_to_name': wnid_to_name
  }
}

model_save_path = os.path.join(SAVE_DIR, f"{MODEL_NAME}_tinyimagenet_pretrained.pth")
torch.save(model_info, model_save_path)
print(f"All Model data saved to: {model_save_path}")

print("\n--- Summary ---")
print(f"Pretrained Model: {MODEL_NAME}")
print(f"Total Classes: {num_total_classes}")
print(f"ID Classes ({len(id_wnids)}): Used for pretraining (subset) and ID detection examples.")
print(f"OOD Classes ({len(ood_wnids)}): Held out entirely, used for OOD detection examples.")
print(f"Pretraining Examples: {len(pretrain_files_onlypaths)} (from {len(id_wnids)} classes)")
print(f"ID-Reserved Examples: {len(id_reserved_files_with_wnid)} (from {len(id_wnids)} classes, for later OOD detection)")
print(f"OOD-Class Examples: {len(ood_class_files)} (from {len(ood_wnids)} classes, for later OOD detection)")
print(f"Saved artifacts directory: {SAVE_DIR}")
print("--- Script Finished ---")