# TinyViT Training 
with Pretrained Weights and Two-Stage Fine-tuning

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="libs.tiny_vit.tiny_vit")

import os
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import classification_report
#from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from libs.common import load_dict, dump_dict
from libs.albumentations_utils import AlbumentationsTransform
from libs.tiny_vit.tiny_vit_train import get_model, load_pretrained_weights, freeze_backbone, unfreeze_all, \
    set_weight_decay, get_cosine_scheduler_with_warmup, train_epoch, validate_epoch, plot_confusion_matrix, EarlyStopping


## Settings

In [None]:
settings = load_dict('settings_tinyvit.yaml')

# Load base model configuration from tiny_vit.json
base_model = settings['base_model']
variants_path = "libs/tiny_vit/tiny_vit.json"
variants = load_dict(variants_path)

if base_model not in variants:
    raise ValueError(f"Base model '{base_model}' not found in available variants: {list(variants.keys())}")

base_model_info = variants[base_model]

# Extract model configuration and paths from base model
model_config = base_model_info['model_config']
pretrained_path = base_model_info['weights']
img_size = model_config['img_size']

print(f"Using base model: {base_model}")
print(f"Image size: {img_size}x{img_size}")
print(f"Pretrained weights: {pretrained_path}")

# Directory settings
data_dir = settings['data_dir']
output_dir = settings['output_dir']
logs_dir = settings['logs_dir']

# Output file names
checkpoint_name = settings['checkpoint_name']

# Derived directory paths
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

# Create directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Count classes
num_classes = sum(os.path.isdir(os.path.join(train_dir, entry)) for entry in os.listdir(train_dir))
print("num_classes:", num_classes)

# Normalization settings
mean = np.array(settings['mean'])  # np.array(IMAGENET_DEFAULT_MEAN)
std = np.array(settings['std'])    # np.array(IMAGENET_DEFAULT_STD)

# Stage 1 training parameters (Head training with frozen backbone)
stage1_epochs = settings['stage1']['epochs']
stage1_lr = float(settings['stage1']['learning_rate'])
stage1_warmup_epochs = settings['stage1']['warmup_epochs']
stage1_min_lr = float(settings['stage1']['min_lr'])
stage1_batch_size = settings['stage1']['batch_size']

# Stage 2 training parameters (Full fine-tuning)  
stage2_epochs = settings['stage2']['epochs']
stage2_lr = float(settings['stage2']['learning_rate'])
stage2_warmup_epochs = settings['stage2']['warmup_epochs']
stage2_min_lr = float(settings['stage2']['min_lr'])
stage2_batch_size = settings['stage2']['batch_size']

# Additional training settings
layer_lr_decay = float(settings['layer_lr_decay'])
weight_decay = float(settings['weight_decay'])
use_amp = settings['use_amp']
patience = settings['patience']
gradient_clip_norm = float(settings['gradient_clip_norm'])
eval_bn = settings['eval_bn']

# Optimizer settings
optimizer_config = settings['optimizer']
optimizer_config['eps'] = float(optimizer_config['eps'])
optimizer_config['betas'] = [float(b) for b in optimizer_config['betas']]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# TensorBoard writer
writer = SummaryWriter(logs_dir)

## transforms and dataloaders

In [None]:
# Get augmentation settings from config
aug_config = settings['augmentation']

train_transform = A.Compose([
    A.Resize(img_size, img_size),
    A.HorizontalFlip(p=aug_config['horizontal_flip']),
    A.Rotate(limit=aug_config['rotation_limit'], p=aug_config['rotation_prob']),
    A.ColorJitter(
        brightness=aug_config['color_jitter']['brightness'], 
        contrast=aug_config['color_jitter']['contrast'], 
        saturation=aug_config['color_jitter']['saturation'], 
        hue=aug_config['color_jitter']['hue'], 
        p=aug_config['color_jitter']['prob']
    ),
    A.OneOf([
        A.GaussNoise(
            var_limit=aug_config['noise_and_blur']['gaussian_noise']['var_limit'], 
            p=aug_config['noise_and_blur']['gaussian_noise']['prob']
        ),
        A.GaussianBlur(
            blur_limit=aug_config['noise_and_blur']['gaussian_blur']['blur_limit'], 
            p=aug_config['noise_and_blur']['gaussian_blur']['prob']
        ),
    ], p=aug_config['noise_and_blur']['prob']),
    A.CoarseDropout(
        max_holes=aug_config['coarse_dropout']['max_holes'], 
        max_height=aug_config['coarse_dropout']['max_height'], 
        max_width=aug_config['coarse_dropout']['max_width'], 
        p=aug_config['coarse_dropout']['prob']
    ),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

eval_transform = A.Compose([
    A.Resize(img_size, img_size),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

# Datasets and Dataloaders
train_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "train"),
    transform=AlbumentationsTransform(train_transform)
)

val_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "val"),
    transform=AlbumentationsTransform(eval_transform)
)

test_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "test"),
    transform=AlbumentationsTransform(eval_transform)
)

train_loader = DataLoader(train_dataset, batch_size=stage1_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=stage1_batch_size)
test_loader = DataLoader(test_dataset, batch_size=stage1_batch_size)

# Get class names and create class labels mapping
class_labels = train_dataset.classes
print(f"Classes found: {class_labels}")

# Build model using integrated model configuration
model = get_model(model_config, num_classes, img_size, device)

# Load pretrained weights
model = load_pretrained_weights(model, pretrained_path, num_classes, device)

# Initialize mixed precision scaler
scaler = GradScaler('cuda') if use_amp else None

# Loss function
criterion = nn.CrossEntropyLoss()

## STAGE 1: Training classifier head with frozen backbone

In [None]:
# Freeze backbone and train only the head
model = freeze_backbone(model)

# Optimizer and scheduler for stage 1 with improved weight decay
param_groups = set_weight_decay(model, weight_decay=weight_decay)
optimizer = optim.AdamW(
    param_groups, 
    lr=stage1_lr, 
    eps=optimizer_config['eps'], 
    betas=optimizer_config['betas']
)

# Cosine annealing scheduler with warmup
total_steps = stage1_epochs * len(train_loader)
warmup_steps = stage1_warmup_epochs * len(train_loader)
scheduler = get_cosine_scheduler_with_warmup(optimizer, warmup_steps, total_steps, stage1_min_lr)

# Early stopping for stage 1
early_stopping = EarlyStopping(patience=patience)

# Stage 1 training loop
global_step = 0
for epoch in range(stage1_epochs):
    print(f"\nEpoch {epoch+1}/{stage1_epochs}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, eval_bn=eval_bn, gradient_clip_norm=gradient_clip_norm)
    
    # Validate
    val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader, criterion, device)
    
    # Log to TensorBoard
    writer.add_scalar('Stage1/Train_Loss', train_loss, epoch)
    writer.add_scalar('Stage1/Train_Acc', train_acc, epoch)
    writer.add_scalar('Stage1/Val_Loss', val_loss, epoch)
    writer.add_scalar('Stage1/Val_Acc', val_acc, epoch)
    writer.add_scalar('Stage1/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Step scheduler
    scheduler.step()
    
    # Early stopping check
    if early_stopping(val_loss, model):
        print(f"Early stopping triggered at epoch {epoch+1}")
        break
    
    global_step += 1

## STAGE 2: Fine-tuning entire network

In [None]:
# Unfreeze all parameters and fine-tune
model = unfreeze_all(model)

# Create new data loaders with smaller batch size for stage 2
train_loader_stage2 = DataLoader(train_dataset, batch_size=stage2_batch_size, shuffle=True)
val_loader_stage2 = DataLoader(val_dataset, batch_size=stage2_batch_size)

# New optimizer and scheduler for stage 2 with lower learning rate
param_groups = set_weight_decay(model, weight_decay=weight_decay)
optimizer = optim.AdamW(
    param_groups, 
    lr=stage2_lr, 
    eps=optimizer_config['eps'], 
    betas=optimizer_config['betas']
)

# Cosine annealing scheduler with warmup for stage 2
total_steps = stage2_epochs * len(train_loader_stage2)  
warmup_steps = stage2_warmup_epochs * len(train_loader_stage2)
scheduler = get_cosine_scheduler_with_warmup(optimizer, warmup_steps, total_steps, stage2_min_lr)

# Reset early stopping for stage 2
early_stopping = EarlyStopping(patience=patience)

# Stage 2 training loop
for epoch in range(stage2_epochs):
    print(f"\nEpoch {epoch+1}/{stage2_epochs}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader_stage2, criterion, optimizer, device, scaler, use_amp, eval_bn=eval_bn, gradient_clip_norm=gradient_clip_norm)
    
    # Validate
    val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader_stage2, criterion, device)
    
    # Log to TensorBoard
    writer.add_scalar('Stage2/Train_Loss', train_loss, global_step + epoch)
    writer.add_scalar('Stage2/Train_Acc', train_acc, global_step + epoch)
    writer.add_scalar('Stage2/Val_Loss', val_loss, global_step + epoch)
    writer.add_scalar('Stage2/Val_Acc', val_acc, global_step + epoch)
    writer.add_scalar('Stage2/Learning_Rate', optimizer.param_groups[0]['lr'], global_step + epoch)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Step scheduler
    scheduler.step()
    
    # Early stopping check
    if early_stopping(val_loss, model):
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

In [None]:
# Save final model with metadata
model_save_dict = {
    'model_state_dict': model.state_dict(),
    'num_classes': num_classes,
    'class_labels': class_labels,
    'img_size': img_size,
    'pretrained_path': pretrained_path
}

checkpoint_savepath = os.path.join(output_dir, checkpoint_name)
torch.save(model_save_dict, checkpoint_savepath)

print(f"\nModel saved to {checkpoint_savepath}")


## EVALUATION ON TEST SET

In [None]:
model.eval()
test_correct, test_total = 0, 0
test_preds, test_labels = [], []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_acc = test_correct / test_total
print(f"Test Accuracy: {test_acc:.4f}")

# Log final test accuracy
writer.add_scalar('Final/Test_Acc', test_acc, 0)

# Generate dynamic file names based on checkpoint_name
checkpoint_base = os.path.splitext(checkpoint_name)[0]  # Remove .pth extension

# Generate confusion matrix
cm_path = os.path.join(logs_dir, f'{checkpoint_base}_confusion_matrix.png')
plot_confusion_matrix(test_labels, test_preds, class_labels, cm_path)
print(f"Confusion matrix saved to {cm_path}")

# Generate classification report
report = classification_report(test_labels, test_preds, target_names=class_labels)
print("\nClassification Report:")
print(report)

# Save classification report
report_path = os.path.join(logs_dir, f'{checkpoint_base}_classification_report.txt')
with open(report_path, 'w') as f:
    f.write(report)

# Also save the training information as metadata
summary_info = {
    'final_test_accuracy': float(test_acc),
    'num_classes': num_classes,
    'class_labels': class_labels,
    'model_architecture': 'TinyViT',
    'base_model': base_model,
    'pretrained_checkpoint': pretrained_path,
    'training_stages': {
        'stage1_epochs': stage1_epochs,
        'stage1_lr': stage1_lr,
        'stage2_epochs': stage2_epochs,
        'stage2_lr': stage2_lr
    },
    'image_size': img_size,
    'stage1_batch_size': stage1_batch_size,
    'stage2_batch_size': stage2_batch_size
}

summary_path = os.path.join(output_dir, f'{checkpoint_base}_metadata.json')
dump_dict(summary_info, summary_path)
print(f"Training summary saved to {summary_path}")

writer.close()
print("Training complete!")