In [2]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import numpy as np

from torch import nn
from torchvision import transforms

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [4]:
def set_seeds(seed=42):
   torch.manual_seed(seed)
   random.seed(seed)
   np.random.seed(seed)
   torch.cuda.manual_seed_all(seed)  
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False

In [5]:
pre_trained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

pretrained_vit = torchvision.models.vit_b_16(weights=pre_trained_vit_weights).to(device)

for parameter in pretrained_vit.parameters():
   parameter.requires_grad = False

class_names = ['CBB', 'CBSD', 'CGM', 'CMD', 'Healthy']

set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)

In [6]:
from torchinfo import summary

summary(
   model = pretrained_vit, 
   input_size = (32, 3, 224, 224),
   col_names = ["input_size", "output_size", "num_params", "trainable"],
   col_width = 20,
   row_settings = ["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 5]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              False
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 

In [7]:
train_dir = '../../datasets/vipr_dataset/masked_3/train'
test_dir = '../../datasets/vipr_dataset/masked_3/test'

In [8]:
pretrained_vit_transforms = pre_trained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [9]:
import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
   train_dir: str,
   test_dir: str,
   transforms: transforms.Compose,
   batch_size: int,
   num_workers: int = NUM_WORKERS,
): 
   train_data = datasets.ImageFolder(root=train_dir, transform=transforms)
   test_data = datasets.ImageFolder(root=test_dir, transform=transforms)

   class_names = train_data.classes

   train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
   )

   test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
   )

   return train_dataloader, test_dataloader, class_names

In [10]:
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(
   train_dir=train_dir,
   test_dir=test_dir,
   transforms=pretrained_vit_transforms,
   batch_size=32,
)

In [11]:
from going_modular.going_modular import engine

optimizer = torch.optim.Adam(pretrained_vit.heads.parameters(), lr=1e-3)

loss_fn = torch.nn.CrossEntropyLoss()

set_seeds()
pretrained_vit = engine.train(
   model = pretrained_vit,
   train_dataloader = train_dataloader_pretrained,
   test_dataloader = test_dataloader_pretrained,
   optimizer = optimizer,
   loss_fn = loss_fn,
   epochs = 10,
   device = device
)

  0%|          | 0/10 [04:01<?, ?it/s]


KeyboardInterrupt: 