In [4]:
# Cell 3: Import necessary libraries
import random
import torch
import os
import zipfile
import urllib.request
import numpy as np
from collections import defaultdict
import logging

In [5]:
# Cell 4: Configuration
TINY_IMAGENET_URL = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' # Or use .tar.gz if preferred

DRIVE_MOUNT_POINT = '/Users/alexc/Education/OMSCS/Masters_Project/msproject_repo/data' # Optional: Google Drive mount point
DATA_DIR = DRIVE_MOUNT_POINT + '/tiny-imagenet-200'
SAVE_DIR = DATA_DIR + '/preprocessed_tinyimagenet'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 42

# Data split ratios
IND_CLASS_RATIO = 0.80  # 80% of classes for In-Distribution
PRETRAIN_EXAMPLE_RATIO = 0.75  # 75% of examples from ID classes for actual pretraining

In [6]:
# Cell 5: Setup
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)
print(f"Using device: {DEVICE}")

Using device: cpu


In [7]:
# Cell 6: Download + extract Tiny Imagenet
# Create a directory for our dataset
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Data will be saved in: {DATA_DIR}")

os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Preprocessed data will be saved in: {SAVE_DIR}")


# URL for the Tiny-ImageNet dataset
zip_path = os.path.join(DATA_DIR, 'tiny-imagenet-200.zip')

# Download the dataset if it doesn't exist
if not os.path.exists(zip_path):
    print("Downloading Tiny-ImageNet dataset...")

    # Create a progress bar for download
    def report_progress(block_num, block_size, total_size):
        progress = float(block_num * block_size) / float(total_size) * 100.0
        print(f"\rDownloading: {progress:.2f}%", end="")

    # Download with progress reporting
    urllib.request.urlretrieve(TINY_IMAGENET_URL, zip_path, reporthook=report_progress)
    print("\nDownload complete!")
else:
    print("Dataset already downloaded.")

# Extract the dataset if not already extracted
extract_dir = os.path.join(DATA_DIR, 'tiny-imagenet-200')
if not os.path.exists(extract_dir):
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    print("Extraction complete!")
else:
    print("Dataset already extracted.")

# Basic validation to check the dataset structure
train_dir = os.path.join(extract_dir, 'train')
val_dir = os.path.join(extract_dir, 'test')

if os.path.exists(train_dir) and os.path.exists(val_dir):
    # Count the number of classes in training set
    train_classes = os.listdir(train_dir)
    print(f"Number of classes in training set: {len(train_classes)}")

    # Check a few example classes
    print(f"Example classes: {train_classes[:5]}")

    # Check the structure of one class
    example_class = train_classes[0]
    example_class_dir = os.path.join(train_dir, example_class)
    example_images_dir = os.path.join(example_class_dir, 'images')
    example_images = os.listdir(example_images_dir)

    print(f"Number of images in {example_class}: {len(example_images)}")
    print(f"Example image paths: {example_images[:3]}")
    print("Dataset structure validation complete!")
else:
    print("Dataset structure seems incorrect. Please check the extraction.")

Data will be saved in: /Users/alexc/Education/OMSCS/Masters_Project/msproject_repo/data/tiny-imagenet-200
Preprocessed data will be saved in: /Users/alexc/Education/OMSCS/Masters_Project/msproject_repo/data/tiny-imagenet-200/preprocessed_tinyimagenet
Downloading Tiny-ImageNet dataset...
Downloading: 100.00%
Download complete!
Extracting dataset...
Extraction complete!
Number of classes in training set: 200
Example classes: ['n02795169', 'n02769748', 'n07920052', 'n02917067', 'n01629819']
Number of images in n02795169: 500
Example image paths: ['n02795169_369.JPEG', 'n02795169_386.JPEG', 'n02795169_105.JPEG']
Dataset structure validation complete!


In [8]:
# ====================== 1. Split classes into In-Distribution (ID) and Out-of-Distribution (OOD) ======================

"""
Gathers details of the Tiny ImageNet training set.
Returns:
    - sample_details: List of dicts, each {'path': str, 'original_label_idx': int, 'original_example_dataset_idx': int, 'wnid': str}
    - wnid_to_idx: Dict mapping WNID to original integer label (0-199)
    - idx_to_wnid: Dict mapping original integer label to WNID
    - all_wnids_ordered: List of WNIDs, order defines the 0-199 mapping
"""
tiny_imagenet_dir = os.path.join(DATA_DIR, 'tiny-imagenet-200')
train_dir = os.path.join(tiny_imagenet_dir, 'train')
wnids_file = os.path.join(tiny_imagenet_dir, 'wnids.txt')

with open(wnids_file, 'r') as f:
    all_wnids_ordered = [line.strip() for line in f]

print(f"Number of classes in training set (all_wnids_ordered): {len(all_wnids_ordered)}")

wnid_to_idx = {wnid: i for i, wnid in enumerate(all_wnids_ordered)}
idx_to_wnid = {i: wnid for wnid, i in wnid_to_idx.items()}

sample_details = []
current_original_example_dataset_idx = 0

print("Scanning Tiny ImageNet training directory...")
for wnid in os.listdir(train_dir):
    if wnid not in wnid_to_idx:
        print(f"Warning: WNID '{wnid}' not found in wnids.txt. Skipping...")
        continue # Skip non-class folders like .DS_Store

    wnid_idx = wnid_to_idx[wnid]
    class_image_dir = os.path.join(train_dir, wnid, 'images')

    # Check if 'images' subdirectory exists, if not, check current wnid directory
    if not os.path.isdir(class_image_dir):
          img_files_dir = os.path.join(train_dir, wnid)
    else:
        img_files_dir = class_image_dir

    for img_name in os.listdir(img_files_dir):
        if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
            img_path = os.path.join(img_files_dir, img_name)
            sample_details.append({
                'path': img_path,
                'wnid_idx': wnid_idx, # index in the parent directory
                'original_example_dataset_idx': current_original_example_dataset_idx, # index as if all the examples from every class were in one folder
                'wnid': wnid
            })
            current_original_example_dataset_idx += 1

if not sample_details:
    raise FileNotFoundError(
        f"No images found. Please check the structure of {train_dir}. "
        f"Expected structure: {train_dir}/<wnid>/images/<image_file.JPEG> or {train_dir}/<wnid>/<image_file.JPEG>"
    )

print(f"Found {len(sample_details)} training images from {len(all_wnids_ordered)} classes.")

if not os.path.exists(tiny_imagenet_dir):
    print(f"Error: Tiny ImageNet directory not found at '{tiny_imagenet_dir}'.")
    print(f"Please download and extract Tiny ImageNet to the {DATA_DIR}")
    raise FileNotFoundError

num_total_original_classes = len(all_wnids_ordered) # Should be 200

logging.info(f"\nSplitting classes into In-Distribution (ID) and Out-of-Distribution (OOD)...")

all_original_class_indices = list(range(num_total_original_classes)) # [0, 1, ..., 199]
random.shuffle(all_original_class_indices)

num_of_ind_classes = int(IND_CLASS_RATIO * num_total_original_classes) # ex., 122

pretrain_classes_original_idxs = sorted(all_original_class_indices[:num_of_ind_classes])
ood_classes_original_idxs = sorted(all_original_class_indices[num_of_ind_classes:])

print(f"\nTotal classes: {num_total_original_classes}")
print(f"In-distribution (ID) classes selected (original indices): {len(pretrain_classes_original_idxs)}")
print(f"Out-of-distribution (OOD) classes selected (original indices): {len(ood_classes_original_idxs)}")

Number of classes in training set (all_wnids_ordered): 200
Scanning Tiny ImageNet training directory...
Found 100000 training images from 200 classes.

Total classes: 200
In-distribution (ID) classes selected (original indices): 160
Out-of-distribution (OOD) classes selected (original indices): 40


In [9]:
# ====================== 2. Create mapping for ID classes to new contiguous labels [0, num_of_ind_classes-1] ======================
ind_wnid_idx_mapping_from_original = {
    original_wnid_idx: new_idx for new_idx, original_wnid_idx in enumerate(pretrain_classes_original_idxs)
}

# 3. Group samples by their original class
samples_by_original_class = defaultdict(list)
for sample in sample_details:
    samples_by_original_class[sample['wnid_idx']].append(sample)

pretrained_ind_indices = []      # List of original_example_dataset_idx for ViT training (75% of ID)
pretrained_left_out_indices = [] # List of original_example_dataset_idx for ID data not used in ViT training (25% of ID)

print("\nSplitting samples for ID classes:")
for wnid_idx in pretrain_classes_original_idxs:
    class_samples = samples_by_original_class[wnid_idx]
    random.shuffle(class_samples) # Shuffle samples within the class

    num_samples_in_class = len(class_samples)
    num_for_pretrain = int(PRETRAIN_EXAMPLE_RATIO * num_samples_in_class)

    # Ensure at least one sample for pretraining if class is not empty
    if num_samples_in_class > 0 and num_for_pretrain == 0:
        num_for_pretrain = 1

    for i, sample in enumerate(class_samples):
        if i < num_for_pretrain:
            pretrained_ind_indices.append(sample['original_example_dataset_idx'])
        else:
            pretrained_left_out_indices.append(sample['original_example_dataset_idx'])

pretrained_ind_indices.sort()
pretrained_left_out_indices.sort()

print(f"Total samples for ViT pretraining (75% of ID): {len(pretrained_ind_indices)}")
print(f"Total ID samples left out from pretraining (25% of ID): {len(pretrained_left_out_indices)}")

# Store class information
class_info = {
    'num_of_classes': num_total_original_classes,
    'pretrain_classes': pretrain_classes_original_idxs, # List of original class indices (0-199) for ID
    'left_out_classes': ood_classes_original_idxs,   # List of original class indices (0-199) for OOD
    'pretrained_ind_indices': pretrained_ind_indices,
    'left_out_ind_indices': pretrained_left_out_indices,
    'pretrain_class_mapping': ind_wnid_idx_mapping_from_original, # Maps original ID class_idx -> new contiguous idx
    'wnid_to_idx': wnid_to_idx, # wnid -> original_idx (0-199)
    'idx_to_wnid': idx_to_wnid, # original_idx (0-199) -> wnid
    'all_wnids_ordered': all_wnids_ordered # Defines the 0-199 mapping
}

# Save all image paths and their original labels for easy access in the training script
# This avoids re-scanning the directory in the training script.
all_training_image_paths = [s['path'] for s in sample_details]
all_training_original_labels = [s['wnid_idx'] for s in sample_details]

data_to_save = {
    'class_info': class_info,
    'all_training_image_paths': all_training_image_paths,
    'all_training_original_labels': all_training_original_labels
}

save_file = os.path.join(SAVE_DIR, 'tiny_imagenet_data_info.pth')
torch.save(data_to_save, save_file)
print(f"\nPreprocessing complete. Data information saved to: {save_file}")

# For convenience, also save just the class_info as requested by user for OOD experiments
class_info_only_file = os.path.join(SAVE_DIR, 'class_info.pth')
torch.save(class_info, class_info_only_file)
print(f"Class_info (for OOD experiments) saved to: {class_info_only_file}")


Splitting samples for ID classes:
Total samples for ViT pretraining (75% of ID): 60000
Total ID samples left out from pretraining (25% of ID): 20000

Preprocessing complete. Data information saved to: /Users/alexc/Education/OMSCS/Masters_Project/msproject_repo/data/tiny-imagenet-200/preprocessed_tinyimagenet/tiny_imagenet_data_info.pth
Class_info (for OOD experiments) saved to: /Users/alexc/Education/OMSCS/Masters_Project/msproject_repo/data/tiny-imagenet-200/preprocessed_tinyimagenet/class_info.pth
