In [1]:
#https://pytorch.org/vision/stable/models/generated/torchvision.models.vit_h_14.html

#ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1:

#These weights are learnt via transfer learning by end-to-end fine-tuning the original SWAG weights on ImageNet-1K data. Also available as ViT_H_14_Weights.DEFAULT.

#Perform the following preprocessing operations: 
#Accepts PIL.Image, 
#batched (B, C, H, W) and single (C, H, W) image torch.Tensor objects. 
#The images are resized to resize_size=[518] using interpolation=InterpolationMode.BICUBIC, 
#followed by a central crop of crop_size=[518].
#Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
#      transforms=partial(
#             ImageClassification,
#             crop_size=518,
#             resize_size=518,
#             interpolation=InterpolationMode.BICUBIC,
#         )

import torchvision
import torch.nn as nn
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#print(f"Using device: {device}")

# Download pretrained ViT weights and model
vit_checkpoint = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
print(f"model starting weights: {vit_checkpoint}")

model = torchvision.models.vit_h_14(weights=vit_checkpoint, progress=False)
# Freeze the model parameters to perform fine tuning only on the last layer (classifier)
for par in model.parameters():
    par.requires_grad = False

model = torchvision.models.vit_h_14(weights=vit_checkpoint, progress=False)

num_ftrs = model.heads[-1].in_features
print(f"num_ftrs : {num_ftrs}")

num_output = 1
model.heads[-1] = torch.nn.Linear(num_ftrs, num_output)

model = nn.DataParallel(model) 

optimizer = optim.Adam(model.module.heads.parameters(), lr=1e-4)

# For one-hot encoding use the BCE with logits loss function
criterion = nn.BCEWithLogitsLoss()
    



model starting weights: ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
num_ftrs : 1280


In [2]:
def train_step(trn_dl,model, criterion, optimizer):
    model.train()
    total_loss = 0
    total_samples = 0
    corrects = 0
    # Iterate over data.
    for i, (x, t) in enumerate(trn_dl):
        print(i, x.shape, t.shape)
        # zero the parameter gradients
        optimizer.zero_grad()
        y_hat = model(x)
        y_hat = torch.squeeze(y_hat)
        #print(y_hat)
        loss = criterion(y_hat, t.to(y_hat.device).float())

        # backward + optimize only if in training phase
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_samples += x.size(0)
    # Return the average training loss of ths epoch
    return total_loss / total_samples



In [3]:
from torch.utils.data import Dataset, DataLoader
import numpy as np


class ImageDataset(Dataset):
    def __init__(self, img_labels, img_images, transforms_fn=None):
        self.img_labels = img_labels
        self.img_images = img_images
        self.transforms_fn = transforms_fn
    def __len__(self):
        return len(self.img_labels)
    def __getitem__(self, idx):
        image = self.img_images[idx]
        label = self.img_labels[idx]
        if self.transforms_fn:
            image = self.transforms_fn(image)
        return image, label
    
###generate test dataset
NUMBER_IMAGES = 20
label_list = np.random.choice([1,2,3], size=NUMBER_IMAGES).tolist()
images = torch.randint(0, 255, size=[NUMBER_IMAGES, 3, 518, 518], dtype=torch.uint8)
image_dataset = ImageDataset(label_list, images, transforms_fn=vit_checkpoint.transforms())
trn_dl = DataLoader(image_dataset, batch_size=64, shuffle=True)

In [4]:
torch.cuda.empty_cache()
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))


torch.cuda.memory_allocated: 0.000000GB
torch.cuda.memory_reserved: 0.000000GB
torch.cuda.max_memory_reserved: 0.000000GB


In [None]:
torch.cuda.empty_cache() 
model = model.cuda()
num_epochs = 10
trn_loss_list = []
val_loss_list = []

import psutil
import os
import time
st = time.time()

for e in range(num_epochs):
    print(f'Epoch {e}/{num_epochs - 1}')
    print('-' * 10)
    print(f"Start of epoch Memory usage: {psutil.Process(os.getpid()).memory_info()[0] / 1e9:0.2f}GB")
    print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached()/1e9}')
   
    trn_loss = train_step(trn_dl, model, criterion, optimizer)
    print(trn_loss)
    trn_loss_list.append(trn_loss)    
    print(f'Loss: {trn_loss:.4f}'
          f"end epoch Memory usage: {psutil.Process(os.getpid()).memory_info()[0] / 1e9:0.2f}GB")
print(f"After Training Memory usage: {psutil.Process(os.getpid()).memory_info()[0] / 1e9:0.2f}GB")

et = time.time()

# get the execution time
elapsed_time = et - st
print('Execution time:', elapsed_time, 'seconds')

Epoch 0/9
----------
Start of epoch Memory usage: 5.56GB
Before forward pass - Cuda memory cached: 2.671771648
0 torch.Size([20, 3, 518, 518]) torch.Size([20])


In [None]:
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))