In [None]:
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
from tqdm.notebook import tqdm

# --- Configuration ---
BATCH_SIZE = 32
MODEL_NAME = 'vit_tiny_patch16_224'
LEARNING_RATE = 0.001
NUM_EPOCHS = 10

# --- Data Transforms ---
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- Data Loading ---
train_dir = 'vit_data/train'
val_dir = 'vit_data/val'

image_datasets = {x: datasets.ImageFolder(os.path.join('vit_data', x), data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

print(f"Dataset sizes: {dataset_sizes}")
print(f"Class names: {class_names}")
print(f"Number of classes: {len(class_names)}")

NameError: name 'BATCH_SIZE' is not defined

In [None]:
# --- Model Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create ViT model
model = timm.create_model(MODEL_NAME, pretrained=True)

# Freeze backbone
for param in model.parameters():
    param.requires_grad = False

# Replace head
num_classes = len(class_names)
num_ftrs = model.head.in_features
model.head = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

# Optimizer (only for head) & Loss
optimizer = optim.Adam(model.head.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

print(f"Model initialized. Training only the head.")

model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

Using device: cpu
Number of classes: 16


In [None]:
# --- Training Function ---
def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} Phase'):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # Simple prediction logic
                    probabilities = torch.softmax(outputs, dim=1)
                    predicted_classes = torch.argmax(probabilities, dim=1)
                    
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predicted_classes == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

# --- Start Training ---
model_ft = train_model(model, criterion, optimizer, num_epochs=NUM_EPOCHS)

Epoch 0/9
----------


KeyboardInterrupt: 