In [2]:
!pip install -q git+https://github.com/huggingface/transformers

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 636 kB 8.3 MB/s 
[K     |████████████████████████████████| 3.3 MB 44.3 MB/s 
[K     |████████████████████████████████| 895 kB 60.5 MB/s 
[K     |████████████████████████████████| 56 kB 6.0 MB/s 
[?25h  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


In [3]:

# Imports
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split



LR=0.001
# dataset config
batch_size = 32
generator=torch.Generator().manual_seed(42) # Can be included for reproducability

In [4]:

import numpy as np
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


In [30]:
_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

NUM_CLASSES = 0

def getTrainingSet(dataset_name):
  if dataset_name == 'CIFAR-10':
    print("in cifar-10")
    NUM_CLASSES=10

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform_train)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform_test)
    
    
    

    trainset, validset = torch.utils.data.random_split(trainset, 
                                                      [int(len(trainset)*0.8),len(trainset)- 
                                                      int(len(trainset)*0.8)], generator=generator)
    
  elif dataset_name == 'STL10':
    NUM_CLASSES=10
    
    trainset = torchvision.datasets.STL10(root='./data', split='train',
                                          download=True, transform=transform_train)

    testset = torchvision.datasets.STL10(root='./data', split='test',
                                          download=True, transform=transform_train)
    

    trainset, validset = torch.utils.data.random_split(trainset, 
                                                      [int(len(trainset)*0.8),len(trainset)- 
                                                      int(len(trainset)*0.8)], generator=generator)
    

  elif dataset_name == 'Caltech101':
    NUM_CLASSES=101
    !gdown https://drive.google.com/uc?id=1DX_XeKHn3yXtZ18DD7qc1wf-Jy5lnhD5
    !unzip -qq '101_ObjectCategories.zip' 

    PATH = '101_ObjectCategories/'

    transform = transforms.Compose(
      [transforms.CenterCrop(256),
      transforms.Resize((64,64)),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    totalset = torchvision.datasets.ImageFolder(PATH, transform=transform_train)

    X, y = zip(*totalset)

    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size = 0.3, 
                                                      stratify=y)
    X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, 
                                                    test_size = 0.5, 
                                                    stratify=y_val)

    trainset, validset, testset = list(zip(X_train, y_train)), list(zip(X_val, y_val)), list(zip(X_test, y_test))




  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)
  validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                            shuffle=False,num_workers=2)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
  return trainset, testset, trainloader, testloader

In [31]:
train_ds, test_ds, train_loader, test_loader = getTrainingSet('STL10')
print(len(test_ds))


Files already downloaded and verified
Files already downloaded and verified
8000


In [24]:
from transformers import ViTModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
import torch.nn.functional as F

class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=10):
        super(ViTForImageClassification, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        output = self.dropout(outputs.last_hidden_state[:,0])
        logits = self.classifier(output)

        return logits

In [25]:

BATCH_SIZE = batch_size
LEARNING_RATE = .0001

In [32]:
from transformers import ViTFeatureExtractor
import torch.nn as nn
import torch
# Define Model
model = ViTForImageClassification(10)    
# Feature Extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# Adam Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Use GPU if available  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
if torch.cuda.is_available():
    model.cuda() 

In [27]:
def train(epoch):
  model.train()
  correct_images = 0
  total_images = 0
  training_loss = 0
  losses = 0
  for batch_index, (images, labels) in enumerate(tqdm(train_loader)):
    optimizer.zero_grad()
    x = np.split(np.squeeze(np.array(images)), BATCH_SIZE)
    # Remove unecessary dimension
    for index, array in enumerate(x):
      x[index] = np.squeeze(array)
    # Apply feature extractor, stack back into 1 tensor and then convert to tensor
    x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
    images, labels = x.to(device), labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    training_loss += loss.item()
    _, predicted = outputs.max(1)
    total_images += labels.size(0)
    correct_images += predicted.eq(labels).sum().item()
  print('Epoch: %d, Loss: %.3f, '
              'Accuracy: %.3f%% (%d/%d)' % (epoch, training_loss/(batch_index+1),
                                       100.*correct_images/total_images, correct_images, total_images))
  return training_loss/(batch_index+1), 100.*correct_images/total_images



In [28]:
def test():
    test_loss = 0
    total_images = 0
    correct_images = 0
    total_loss = 0
    model.eval()
    with torch.no_grad():
      for batch_index, (images, labels) in enumerate(tqdm(test_loader)):
        x = np.split(np.squeeze(np.array(images)), 32)
        # Remove unecessary dimension
        for index, array in enumerate(x):
          x[index] = np.squeeze(array)
        # Apply feature extractor, stack back into 1 tensor and then convert to tensor
        x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
        images, labels = x.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total_images += labels.size(0)
        correct_images += predicted.eq(labels).sum().item()
        #print(batch_index, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        #           % (test_loss/(batch_index+1), 100.*correct_images/total_images, correct_images, total_images))
    test_accuracy = 100.*correct_images/total_images
    print("accuracy of test set is",test_accuracy)
    return test_accuracy

In [33]:
# Model

accuracy_test = []
loss = []
train_acc = []
model.to(device)
    
criterion = nn.CrossEntropyLoss()

#RETURN LOSS AFTER EACH EPOCH
for epoch in range(20):
    epoch_loss, epoch_acc = train(epoch)
    loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    accuracy_test.append(test())
        

100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 0, Loss: 1.994, Accuracy: 25.400% (1016/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 32.325


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 1, Loss: 1.720, Accuracy: 37.525% (1501/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 38.025


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 2, Loss: 1.608, Accuracy: 40.750% (1630/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 41.825


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 3, Loss: 1.550, Accuracy: 44.775% (1791/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 44.675


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 4, Loss: 1.440, Accuracy: 48.475% (1939/4000)


100%|██████████| 250/250 [01:06<00:00,  3.76it/s]


accuracy of test set is 46.0875


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 5, Loss: 1.394, Accuracy: 49.425% (1977/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 47.2125


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 6, Loss: 1.360, Accuracy: 51.000% (2040/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 46.9


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 7, Loss: 1.304, Accuracy: 53.375% (2135/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 48.325


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 8, Loss: 1.302, Accuracy: 53.775% (2151/4000)


100%|██████████| 250/250 [01:06<00:00,  3.75it/s]


accuracy of test set is 48.8375


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 9, Loss: 1.246, Accuracy: 55.900% (2236/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 49.0875


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 10, Loss: 1.221, Accuracy: 56.675% (2267/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 50.55


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 11, Loss: 1.187, Accuracy: 58.375% (2335/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 49.45


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 12, Loss: 1.177, Accuracy: 58.000% (2320/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 50.95


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 13, Loss: 1.117, Accuracy: 60.550% (2422/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 52.0


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 14, Loss: 1.098, Accuracy: 60.425% (2417/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 49.5625


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 15, Loss: 1.066, Accuracy: 62.675% (2507/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 50.7875


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 16, Loss: 1.061, Accuracy: 62.025% (2481/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 52.0


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 17, Loss: 1.035, Accuracy: 63.225% (2529/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 52.575


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 18, Loss: 1.036, Accuracy: 63.225% (2529/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]


accuracy of test set is 52.4375


100%|██████████| 125/125 [01:20<00:00,  1.55it/s]


Epoch: 19, Loss: 1.002, Accuracy: 64.875% (2595/4000)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]

accuracy of test set is 53.4625



