In [None]:
# Cell 1: Connect to Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Cell 2: Install necessary libraries
!pip install torch torchvision torchaudio matplotlib numpy einops scikit-learn timm tqdm

In [None]:
# Cell 3: Import necessary libraries
import random
import torch
import os
import zipfile
from urllib import request
import urllib.request
import numpy as np
from collections import defaultdict
import logging

In [None]:
# Cell 4: Configuration
# --- Configuration ---
TINY_IMAGENET_URL = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' # Or use .tar.gz if preferred
DRIVE_MOUNT_POINT = '/content/drive/MyDrive/' # Optional: Google Drive mount point
DATA_DIR = DRIVE_MOUNT_POINT + 'data/tiny-imagenet-200-1'
SAVE_DIR = DRIVE_MOUNT_POINT + 'data/preprocessed_tinyimagenet'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 42

# Data split ratios
IND_CLASS_RATIO = 0.80  # 80% of classes for In-Distribution
PRETRAIN_EXAMPLE_RATIO = 0.75  # 75% of examples from ID classes for actual pretraining

In [None]:
# Cell 5: Setup
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)
logging.info(f"Using device: {DEVICE}")

In [None]:

# Cell 6: Download + extract Tiny Imagenet
# Create a directory for our dataset
os.makedirs(DATA_DIR, exist_ok=True)
logging.info(f"Data will be saved in: {DATA_DIR}")


# URL for the Tiny-ImageNet dataset
zip_path = os.path.join(DATA_DIR, 'tiny-imagenet-200.zip')

# Download the dataset if it doesn't exist
if not os.path.exists(zip_path):
    logging.info("Downloading Tiny-ImageNet dataset...")

    # Create a progress bar for download
    def report_progress(block_num, block_size, total_size):
        progress = float(block_num * block_size) / float(total_size) * 100.0
        logging.info(f"\rDownloading: {progress:.2f}%", end="")

    # Download with progress reporting
    urllib.request.urlretrieve(TINY_IMAGENET_URL, zip_path, reporthook=report_progress)
    logging.info("\nDownload complete!")
else:
    logging.info("Dataset already downloaded.")

# Extract the dataset if not already extracted
extract_dir = os.path.join(DATA_DIR, 'tiny-imagenet-200')
if not os.path.exists(extract_dir):
    logging.info("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    logging.info("Extraction complete!")
else:
    logging.info("Dataset already extracted.")

# Basic validation to check the dataset structure
train_dir = os.path.join(extract_dir, 'train')
val_dir = os.path.join(extract_dir, 'test')

if os.path.exists(train_dir) and os.path.exists(val_dir):
    # Count the number of classes in training set
    train_classes = os.listdir(train_dir)
    logging.info(f"Number of classes in training set: {len(train_classes)}")

    # Check a few example classes
    logging.info(f"Example classes: {train_classes[:5]}")

    # Check the structure of one class
    example_class = train_classes[0]
    example_class_dir = os.path.join(train_dir, example_class)
    example_images_dir = os.path.join(example_class_dir, 'images')
    example_images = os.listdir(example_images_dir)

    logging.info(f"Number of images in {example_class}: {len(example_images)}")
    logging.info(f"Example image paths: {example_images[:3]}")
    logging.info("Dataset structure validation complete!")

else:
    logging.info("Dataset structure seems incorrect. Please check the extraction.")

# Output from Colab
# Data will be saved in: /content/drive/MyDrive/data/tiny-imagenet-200-1
# Dataset already downloaded.
# Dataset already extracted.
# Number of classes in training set: 122
# Example classes: ['n03584254', 'n02403003', 'n02056570', 'n02769748', 'n01443537']
# Number of images in n03584254: 500
# Example image paths: ['n03584254_251.JPEG', 'n03584254_348.JPEG', 'n03584254_465.JPEG']
# Dataset structure validation complete!

In [None]:
# Cell 7: Get Tiny ImageNet train data details
def get_tiny_imagenet_train_data_details(dataset_path):
    """
    Gathers details of the Tiny ImageNet training set.
    Returns:
        - sample_details: List of dicts, each {'path': str, 'original_label_idx': int, 'original_dataset_idx': int, 'wnid': str}
        - wnid_to_idx: Dict mapping WNID to original integer label (0-199)
        - idx_to_wnid: Dict mapping original integer label to WNID
        - all_wnids_ordered: List of WNIDs, order defines the 0-199 mapping
    """
    train_dir = os.path.join(dataset_path, 'train')
    wnids_file = os.path.join(dataset_path, 'wnids.txt')

    with open(wnids_file, 'r') as f:
        all_wnids_ordered = [line.strip() for line in f]

    wnid_to_idx = {wnid: i for i, wnid in enumerate(all_wnids_ordered)} # Maps WNID to original integer label (0-199)
    idx_to_wnid = {i: wnid for wnid, i in wnid_to_idx.items()} # Maps original integer label to WNID

    sample_details = []
    current_original_idx = 0
    logging.info("Scanning Tiny ImageNet training directory...")

    # Retrieve all image paths for each class
    for wnid in os.listdir(train_dir):
        if wnid not in wnid_to_idx:
            continue # Skip non-class folders like .DS_Store

        original_class_idx = wnid_to_idx[wnid]
        class_image_dir = os.path.join(train_dir, wnid, 'images')

        # Check if 'images' subdirectory exists, if not, check current wnid directory
        if not os.path.isdir(class_image_dir):
             img_files_dir = os.path.join(train_dir, wnid)
        else:
            img_files_dir = class_image_dir

        for img_name in os.listdir(img_files_dir):
            if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                img_path = os.path.join(img_files_dir, img_name)
                sample_details.append({
                    'path': img_path,
                    'original_label_idx': original_class_idx,
                    'original_dataset_idx': current_original_idx,
                    'wnid': wnid
                })
                current_original_idx += 1

    if not sample_details:
        raise FileNotFoundError(
            f"No images found. Please check the structure of {train_dir}. "
            f"Expected structure: {train_dir}/<wnid>/images/<image_file.JPEG> or {train_dir}/<wnid>/<image_file.JPEG>"
        )

    logging.info(f"Found {len(sample_details)} training images from {len(all_wnids_ordered)} classes.")
    return sample_details, wnid_to_idx, idx_to_wnid, all_wnids_ordered

In [None]:
# Cell 8: Split classes into In-Distribution (ID) and Out-of-Distribution (OOD)
tiny_imagenet_dir = os.path.join(DATA_DIR, 'tiny-imagenet-200')

if not os.path.exists(tiny_imagenet_dir):
    logging.info(f"Error: Tiny ImageNet directory not found at '{tiny_imagenet_dir}'.")
    logging.info(f"Please download and extract Tiny ImageNet to the {DATA_DIR}.")
    raise FileNotFoundError(f"Tiny ImageNet directory not found at '{tiny_imagenet_dir}'")

sample_details, wnid_to_idx, idx_to_wnid, all_wnids_ordered = \
    get_tiny_imagenet_train_data_details(tiny_imagenet_dir)

num_total_original_classes = len(all_wnids_ordered) # Should be 200

# 1. Split classes into In-Distribution (ID) and Out-of-Distribution (OOD)
all_original_class_indices = list(range(num_total_original_classes))
random.shuffle(all_original_class_indices)

num_id_classes = int(IND_CLASS_RATIO * num_total_original_classes) # Should be 160

pretrain_classes_original_idxs = sorted(all_original_class_indices[:num_id_classes])
ood_classes_original_idxs = sorted(all_original_class_indices[num_id_classes:])

logging.info(f"\nTotal classes: {num_total_original_classes}")
logging.info(f"In-distribution (ID) classes selected (original indices): {len(pretrain_classes_original_idxs)}")
logging.info(f"Out-of-distribution (OOD) classes selected (original indices): {len(ood_classes_original_idxs)}")

# 2. Create mapping for ID classes to new contiguous labels [0, num_id_classes-1]
ind_class_mapping_from_original = {
    original_idx: new_idx for new_idx, original_idx in enumerate(pretrain_classes_original_idxs)
}

# 3. Group samples by their original class
samples_by_original_class = defaultdict(list)
for sample in sample_details:
    samples_by_original_class[sample['original_label_idx']].append(sample)

pretrained_ind_indices = []      # List of original_dataset_idx for ViT training (75% of ID)
pretrained_left_out_indices = [] # List of original_dataset_idx for ID data not used in ViT training (25% of ID)

logging.info("\nSplitting samples for ID classes:")
for original_class_idx in pretrain_classes_original_idxs:
    class_samples = samples_by_original_class[original_class_idx]
    random.shuffle(class_samples) # Shuffle samples within the class

    num_samples_in_class = len(class_samples)
    num_for_pretrain = int(PRETRAIN_EXAMPLE_RATIO * num_samples_in_class)

    # Ensure at least one sample for pretraining if class is not empty
    if num_samples_in_class > 0 and num_for_pretrain == 0:
        num_for_pretrain = 1

    for i, sample in enumerate(class_samples):
        if i < num_for_pretrain:
            pretrained_ind_indices.append(sample['original_dataset_idx'])
        else:
            pretrained_left_out_indices.append(sample['original_dataset_idx'])

pretrained_ind_indices.sort()
pretrained_left_out_indices.sort()

logging.info(f"Total samples for ViT pretraining (75% of ID): {len(pretrained_ind_indices)}")
logging.info(f"Total ID samples left out from pretraining (25% of ID): {len(pretrained_left_out_indices)}")

# Store class information
class_info = {
    'num_of_classes': num_total_original_classes,
    'pretrain_classes': pretrain_classes_original_idxs, # List of original class indices (0-199) for ID
    'left_out_classes': ood_classes_original_idxs,   # List of original class indices (0-199) for OOD
    'left_out_ind_indices': pretrained_left_out_indices,
    'pretrained_ind_indices': pretrained_ind_indices,
    'pretrain_class_mapping': ind_class_mapping_from_original, # Maps original ID class_idx -> new contiguous idx
    'wnid_to_idx': wnid_to_idx, # wnid -> original_idx (0-199)
    'idx_to_wnid': idx_to_wnid, # original_idx (0-199) -> wnid
    'all_wnids_ordered': all_wnids_ordered # Defines the 0-199 mapping
}

# Save all image paths and their original labels for easy access in the training script
# This avoids re-scanning the directory in the training script.
all_training_image_paths = [s['path'] for s in sample_details]
all_training_original_labels = [s['original_label_idx'] for s in sample_details]

data_to_save = {
    'class_info': class_info,
    'all_training_image_paths': all_training_image_paths,
    'all_training_original_labels': all_training_original_labels
}

save_file = os.path.join(SAVE_DIR, 'tiny_imagenet_data_info.pth')
torch.save(data_to_save, save_file)
logging.info(f"\nPreprocessing complete. Data information saved to: {save_file}")

# For convenience, also save just the class_info as requested by user for OOD experiments
class_info_only_file = os.path.join(SAVE_DIR, 'class_info.pth')
torch.save(class_info, class_info_only_file)
logging.info(f"Class_info (for OOD experiments) saved to: {class_info_only_file}")
