In [23]:
print("Hello")

Hello


In [24]:
from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor
from PIL import Image
import requests

url = 'https://lumiere-a.akamaihd.net/v1/images/darth-vader-main_4560aff7.jpeg?region=71%2C0%2C1139%2C854'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state


In [25]:
# split into train and test, take 10000 train samples and 1000 test samples
# randomly shuffle the data list
import random
import torch


In [28]:
train_split = torch.load('train_split_mnist.pt')
test_split = train_split[:1000]
train_split = train_split[1000:]

In [29]:
print(len(train_split), len(test_split))

9000 1000


In [34]:
from torch import nn, optim
class ptuned_VIT(nn.Module):
    def __init__(self, num_classes):
        super(ptuned_VIT, self).__init__()
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.embeddings = nn.Embedding(3, 768)
        self.classification_layer = nn.Linear(768, num_classes)
        self.softmax = nn.Softmax(dim=1)
        self.trainable = [self.embeddings, self.classification_layer]
        for param in self.model.parameters():
            param.requires_grad = False
    def forward(self, x):
        tens = [0,1,2]
        tens = torch.tensor(tens)
        tens = tens.to(device)
        x1 = self.embeddings(tens)
        x2 = self.model.embeddings(x)
        # change x1 shape from x, 768 to 1, x, 768
        x1 = x1.unsqueeze(0)
        # concat x1 and x2 along the first dimension
        # x1 is 1, 3, 768, make is x2.shape[0], 3, 768
        x1 = x1.expand(x2.shape[0], -1, -1)
        x = torch.cat((x1, x2), 1)
        x = self.model.encoder(x)
        x = x['last_hidden_state']
        x = self.model.pooler(x)
        x = self.classification_layer(x)
        x = self.softmax(x)
        return x
        
classes = 10
p_tokens = 3
import torch
CUDA_LAUNCH_BLOCKING=1
from torch import nn, optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# use data loader
train_loader = torch.utils.data.DataLoader(train_split, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_split, batch_size=32, shuffle=True)




In [35]:
print(len(train_loader))

282


In [36]:
# use corss entropy loss
# import accuracy and f1 from sklearn
model = ptuned_VIT(classes)
model = model.to(device)
from sklearn.metrics import accuracy_score, f1_score
loss_func = nn.CrossEntropyLoss()
# use adam optimizer
optimizer = optim.Adam(model.parameters(),lr=0.001)
acc_list = []
n_epochs = 10
fin_acc = []
for epoch in range(n_epochs):
    batch_no = 0
    for sample in train_loader:
        image = sample[0]
        label = sample[1]
        image = image.to(device)
        label = label.to(device)
        # iamge shape is a, 1, b, c, d make it a, b, c, d
        image = image.squeeze(1)
        output = model(image)
        loss = loss_func(output, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # calculate accuracy
        pred = torch.argmax(output, dim=1)
        acc = accuracy_score(label.cpu(), pred.cpu())
        acc_list.append(acc)
        running_acc = sum(acc_list)/len(acc_list)
        fin_acc.append(acc)
        print(f"Epoch: {epoch}, Batch: {batch_no}/313, Loss: {loss.item()}, Running Accuracy: {running_acc}, Current Accuracy {acc}")
        batch_no += 1
    # write testing loop using test_loader
    acc_list = []
    for sample in test_loader:
        image = sample[0]
        label = sample[1]
        image = image.to(device)
        label = label.to(device)
        image = image.squeeze(1)
        output = model(image)
        pred = torch.argmax(output, dim=1)
        acc = accuracy_score(label.cpu(), pred.cpu())
        acc_list.append(acc)
        running_acc = sum(acc_list)/len(acc_list)
        print(f"Epoch: {epoch}, Test Accuracy: {running_acc}")
    torch.save(model, "pt_vit_mnist_model.pt")
    torch.save(fin_acc, "pt_vit_mnist_acc.pt")



Epoch: 0, Batch: 0/313, Loss: 2.3010013103485107, Running Accuracy: 0.0625, Current Accuracy 0.0625
Epoch: 0, Batch: 1/313, Loss: 2.2837631702423096, Running Accuracy: 0.140625, Current Accuracy 0.21875
Epoch: 0, Batch: 2/313, Loss: 2.2730467319488525, Running Accuracy: 0.16666666666666666, Current Accuracy 0.21875
Epoch: 0, Batch: 3/313, Loss: 2.2582132816314697, Running Accuracy: 0.203125, Current Accuracy 0.3125
Epoch: 0, Batch: 4/313, Loss: 2.290294647216797, Running Accuracy: 0.1875, Current Accuracy 0.125
Epoch: 0, Batch: 5/313, Loss: 2.208986759185791, Running Accuracy: 0.20833333333333334, Current Accuracy 0.3125
Epoch: 0, Batch: 6/313, Loss: 2.2569849491119385, Running Accuracy: 0.20982142857142858, Current Accuracy 0.21875
Epoch: 0, Batch: 7/313, Loss: 2.2954745292663574, Running Accuracy: 0.19921875, Current Accuracy 0.125
Epoch: 0, Batch: 8/313, Loss: 2.2345309257507324, Running Accuracy: 0.2013888888888889, Current Accuracy 0.21875
Epoch: 0, Batch: 9/313, Loss: 2.263293027

KeyboardInterrupt: 

In [38]:
model_new = torch.load("pt_vit_mnist_model.pt")
print(model_new)

ptuned_VIT(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_featu

In [39]:
new_list = torch.load("pt_vit_mnist_acc.pt")
print(new_list)

[0.0625, 0.21875, 0.21875, 0.3125, 0.125, 0.3125, 0.21875, 0.125, 0.21875, 0.25, 0.1875, 0.34375, 0.40625, 0.34375, 0.25, 0.34375, 0.53125, 0.40625, 0.5, 0.53125, 0.5625, 0.4375, 0.46875, 0.5, 0.5, 0.625, 0.625, 0.5625, 0.53125, 0.53125, 0.40625, 0.53125, 0.5625, 0.53125, 0.5625, 0.59375, 0.34375, 0.5, 0.53125, 0.5625, 0.6875, 0.5, 0.5, 0.65625, 0.6875, 0.53125, 0.5625, 0.40625, 0.65625, 0.5625, 0.78125, 0.59375, 0.59375, 0.75, 0.625, 0.625, 0.6875, 0.65625, 0.71875, 0.71875, 0.65625, 0.5625, 0.75, 0.6875, 0.625, 0.625, 0.71875, 0.6875, 0.59375, 0.3125, 0.59375, 0.6875, 0.6875, 0.65625, 0.4375, 0.65625, 0.5625, 0.59375, 0.5625, 0.75, 0.6875, 0.65625, 0.6875, 0.65625, 0.875, 0.6875, 0.5625, 0.625, 0.75, 0.6875, 0.59375, 0.6875, 0.5625, 0.78125, 0.84375, 0.625, 0.8125, 0.65625, 0.6875, 0.5, 0.8125, 0.6875, 0.71875, 0.75, 0.84375, 0.75, 0.6875, 0.6875, 0.8125, 0.78125, 0.75, 0.78125, 0.8125, 0.8125, 0.78125, 0.84375, 0.59375, 0.78125, 0.6875, 0.75, 0.71875, 0.75, 0.71875, 0.8125, 0.84375,