## Implementing the Vision Transformer

In [1]:
#imports
import torch
import torch.nn as nn
from torch.optim import Adam


import torchvision as tv
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import time
import os
import copy
import requests
import io

import timm 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%run ./torch.ipynb

In [3]:
# Setting the device, either GPU cluster or cpu.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
# if torch.backends.mps.is_available():
#    device = torch.device("mps")
#    x = torch.ones(1, device=device)
#    print (x)
# else:
#    print ("MPS device not found.")
#    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

Using device:  cpu 


In [4]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
num_classes = 15
model.head = nn.Linear(model.head.in_features, num_classes)
model.to(device)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [5]:
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=1e-5)

In [6]:
def train_model(model, criterion, optimizer, loader_train, loader_val, num_epochs=10):
    model.train()  # Set model to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in loader_train:
            # inputs, labels = inputs.to(device), labels.to(device)
            inputs, labels = inputs.to(device, dtype=torch.float32), labels.to(device, dtype=torch.float32)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(loader_train.dataset)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')

        # Validation phase
        validate_model(model, loader_val)

def validate_model(model, loader_val, threshold=0.5):
    model.eval()
    total_samples = 0
    total_correct = 0
    
    with torch.no_grad():
        for inputs, labels in loader_val:
            # inputs, labels = inputs.to(device), labels.to(device)
            inputs, labels = inputs.to(device, dtype=torch.float32), labels.to(device, dtype=torch.float32)
            outputs = model(inputs)
            predicted = outputs.sigmoid() > threshold  # Apply sigmoid and threshold to convert logits to binary predictions
            total_correct += (predicted == labels.byte()).sum().item()  # Correct predictions per label
            total_samples += labels.numel()  # Total number of label predictions

    accuracy = total_correct / total_samples * 100
    print(f'Validation Accuracy: {accuracy:.2f}%')


# Assuming you have your dataloaders ready
train_model(model, criterion, optimizer, loader_train, loader_val)

KeyboardInterrupt: 