## Import Libraries

In [1]:
import torch
import torchvision
import torch.nn as nn
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Load & Preprocessing Dataset
### Cat & Dog Dataset

In [3]:
vit_model = torchvision.models.ViT_B_16_Weights.DEFAULT
transform = vit_model.transforms()
transform

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [4]:
train_dataset = torchvision.datasets.ImageFolder("/kaggle/input/cat-and-dog/training_set/training_set", transform=transform)
test_dataset = torchvision.datasets.ImageFolder("/kaggle/input/cat-and-dog/test_set/test_set", transform=transform)

In [5]:
train = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=32)
test = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=32)

## VisionTransformer Transfer Learning
## ViT_B_16 Model

In [6]:
vit_model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1")

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 168MB/s]  


In [7]:
for param in vit_model.parameters():
    param.requires_grad = False

vit_model.heads = nn.Linear(768, 2)
vit_model = vit_model.to(device)

In [8]:
optimizer = torch.optim.Adam(vit_model.parameters(),lr=0.001)
loss_function = nn.CrossEntropyLoss()

In [9]:
def accuracy(preds,labels):
    _,pred_max = torch.max(preds,1)
    acc = torch.sum(pred_max==labels,dtype=torch.float64) / len(preds)
    
    return acc

In [10]:
for epoch in range(10):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    for image, label in tqdm(train):
        image = image.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        pred = vit_model(image)
        loss = loss_function(pred, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss
        train_acc += accuracy(pred, label)
        
    for image, label in test:
        image = image.to(device)
        label = label.to(device)
        pred = vit_model(image)
        loss = loss_function(pred, label)
        
        test_loss += loss
        test_acc += accuracy(pred, label)
        
    total_acc = train_acc / len(train)
    total_val_acc = test_acc / len(test)
    print(f"EPOCHS: {epoch+1}, ACC: {total_acc}, VAL_ACC: {total_val_acc}")

100%|██████████| 251/251 [01:29<00:00,  2.82it/s]


EPOCHS: 1, ACC: 0.9874252988047808, VAL_ACC: 0.99560546875


100%|██████████| 251/251 [00:54<00:00,  4.57it/s]


EPOCHS: 2, ACC: 0.9962649402390438, VAL_ACC: 0.99462890625


100%|██████████| 251/251 [00:55<00:00,  4.55it/s]


EPOCHS: 3, ACC: 0.9972609561752988, VAL_ACC: 0.994140625


100%|██████████| 251/251 [00:55<00:00,  4.56it/s]


EPOCHS: 4, ACC: 0.9977589641434262, VAL_ACC: 0.994140625


100%|██████████| 251/251 [00:54<00:00,  4.58it/s]


EPOCHS: 5, ACC: 0.9983814741035857, VAL_ACC: 0.9931640625


100%|██████████| 251/251 [00:55<00:00,  4.54it/s]


EPOCHS: 6, ACC: 0.9985059760956175, VAL_ACC: 0.99365234375


100%|██████████| 251/251 [00:55<00:00,  4.50it/s]


EPOCHS: 7, ACC: 0.9988794820717132, VAL_ACC: 0.99365234375


100%|██████████| 251/251 [00:54<00:00,  4.59it/s]


EPOCHS: 8, ACC: 0.9995019920318725, VAL_ACC: 0.994140625


100%|██████████| 251/251 [00:54<00:00,  4.59it/s]


EPOCHS: 9, ACC: 0.9995019920318725, VAL_ACC: 0.9931640625


100%|██████████| 251/251 [00:54<00:00,  4.59it/s]


EPOCHS: 10, ACC: 0.9998754980079682, VAL_ACC: 0.994140625


## Save Weights

In [11]:
torch.save(vit_model.state_dict(),"weights.pth")