In [None]:
import sys
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

image_size =224
tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm  # PyTorch Image Models

# ----------------------------
# Configuration and Parameters
# ----------------------------
num_classes = 200  # Tiny ImageNet has 200 classes
batch_size = 64
num_workers = 4
num_epochs = 200
learning_rate = 1e-3
weight_decay = 0.05
step_size = 10  # epochs after which to decay the learning rate
gamma = 0.1


# ----------------------------
# Model, Loss, Optimizer, Scheduler
# ----------------------------
# Create a Swin Transformer model from timm. Pretrained weights are loaded and the final layer is adapted.
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=num_classes)

# Move model to GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Print the number of trainable parameters.
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params:,}")

# Loss function and optimizer.
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# ----------------------------
# Training and Validation Loop
# ----------------------------
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)

    epoch_loss = running_loss / total_samples
    epoch_acc = total_correct / total_samples
    print(f"Epoch [{epoch+1}/{num_epochs}] Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_samples = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_samples += images.size(0)

    val_epoch_loss = val_loss / val_samples
    val_epoch_acc = val_correct / val_samples
    print(f"Epoch [{epoch+1}/{num_epochs}] Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.4f}")

    scheduler.step()  # Adjust learning rate

# Save the trained model.
torch.save(model.state_dict(), "swin_tiny_imagenet.pth")
print("Training complete and model saved.")


Number of trainable parameters: 27,673,154


KeyboardInterrupt: 