In [9]:
import sys
import os
from dotenv import load_dotenv
load_dotenv()
ROOT_DIR_PATH = os.environ.get('ROOT_PATH')
sys.path.append(os.path.abspath(ROOT_DIR_PATH)) 
import urllib.request
import zipfile
import shutil

DESTINATION_PATH =f'{ROOT_DIR_PATH}/data/TINYIMAGENET/'
ZIP_NAME = 'tiny-imagenet-200.zip'


def download_tiny_imagenet(save_path=f'{DESTINATION_PATH}{ZIP_NAME}'):
    url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    print(f"Downloading TinyImageNet to {save_path}...")
    urllib.request.urlretrieve(url, save_path)
    print("Download complete!")

def extract_dataset(zip_path=f'{DESTINATION_PATH}{ZIP_NAME}', extract_to=DESTINATION_PATH):
    print(f"Extracting {zip_path} to {extract_to}...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print("Extraction complete!")

def rearrange_val_folder(base_dir=f'{DESTINATION_PATH}tiny-imagenet-200/val'):
    print(f"Reorganizing validation folder at {base_dir}...")
    img_dir = os.path.join(base_dir, 'images')
    ann_file = os.path.join(base_dir, 'val_annotations.txt')

    with open(ann_file, 'r') as f:
        for line in f:
            file_name, class_name = line.split('\t')[:2]
            class_dir = os.path.join(base_dir, class_name)
            os.makedirs(class_dir, exist_ok=True)
            src = os.path.join(img_dir, file_name)
            dst = os.path.join(class_dir, file_name)
            if os.path.exists(src):
                shutil.move(src, dst)

    shutil.rmtree(img_dir)
    print("Validation images reorganized!")

def main():
    dst_dir = DESTINATION_PATH
    zip_fname = f"{DESTINATION_PATH}{ZIP_NAME}"

    if not os.path.isdir(dst_dir):
        print('creating the destination dir. & downloading')
        os.makedirs(os.path.dirname(dst_dir), exist_ok=True)
        download_tiny_imagenet(zip_fname)
        extract_dataset(zip_path=zip_fname, extract_to=os.path.dirname(dst_dir))
        rearrange_val_folder(base_dir=f'{DESTINATION_PATH}tiny-imagenet-200/val')
        print(f"TinyImageNet is ready under: {dst_dir}/")
    else:
        print(f"Dataset directory {dst_dir} already exists, zip downloaded.")

    if not os.path.isdir(f'{DESTINATION_PATH}tiny-imagenet-200/'):
        extract_dataset(zip_path=zip_fname, extract_to=os.path.dirname(dst_dir))
        rearrange_val_folder(base_dir=f'{DESTINATION_PATH}tiny-imagenet-200/val')
        print(f"TinyImageNet is ready under: {dst_dir}")
    else : print(f'TinyImageNet already exists under: {dst_dir}')


if __name__ == '__main__':
    main()

Dataset directory /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/ already exists, zip downloaded.
TinyImageNet already exists under: /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/


In [4]:
dst_dir = DESTINATION_PATH
zip_fname = f"{DESTINATION_PATH}{ZIP_NAME}"

In [6]:
def main():
    # You can customize this path
    dst_dir = DESTINATION_PATH
    zip_fname = f"{DESTINATION_PATH}{ZIP_NAME}"

    if not os.path.isdir(dst_dir):
        print('creating the destination dir. & downloading')
        os.makedirs(os.path.dirname(dst_dir), exist_ok=True)
        download_tiny_imagenet(zip_fname)
        extract_dataset(zip_path=zip_fname, extract_to=os.path.dirname(dst_dir))
        
        print(f"TinyImageNet is ready under: {dst_dir}/")
    else:
        print(f"Dataset directory {dst_dir}/ already exists. Skipping download.")

if __name__ == '__main__':
    main()


creating the destination dir. & downloading
Downloading TinyImageNet to /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/tinyimagenet200.zip...


KeyboardInterrupt: 

In [None]:
def rearrange_val_folder(base_dir=f'{EXTRACT_PATH}/tiny-imagenet-200/val'):
    print(f"Reorganizing validation folder at {base_dir}...")
    img_dir = os.path.join(base_dir, 'images')
    ann_file = os.path.join(base_dir, 'val_annotations.txt')

    with open(ann_file, 'r') as f:
        for line in f:
            file_name, class_name = line.split('\t')[:2]
            class_dir = os.path.join(base_dir, class_name)
            os.makedirs(class_dir, exist_ok=True)
            src = os.path.join(img_dir, file_name)
            dst = os.path.join(class_dir, file_name)
            if os.path.exists(src):
                shutil.move(src, dst)

    shutil.rmtree(img_dir)
    print("Validation images reorganized!")

# training test

In [11]:
import sys
import os
from dotenv import load_dotenv
load_dotenv()
ROOT_DIR_PATH = os.environ.get('ROOT_PATH')
sys.path.append(os.path.abspath(ROOT_DIR_PATH))  # Adds root directory to sys.path

import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.vit import VisionTransformerSmall
from utils.model_io import save_model
from utils.config_loader import load_config
from utils.data_loader import DatasetLoader
from pynvml import (
    nvmlInit, nvmlDeviceGetName, nvmlShutdown,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetMemoryInfo,
    nvmlDeviceGetUtilizationRates
)
from torch.optim.lr_scheduler import CosineAnnealingLR
from timm.data import Mixup
import numpy as np
from transformers import get_cosine_schedule_with_warmup

def train_one_epoch(model, loader, criterion, optimizer, device, 
                    mixup_fn=None, scheduler_warmup_enabled=False, scheduler_warmup=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(loader, desc="Training", leave=True)
    for  inputs, targets in progress_bar:
        #print(f'input shape : {inputs.shape}, taget_shape : {targets.shape}, target dim : {targets.ndim}')
        inputs, targets = inputs.to(device), targets.to(device)
        if mixup_fn is not None:
            inputs, targets = mixup_fn(inputs, targets)

        if targets.ndim == 2:
            targets = targets.type_as(inputs)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        if scheduler_warmup_enabled:
            if scheduler_warmup is None : raise Exception(f'scheduler warmup is enabled, but no scheduler object has been passed in train_one_epoch function')
            scheduler_warmup.step()

        running_loss += loss.item() * inputs.size(0)

        if targets.ndim == 2:
            # MixUp with soft labels
            _, predicted = outputs.max(1)
            _, true_classes = targets.max(1)  # Take argmax of soft labels as true class
            correct += predicted.eq(true_classes).sum().item()
            total += targets.size(0)
        else :
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

        # Update progress bar with metrics
        if total > 0:
            avg_loss = running_loss / total
            accuracy = 100. * correct / total
            progress_bar.set_postfix({
                "Loss": f"{avg_loss:.4f}",
                "Acc": f"{accuracy:.2f}%"
            })

        else : raise Exception(f'Expected non-zero batch size, but got 0 targets. Check if the dataset is empty or DataLoader is misconfigured.')

    
    return avg_loss, accuracy

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    progress_bar = tqdm(loader, desc="Validation", leave=True)
    with torch.no_grad():
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            # Compute accuracy
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # Avoid division by zero on first step
            if total > 0:
                avg_loss = running_loss / total
                accuracy = 100. * correct / total

                progress_bar.set_postfix({
                    "Loss": f"{avg_loss:.4f}",
                    "Acc": f"{accuracy:.2f}%"
                })
                
    return avg_loss, accuracy



  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# def main():
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load config
config = load_config(f"{ROOT_DIR_PATH}/config/vit_config.yaml")
# loading cifar100
#cifar100_config = config["data"]['CIFAR100']
dataset_config = config["data"]['TINYIMAGENET']
DATASET = dataset_config["dataset"]
DATA_DIR = dataset_config["data_path"]
BATCH = dataset_config["batch_size"]
NUM_WORKERS = dataset_config["num_workers"]
IMAGE = dataset_config["img_size"]
NUM_CLASSES = dataset_config["num_classes"]
CHANNELS = dataset_config["channels"]

Using device: cuda


In [None]:
# loading data
print(f'loading dataset : {DATASET}')
loader = DatasetLoader(dataset_name=DATASET,
                        data_dir=DATA_DIR,
                        batch_size=BATCH,
                        num_workers=NUM_WORKERS,
                        img_size=IMAGE)
train_loader, val_loader = loader.get_loaders()
print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")
print('data sanity check')
for images, labels in train_loader:
    print(f'image shape and labels shape in training data - one batch : {images.shape}, {labels.shape}')
    break

loading dataset : TINYIMAGENET
Dataset directory /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/ already exists, zip downloaded.
TinyImageNet already exists under: /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/
Dataset directory /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/ already exists, zip downloaded.
TinyImageNet already exists under: /home/wd/Documents/work_stuff/ViT_REPLICATION/data/TINYIMAGENET/
Train batches: 782, Validation batches: 79
data sanity check
image shape and labels shape in training data - one batch : torch.Size([128, 3, 64, 64]), torch.Size([128])


: 