In [None]:
import torch
print(torch.cuda.is_available())

True


# Loading Data Set and Assembling Training, Validation, and Test Sets



In [None]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -qq 'tiny-imagenet-200.zip'

--2021-10-09 16:19:35--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.68.10
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’

[tiny-imagenet-200.zip]
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of tiny-imagenet-200.zip or
        tiny-imagenet-200.zip.zip, and cannot find tiny-imagenet-200.zip.ZIP, period.


In [None]:
# Load training dataset, then split into training, validation, and test sets
import csv
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import datasets,transforms
import torch.utils.tensorboard as tb
from tqdm import tqdm


dataset = "CIFAR-10"
BATCH_SIZE = 16

# create datasets
if dataset == "STL-10":
    training_dataset = datasets.STL10(root='./data/STL10', split='train', transform=transforms.ToTensor(), download=True)
    validation_dataset = datasets.STL10(root='./data/STL10', split='test', transform=transforms.ToTensor(), download=True)
    in_size = 96
    num_classes = 10
if dataset == "CIFAR-10":
    training_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True, transform=transforms.ToTensor())
    validation_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transforms.ToTensor())
    in_size = 32
    num_classes = 10
if dataset == "TinyImageNet":
    training_dataset = datasets.ImageFolder("./tiny-imagenet-200/train", transform=transforms.Compose([transforms.ToTensor()]))
    validation_dataset = datasets.ImageFolder("./tiny-imagenet-200/val", transform=transforms.Compose([transforms.ToTensor()]))
    in_size = 64
    num_classes = 200

# create dataloaders
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=int(BATCH_SIZE), shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=int(BATCH_SIZE), shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
!pip install perceiver-pytorch

Collecting perceiver-pytorch
  Downloading perceiver_pytorch-0.7.4-py3-none-any.whl (11 kB)
Collecting einops>=0.3
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, perceiver-pytorch
Successfully installed einops-0.3.2 perceiver-pytorch-0.7.4


In [None]:
import torch
from torch import nn, einsum
from perceiver_pytorch import PerceiverIO


class QueryFFN(nn.Module):
    def __init__(self, in_size):
        super().__init__()
        self.query = nn.Linear(in_size*in_size, num_classes)
      
    def forward(self, images):
        images = torch.flatten(images, start_dim = 1)
        return self.query(images)[:, None, :]


original_model = PerceiverIO(
    dim = in_size,                    # dimension of sequence to be encoded
    queries_dim = num_classes,            # dimension of decoder queries
    logits_dim = num_classes,            # dimension of final logits
    depth = 5,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False,    # whether to weight tie layers (optional, as indicated in the diagram)
    decoder_ff = False
)

seq = torch.randn(3, in_size, in_size)
queries = torch.rand(3, 1, num_classes)

logits = original_model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)
print(logits.squeeze().size())

logits = model(seq)
print(logits.size())


torch.Size([3, 10])
torch.Size([3, 1, 10])


In [None]:
LR = 1e-1
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
EPOCHS = 8

def train_model(model_in, learning_rate = LR, weight_d = WEIGHT_DECAY, momentum = MOMENTUM):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = model_in
  model = model.to(device)
  query_model = QueryFFN(in_size)
  query_model = query_model.to(device)

  # Define Loss Function and get optimizer
  train_loss = {}
  validation_acc = {}
  loss_f = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr = LR, weight_decay=WEIGHT_DECAY, betas=(momentum, 0.999))
  query_optimizer = torch.optim.Adam(query_model.parameters(), lr = 1e-2, weight_decay = WEIGHT_DECAY)
  epoch_loss = {}
  for epoch in range (EPOCHS):
    print("EPOCH : ", epoch)
    model.train()
    query_model.train()
    batch_loss = []
    
    accuracies = []
    num = 0
    # Iterate through training set, collect loss values and update model
    for im, truth_labels in training_loader:
      num += 1
      if num > 500:
        break

      im = im.mean(1)
      # print(im.size())
      im = im.to(device)
      queries = query_model(im)
      queries = queries.to(device)
      truth_labels = truth_labels.to(device)
      predicted_labels = model(im, queries = queries).squeeze()
      predicted_labels = predicted_labels.to(device)
      loss = loss_f(predicted_labels, truth_labels)
      print(loss)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      query_optimizer.step()
      query_optimizer.zero_grad()
      batch_loss.append(loss.item())
      accuracy = (predicted_labels.argmax(1) == truth_labels).float().mean().item()
      accuracies.append(accuracy)

    print("******Train ACCURACY****** : ", torch.FloatTensor(accuracies).mean().item())
    print("TRAIN_LOSS of Each Batch: ", torch.FloatTensor(batch_loss).mean().item())
    epoch_loss[epoch] = batch_loss

    # Iterate through validation set and compute validation accuracy
    model.eval()
    query_model.eval()
    accuracies = []
    for validation_im, validation_labels in validation_loader:
      validation_im = validation_im.to(device)
      validation_im = validation_im.mean(1)
      validation_labels = validation_labels.to(device)
      queries = query_model(validation_im)
      queries = queries.to(device)
      predicted_labels = model(validation_im, queries = queries).squeeze().argmax(1)
      accuracy = (predicted_labels == validation_labels).float().mean().item()
      accuracies.append(accuracy)
    
    validation_set_accuracy = torch.FloatTensor(accuracies).mean().item()
    print("******VALIDATION ACCURACY****** : ", validation_set_accuracy)
    validation_acc[epoch] = validation_set_accuracy

  # # Calculate test set accuracy
  # model.eval()
  # accuracies = []
  # for test_im, test_labels in test_set:
  #     test_im = test_im.to(device)
  #     test_labels = test_labels.to(device)

  #     predicted_labels = model(test_im).argmax(1)
  #     accuracy = (predicted_labels == test_labels).float().mean().item()
  #     accuracies.append(accuracy)
    
  # test_set_accuracy = torch.FloatTensor(accuracies).mean().item()
  # print("******TEST ACCURACY****** : ", test_set_accuracy)
  # return train_loss, validation_acc, test_set_accuracy

In [None]:
train_model(original_model)

EPOCH :  0
tensor(2.2611, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(6.2336, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(85.2318, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(163.6691, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(118.8763, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(63.4948, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(41.1605, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(260.7211, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(55.8646, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(144.7871, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(180.2741, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(105.3125, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(49.1882, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(71.1342, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(96.3773, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(193.8064, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(122.4519, device=

KeyboardInterrupt: ignored

In [None]:
import torch
torch.cuda.empty_cache()