In [1]:
import torch

from data_setup import create_dataloaders
from engine import train_model
from pathlib import Path
from torch import nn
from torchmetrics import Accuracy
from torchvision.models import ViT_B_16_Weights, vit_b_16
from utils import create_writer

DEVICE = 'cuda' if torch.cuda.is_available else 'cpu'
TRANSFORMERS = ViT_B_16_Weights.DEFAULT.transforms()
MODEL = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
MODEL.conv_proj = MODEL.conv_proj.requires_grad_(False)
MODEL.encoder = MODEL.encoder.requires_grad_(False)

In [2]:
imgs_path = Path('../data/restrc-oxford-iiit-pet/')
train_dir = imgs_path / 'train'
test_dir = imgs_path / 'test'

In [None]:
train_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    train_transform=TRANSFORMERS
)

MODEL.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=len(class_names), bias=True)
)

optimizer = torch.optim.Adam(MODEL.parameters(),
                             lr=1e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.1)
loss_fn = nn.CrossEntropyLoss()
metric_fn = Accuracy(task='multiclass', num_classes=len(class_names))
experiment_results_1 = train_model(model=MODEL,
                                   train_data=train_dataloader,
                                   test_data=test_dataloader,
                                   loss_fn=loss_fn,
                                   optimizer=optimizer,
                                   metric_fn=metric_fn,
                                   epochs=5,
                                   random_seed=42,
                                   verbose=1,
                                   writer=create_writer('cat-dog-b-exp_1',
                                                        model_name='cat-dog-b',
                                                        extra=f'5_epochs'))
experiment_results_1

  0%|          | 0/5 [00:00<?, ?it/s]