In [16]:
import os
import sys

import torch
import torchvision
import torchmetrics

from torchinfo import summary

from tqdm.auto import tqdm

from pathlib import Path
from typing import Dict, List

super_directory = os.path.abspath('..')
sys.path.append(super_directory)

from data_setup import data_download, get_dataloaders
from vit import ViT
from utils import create_writer

In [17]:
# Device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [18]:
# Hyperparameters
COLOR_CHANNELS = 3
HEIGHT_WIDTH = (224, 224)              # resized to 224 in get_dataloaders function

BATCH_SIZE = 32

PATCH_SIZE = (16, 16)
NUM_PATCHES = int((HEIGHT_WIDTH[0] / PATCH_SIZE[0]) ** 2)

EMBED_DIMS = 768
NUM_ATTN_HEADS = 12
RATIO_HIDDEN_MLP = 4
NUM_ENC_BLOCKS = 12

NUM_EPOCHS = 0

# Data

In [19]:
# Donwload data if it hasn't been downloaded
data_path = Path('../data/')

if not data_path.is_dir():
    data_download()

In [20]:
# Dataloaders
train_path = data_path / 'desserts' / 'train'
test_path = data_path / 'desserts' / 'test'

train_dataloader, test_dataloader, class_labels = get_dataloaders(train_path=train_path,
                                                                  test_path=test_path,
                                                                  batch_size=BATCH_SIZE)

# Model

In [21]:
# Instantiate model and verify structure
model_train_fs = ViT(in_channels=3,
                     out_dims=len(class_labels),
                     patch_size=PATCH_SIZE,
                     num_patches=NUM_PATCHES,
                     embed_dims=EMBED_DIMS,
                     num_attn_heads=NUM_ATTN_HEADS,
                     ratio_hidden_mlp=RATIO_HIDDEN_MLP,
                     num_encoder_blocks=NUM_ENC_BLOCKS)

summary(model_train_fs,
        input_size=(32, 3, 224, 224),   # Batch dim, color channels, height, width
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=20,
        row_settings=['var_names'])

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                               [32, 3, 224, 224]    [32, 5]              --                   True
├─DataEmbeddings (data_embeddings)                      [32, 3, 224, 224]    [32, 197, 768]       152,064              True
│    └─Conv2d (conv_layer)                              [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
│    └─Flatten (flatten)                                [32, 768, 14, 14]    [32, 768, 196]       --                   --
├─Sequential (encoder_blocks)                           [32, 197, 768]       [32, 197, 768]       --                   True
│    └─EncoderBlock (0)                                 [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─LayerNorm (layer_norm)                      [32, 197, 768]       [32, 197, 768]       1,536                True
│    

In [22]:
# Model from torchvision
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
model_finetune = torchvision.models.vit_b_16(weights=weights)

summary(model_finetune,
        input_size=(32, 3, 224, 224),   # Batch dim, color channels, height, width
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=20,
        row_settings=['var_names'])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 1000]           768                  True
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              True
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       7,087,872            True
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 197, 76

In [23]:
model_finetune.heads = torch.nn.Linear(in_features=EMBED_DIMS,
                                       out_features=len(class_labels))

for param in model_finetune.conv_proj.parameters():
        param.requires_grad = False

for param in model_finetune.encoder.parameters():
        param.requires_grad = False

summary(model_finetune,
        input_size=(32, 3, 224, 224),   # Batch dim, color channels, height, width
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=20,
        row_settings=['var_names'])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 5]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              False
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 

# Train

In [24]:
# Loss, optimizer, accuracy
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_train_fs.parameters(),
                             lr = 0.01)

accuracy_function = torchmetrics.Accuracy(task='multiclass', 
                                          num_classes=len(class_labels))

In [25]:
def train_epoch(model: torch.nn.Module,
                train_dataloader: torch.utils.data.DataLoader,
                loss_function: torch.nn.Module,
                optimizer: torch.optim.Optimizer,
                accuracy_function: torchmetrics.Accuracy,
                device: torch.device) -> Dict[str, float]:
    # Model to device
    model = model.to(device)

    # Model to train mode
    model = model.train()

    # Track avg loss
    train_loss = 0
    train_acc = 0

    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device)
        
        # Forward pass -> loss -> zero grad -> back prop -> gradient descent
        y_logits = model(X)
        loss = loss_function(y_logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accuracy
        accuracy_function = accuracy_function.to(device)
        y_preds = torch.argmax(torch.softmax(y_logits, dim=1), dim=1).squeeze()
        accuracy = accuracy_function(y_preds, y)
        
        # Accumulate
        train_loss += loss
        train_acc += accuracy

    # Average per batch
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)    
    return {'train_loss': train_loss.item(),
            'train_acc': train_acc.item()}

In [26]:
# train_results = train_epoch(model=model,
#                             train_dataloader=train_dataloader,
#                             loss_function=loss_function,
#                             optimizer=optimizer,
#                             accuracy_function=accuracy_function,
#                             device=device)
# print(train_results)
# torch.cuda.empty_cache()

In [27]:
def test_epoch(model: torch.nn.Module,
               test_dataloader: torch.utils.data.DataLoader,
               loss_function: torch.nn.Module,
               accuracy_function: torchmetrics.Accuracy,
               device: torch.device) -> Dict[str, float]:
    # Model to device
    model = model.to(device)

    # Set model to evaluation mode
    model = model.eval()

    # Track avg loss
    test_loss = 0
    test_acc = 0

    for X, y in test_dataloader:
        X = X.to(device)
        y = y.to(device)

        # With inference to save cuda memory
        with torch.inference_mode():
            # Loss
            y_logits = model(X)
            loss = loss_function(y_logits, y)
            
            # Accuracy
            accuracy_function = accuracy_function.to(device)
            y_preds = torch.argmax(torch.softmax(y_logits, dim=1), dim=1).squeeze()
            accuracy = accuracy_function(y_preds, y)
            
            # Accumulate
            test_loss += loss
            test_acc += accuracy

    # Average per batch
    with torch.inference_mode():
        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)
    
    return {'test_loss': test_loss.item(),
            'test_acc': test_acc.item()}

In [28]:
# test_results = test_epoch(model=model,
#                           test_dataloader=test_dataloader,
#                           loss_function=loss_function,
#                           accuracy_function=accuracy_function,
#                           device=device)
# print(test_results)
# torch.cuda.empty_cache()

In [29]:
def train(num_epochs: int,
          model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          loss_function: torch.nn.Module,
          optimizer: torch.optim.Optimizer,
          accuracy_function: torchmetrics.Accuracy,
          device: torch.device,
          writer: torch.utils.tensorboard.writer.SummaryWriter) -> Dict[str, List[float]]:
    
      results = {
            'train_loss': [],
            'train_acc': [],
            'test_loss': [],
            'test_acc': []
      }
    
      for epoch in tqdm(range(num_epochs)):
            print("-"*50 + "\n")
            
            # Train for one epoch
            train_result = train_epoch(model=model,
                                          train_dataloader=train_dataloader,
                                          loss_function=loss_function,
                                          optimizer=optimizer,
                                          accuracy_function=accuracy_function,
                                          device=device)
            
            # Do testing after one epoch
            test_result = test_epoch(model=model,
                                    test_dataloader=test_dataloader,
                                    loss_function=loss_function,
                                    accuracy_function=accuracy_function,
                                    device=device)
            
            # Print results
            print(f"Epoch: {epoch}  |  Train Loss: {train_result['train_loss']:.2f}  |  Test Loss: {test_result['test_loss']:.2f}  |  Train Accuracy: {train_result['train_acc']:.2f}  |  Test Accuracy: {test_result['test_acc']:.2f}")
            
            # Track results
            results['train_loss'].append(train_result['train_loss'])
            results['train_acc'].append(train_result['train_acc'])
            results['test_loss'].append(test_result['test_loss'])
            results['test_acc'].append(test_result['test_acc'])
            
            # Using tensorboard writer for result tracking
            writer.add_scalars(main_tag="Loss",
                               tag_scalar_dict={'train_loss': train_result['train_loss'],
                                                'test_loss': test_result['test_loss']},
                               global_step=epoch)
            
            writer.add_scalars(main_tag="Accuracy",
                               tag_scalar_dict={'train_acc': train_result['train_acc'],
                                                'test_acc': test_result['test_acc']},
                               global_step=epoch)
            
            # Empty cuda cache for memory management
            torch.cuda.empty_cache()
            
      writer.close()
      
      return results

In [30]:
writer = create_writer(model_name='custom_vit',
                       experiment_name='test')

results = train(num_epochs=NUM_EPOCHS,
                model=model_train_fs,
                train_dataloader=train_dataloader,
                test_dataloader=test_dataloader,
                loss_function=loss_function,
                optimizer=optimizer,
                accuracy_function=accuracy_function,
                device=device,
                writer=writer)

[INFO] Created SummaryWriter, saving to: runs/24-11-25/custom_vit/test...


0it [00:00, ?it/s]