In [1]:
import torch
print(torch.cuda.is_available())

True


# Loading Data Set and Assembling Training, Validation, and Test Sets



In [13]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -qq 'tiny-imagenet-200.zip'

--2021-10-09 22:46:31--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.68.10
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2021-10-09 22:46:46 (16.7 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]



In [18]:
# Load training dataset, then split into training, validation, and test sets
import csv
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import datasets,transforms
import torch.utils.tensorboard as tb
from tqdm import tqdm


dataset = "CIFAR-10"
BATCH_SIZE = 16

# create datasets
if dataset == "STL-10":
    training_dataset = datasets.STL10(root='./data/STL10', split='train', transform=transforms.ToTensor(), download=True)
    validation_dataset = datasets.STL10(root='./data/STL10', split='test', transform=transforms.ToTensor(), download=True)
    in_size = 96
    num_classes = 10
if dataset == "CIFAR-10":
    training_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True, transform=transforms.ToTensor())
    validation_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transforms.ToTensor())
    in_size = 32
    num_classes = 10
if dataset == "TinyImageNet":
    training_dataset = datasets.ImageFolder("./tiny-imagenet-200/train", transform=transforms.Compose([transforms.ToTensor()]))
    validation_dataset = datasets.ImageFolder("./tiny-imagenet-200/val", transform=transforms.Compose([transforms.ToTensor()]))
    in_size = 64
    num_classes = 200

# create dataloaders
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=int(BATCH_SIZE), shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=int(BATCH_SIZE), shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
!pip install perceiver-pytorch

Collecting perceiver-pytorch
  Downloading perceiver_pytorch-0.7.4-py3-none-any.whl (11 kB)
Collecting einops>=0.3
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, perceiver-pytorch
Successfully installed einops-0.3.2 perceiver-pytorch-0.7.4


In [20]:
import torch
from torch import nn, einsum
from perceiver_pytorch import PerceiverIO


# class QueryFFN(nn.Module):
#     def __init__(self, in_size):
#         super().__init__()
#         self.query = nn.Linear(in_size*in_size, num_classes)
      
#     def forward(self, images):
#         images = torch.flatten(images, start_dim = 1)
#         return self.query(images)[:, None, :]


original_model = PerceiverIO(
    dim = in_size,                    # dimension of sequence to be encoded
    queries_dim = num_classes,            # dimension of decoder queries
    logits_dim = num_classes,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False,    # whether to weight tie layers (optional, as indicated in the diagram)
    decoder_ff = False
)

seq = torch.randn(3, in_size, in_size)
queries = torch.rand(3, 1, num_classes)

logits = original_model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)
print(logits.squeeze().size())

# model = QueryFFN(in_size)

# logits = model(seq)
# print(logits.size())
temp = torch.nn.Parameter(torch.rand(BATCH_SIZE, 1, num_classes), requires_grad = True)
original_model.register_parameter(name='query_io', param=temp)


torch.Size([3, 10])


In [21]:
LR = 1e-1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
EPOCHS = 8

def train_model(model_in, learning_rate = LR, weight_d = WEIGHT_DECAY, momentum = MOMENTUM):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = model_in
  model = model.to(device)
  # query_model = QueryFFN(in_size)
  # query_model = query_model.to(device)

  # Define Loss Function and get optimizer
  train_loss = {}
  validation_acc = {}
  loss_f = torch.nn.CrossEntropyLoss()
  # optimizer = torch.optim.Adam(model.parameters(), lr = LR, weight_decay=WEIGHT_DECAY, betas=(momentum, 0.999))
  optimizer = torch.optim.SGD(model.parameters(), lr=LR, weight_decay = WEIGHT_DECAY, momentum=MOMENTUM)
  # query_optimizer = torch.optim.Adam(query_model.parameters(), lr = 1e-2, weight_decay = WEIGHT_DECAY)
  epoch_loss = {}
  for epoch in range (EPOCHS):
    print("EPOCH : ", epoch)
    model.train()
    # query_model.train()
    batch_loss = []
    
    accuracies = []
    num = 0
    # Iterate through training set, collect loss values and update model
    for im, truth_labels in training_loader:
      num += 1
      if num > 50:
        break

      im = im.mean(1)
      # print(im.size())
      im = im.to(device)
      queries = model.query_io.data
      queries = queries.to(device)
      truth_labels = truth_labels.to(device)
      predicted_labels = model(im, queries = queries).squeeze()
      predicted_labels = predicted_labels.to(device)
      loss = loss_f(predicted_labels, truth_labels)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      # print(loss)
      # query_optimizer.step()
      # query_optimizer.zero_grad()
      batch_loss.append(loss.item())
      accuracy = (predicted_labels.argmax(1) == truth_labels).float().mean().item()
      # print(accuracy)
      accuracies.append(accuracy)

    print("******Train ACCURACY****** : ", torch.FloatTensor(accuracies).mean().item())
    print("TRAIN_LOSS of Each Batch: ", torch.FloatTensor(batch_loss).mean().item())
    epoch_loss[epoch] = batch_loss

    # # Iterate through validation set and compute validation accuracy
    # model.eval()
    # query_model.eval()
    # accuracies = []
    # num = 0
    # for validation_im, validation_labels in validation_loader:
    #   validation_im = validation_im.to(device)
    #   validation_im = validation_im.mean(1)
    #   validation_labels = validation_labels.to(device)
    #   queries = query_model(validation_im)
    #   queries = queries.to(device)
    #   predicted_labels = model(validation_im, queries = queries).squeeze().argmax(1)
    #   accuracy = (predicted_labels == validation_labels).float().mean().item()
    #   accuracies.append(accuracy)
    
    # validation_set_accuracy = torch.FloatTensor(accuracies).mean().item()
    # print("******VALIDATION ACCURACY****** : ", validation_set_accuracy)
    # validation_acc[epoch] = validation_set_accuracy

  # # Calculate test set accuracy
  # model.eval()
  # accuracies = []
  # for test_im, test_labels in test_set:
  #     test_im = test_im.to(device)
  #     test_labels = test_labels.to(device)

  #     predicted_labels = model(test_im).argmax(1)
  #     accuracy = (predicted_labels == test_labels).float().mean().item()
  #     accuracies.append(accuracy)
    
  # test_set_accuracy = torch.FloatTensor(accuracies).mean().item()
  # print("******TEST ACCURACY****** : ", test_set_accuracy)
  # return train_loss, validation_acc, test_set_accuracy

In [22]:
train_model(original_model)

EPOCH :  0
******Train ACCURACY****** :  0.11124999821186066
TRAIN_LOSS of Each Batch:  2.329042911529541
EPOCH :  1
******Train ACCURACY****** :  0.09125000238418579
TRAIN_LOSS of Each Batch:  2.3251492977142334
EPOCH :  2
******Train ACCURACY****** :  0.08874999731779099
TRAIN_LOSS of Each Batch:  2.3311307430267334
EPOCH :  3
******Train ACCURACY****** :  0.10750000178813934
TRAIN_LOSS of Each Batch:  2.316765546798706
EPOCH :  4
******Train ACCURACY****** :  0.12125000357627869
TRAIN_LOSS of Each Batch:  2.313443660736084
EPOCH :  5
******Train ACCURACY****** :  0.1262499988079071
TRAIN_LOSS of Each Batch:  2.3125839233398438
EPOCH :  6
******Train ACCURACY****** :  0.08375000208616257
TRAIN_LOSS of Each Batch:  2.3394484519958496
EPOCH :  7
******Train ACCURACY****** :  0.14875000715255737
TRAIN_LOSS of Each Batch:  2.2910404205322266


In [103]:
import torch
torch.cuda.empty_cache()