In [1]:
!pip install ml_collections
!pip install perceiver-pytorch

Collecting ml_collections
  Downloading ml_collections-0.1.0-py3-none-any.whl (88 kB)
[?25l[K     |███▊                            | 10 kB 24.7 MB/s eta 0:00:01[K     |███████▍                        | 20 kB 27.8 MB/s eta 0:00:01[K     |███████████                     | 30 kB 22.9 MB/s eta 0:00:01[K     |██████████████▉                 | 40 kB 18.1 MB/s eta 0:00:01[K     |██████████████████▌             | 51 kB 14.4 MB/s eta 0:00:01[K     |██████████████████████▏         | 61 kB 11.9 MB/s eta 0:00:01[K     |█████████████████████████▉      | 71 kB 10.8 MB/s eta 0:00:01[K     |█████████████████████████████▋  | 81 kB 11.9 MB/s eta 0:00:01[K     |████████████████████████████████| 88 kB 5.1 MB/s 
Installing collected packages: ml-collections
Successfully installed ml-collections-0.1.0
Collecting perceiver-pytorch
  Downloading perceiver_pytorch-0.7.4-py3-none-any.whl (11 kB)
Collecting einops>=0.3
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected pa

In [62]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -qq 'tiny-imagenet-200.zip'

--2021-10-10 06:54:30--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.68.10
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2021-10-10 06:55:10 (5.96 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]



In [122]:
from ml_collections import ConfigDict

def C(**kwargs):
    return ConfigDict(initial_dictionary=kwargs)

def get_config():
    return C(
        cuda                = True,
        dataset             = 'cifar10',
        image_size          = 32,
        num_classes         = 10,

        train=C(
            batch_size          = 128,
            num_epochs          = 100,
        ),

        optimizer_type      = 'adamw',
        optimizer_args=C(
            lr                  = 3e-4,
        ),

        model_type          = 'perceiver_io',
        model_args=C(
            dim = 32 * 32 * 3,                    # dimension of sequence to be encoded
            queries_dim = 10,            # dimension of decoder queries
            logits_dim = 10,            # dimension of final logits
            depth = 2,                   # depth of net
            num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
            latent_dim = 64,            # latent dimension
            cross_heads = 1,             # number of heads for cross attention. paper said 1
            latent_heads = 8,            # number of heads for latent self attention, 8
            cross_dim_head = 128,         # number of dimensions per cross attention head
            latent_dim_head = 128,        # number of dimensions per latent self attention head
            weight_tie_layers = False,    # whether to weight tie layers (optional, as indicated in the diagram)
            decoder_ff = False
        ),
    )

In [3]:
def load_config(config_name):
    return {
        'perceiver_io': get_config(),
    }[config_name]

In [135]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from perceiver_pytorch import Perceiver, PerceiverIO

def get_datasets(dataset, data_root):
    assert dataset in ('stl10', 'cifar10', 'tinyimagenet')
    data_root = os.path.abspath(os.path.expanduser(data_root))
    root_dir = os.path.join(data_root, dataset)
    transform_train = transforms.Compose([
      transforms.RandomHorizontalFlip(p = 0.5),
      transforms.ColorJitter(brightness=0.5, hue = 0.25),
      transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    if dataset == 'stl10':
        train_dataset = datasets.STL10(root=root_dir, split='train', transform=transforms.ToTensor(), download=True)
        test_dataset = datasets.STL10(root=root_dir, split='test', transform=transforms.ToTensor(), download=True)
    elif dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform_train)
        test_dataset = datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform_test)
    elif dataset == 'tinyimagenet':
        train_dataset = datasets.ImageFolder("./tiny-imagenet-200/train", transform=transform_train)
        test_dataset = datasets.ImageFolder("./tiny-imagenet-200/val", transform=transform_test)

    return train_dataset, test_dataset

def get_model(config):
    model_cls = {'perceiver': Perceiver,
                 'perceiver_io': PerceiverIO
                 }[config.model_type]
    model = model_cls(**config.model_args)
    param_count = sum(np.prod(p.shape).item() for p in model.parameters())
    print(f'Created {config.model_type} model with {param_count} parameters.')
    if config.cuda:
        model.cuda()
    temp = torch.nn.Parameter(torch.rand(config.train['batch_size'], 32, config['num_classes']), requires_grad = True)
    model.register_parameter(name='query_io', param=temp)
    return model

def get_optimizer(model, config):
    optimizer_cls = {'adamw': optim.AdamW,
                     'rmsprop': optim.RMSprop
                     }[config.optimizer_type]
    optimizer = optimizer_cls(model.parameters(), **config.optimizer_args)
    return optimizer

def train(config):
    train_dataset, test_dataset = get_datasets(config.dataset, 'data')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size,
                                               shuffle=True, pin_memory=config.cuda, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.train.batch_size,
                                              shuffle=False, pin_memory=config.cuda, drop_last=False)
    model = get_model(config)
    opt = get_optimizer(model, config)

    total_steps = 0
    for epoch in range(1, config.train.num_epochs+1):
        print(f'Starting epoch {epoch}')

        for _, (x, y) in enumerate(train_loader):
            total_steps += 1
            if len(x) != 128:
              break
            if config.cuda:
                x, y = x.cuda(), y.cuda()
            x = x.permute(0, 2, 3, 1) * 2 - 1
            x = torch.flatten(x, start_dim = 1, end_dim = 3)
            queries = model.query_io.data
            queries = queries.cuda()
            y_hat = model(x[:, None, :], queries = queries).mean(1).squeeze()
            opt.zero_grad()
            loss = F.cross_entropy(y_hat, y)
            loss.backward()
            opt.step()
            if total_steps % 100 == 0:
                print(f'epoch {epoch} step {total_steps} loss {loss.item():.4f}')
        print(f'epoch {epoch} step {total_steps} loss {loss.item():.4f}')

        ep_train_loss, ep_train_acc = evaluate(model, train_dataset, config)
        print(f'epoch {epoch} Train accuracy: {ep_train_acc:.4f} loss: {ep_train_loss:.4f}')
        ep_test_loss, ep_test_acc = evaluate(model, test_dataset, config)
        print(f'epoch {epoch} Test accuracy: {ep_test_acc:.4f} loss: {ep_test_loss:.4f}')
        print()


@torch.no_grad()
def evaluate(model, dataset, config):
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.train.batch_size,
                                              shuffle=False, pin_memory=config.cuda,
                                              drop_last=False)
    model.eval()
    losses, preds = [], []
    for x, y in data_loader:
        if len(x) != 128:
            break
        if config.cuda:
            x, y = x.cuda(), y.cuda()
        x = x.permute(0, 2, 3, 1) * 2 - 1
        x = torch.flatten(x, start_dim = 1, end_dim = 3)
        queries = model.query_io.data
        queries = queries.cuda()
        y_hat = model(x[:, None, :], queries = queries).mean(1).squeeze()
        y_pred = y_hat.argmax(axis=-1)
        loss = F.cross_entropy(y_hat, y, reduction='none')
        pred = (y_pred == y).long()
        losses.append(loss)
        preds.append(pred)

    avg_loss = torch.cat(losses).cpu().mean().item()
    accuracy = torch.cat(preds).cpu().float().mean().item()

    return avg_loss, accuracy


In [136]:
config = load_config('perceiver_io')
train(config)

Files already downloaded and verified
Files already downloaded and verified
Created perceiver_io model with 1504652 parameters.
Starting epoch 1
epoch 1 step 100 loss 1.9702
epoch 1 step 200 loss 1.9486
epoch 1 step 300 loss 1.8389
epoch 1 step 391 loss 1.8121
epoch 1 Train accuracy: 0.3877 loss: 1.7115
epoch 1 Test accuracy: 0.4094 loss: 1.6638

Starting epoch 2
epoch 2 step 400 loss 1.5476
epoch 2 step 500 loss 1.8173
epoch 2 step 600 loss 1.5951
epoch 2 step 700 loss 1.5754
epoch 2 step 782 loss 1.5566
epoch 2 Train accuracy: 0.4348 loss: 1.5992
epoch 2 Test accuracy: 0.4426 loss: 1.5640

Starting epoch 3
epoch 3 step 800 loss 1.6381
epoch 3 step 900 loss 1.5764
epoch 3 step 1000 loss 1.5905
epoch 3 step 1100 loss 1.4677
epoch 3 step 1173 loss 1.8356
epoch 3 Train accuracy: 0.4592 loss: 1.5191
epoch 3 Test accuracy: 0.4677 loss: 1.5011

Starting epoch 4
epoch 4 step 1200 loss 1.5768
epoch 4 step 1300 loss 1.4858
epoch 4 step 1400 loss 1.4332
epoch 4 step 1500 loss 1.5959
epoch 4 ste

Starting epoch 1
epoch 1 Train accuracy: 0.2149 loss: 2.1175
epoch 1 Test accuracy: 0.2099 loss: 2.1230

Starting epoch 2
epoch 2 Train accuracy: 0.2630 loss: 2.0121
epoch 2 Test accuracy: 0.2568 loss: 2.0234

Starting epoch 3
epoch 3 Train accuracy: 0.3049 loss: 1.9212
epoch 3 Test accuracy: 0.2926 loss: 1.9452

Starting epoch 4
epoch 4 Train accuracy: 0.3238 loss: 1.8783
epoch 4 Test accuracy: 0.3044 loss: 1.9204

Starting epoch 5
epoch 5 Train accuracy: 0.3507 loss: 1.8155
epoch 5 Test accuracy: 0.3311 loss: 1.8677

Starting epoch 6
epoch 6 Train accuracy: 0.3613 loss: 1.7780
epoch 6 Test accuracy: 0.3315 loss: 1.8519

Starting epoch 7
epoch 7 Train accuracy: 0.3804 loss: 1.7315
epoch 7 Test accuracy: 0.3404 loss: 1.8337

Starting epoch 8
epoch 8 Train accuracy: 0.3959 loss: 1.6926
epoch 8 Test accuracy: 0.3443 loss: 1.8230

Starting epoch 9
epoch 9 Train accuracy: 0.4155 loss: 1.6319
epoch 9 Test accuracy: 0.3509 loss: 1.8061

Starting epoch 10
epoch 10 Train accuracy: 0.4371 loss: 1.5848
epoch 10 Test accuracy: 0.3573 loss: 1.7931

Starting epoch 11
epoch 11 Train accuracy: 0.4526 loss: 1.5444
epoch 11 Test accuracy: 0.3527 loss: 1.8109

Starting epoch 12
epoch 12 Train accuracy: 0.4790 loss: 1.4834
epoch 12 Test accuracy: 0.3623 loss: 1.7963

Starting epoch 13
epoch 13 Train accuracy: 0.4951 loss: 1.4388
epoch 13 Test accuracy: 0.3534 loss: 1.8314

Starting epoch 14
epoch 14 Train accuracy: 0.5201 loss: 1.3662
epoch 14 Test accuracy: 0.3575 loss: 1.8386

Starting epoch 15
epoch 15 Train accuracy: 0.5374 loss: 1.3202
epoch 15 Test accuracy: 0.3539 loss: 1.8746

Starting epoch 16
epoch 16 Train accuracy: 0.5430 loss: 1.2864
epoch 16 Test accuracy: 0.3480 loss: 1.9401

Starting epoch 17
epoch 17 Train accuracy: 0.5807 loss: 1.2018
epoch 17 Test accuracy: 0.3478 loss: 1.9427

Starting epoch 18
epoch 18 Train accuracy: 0.6062 loss: 1.1391
epoch 18 Test accuracy: 0.3426 loss: 1.9859

Starting epoch 19
epoch 19 Train accuracy: 0.6255 loss: 1.0780
epoch 19 Test accuracy: 0.3333 loss: 2.0762

In [8]:
# Batch Size Ablation on Smaller Architecture, 12 epochs, CIFAR-10
# Fixed model architecture for ablation study
# model_args=C(
#             dim = 32,                    # dimension of sequence to be encoded
#             queries_dim = 10,            # dimension of decoder queries
#             logits_dim = 10,            # dimension of final logits
#             depth = 2,                   # depth of net
#             num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
#             latent_dim = 64,            # latent dimension
#             cross_heads = 1,             # number of heads for cross attention. paper said 1
#             latent_heads = 8,            # number of heads for latent self attention, 8
#             cross_dim_head = 64,         # number of dimensions per cross attention head
#             latent_dim_head = 64,        # number of dimensions per latent self attention head
#             weight_tie_layers = False,    # whether to weight tie layers (optional, as indicated in the diagram)
#             decoder_ff = False
#         ),
batch_size = [8, 16, 32, 64, 128]
final_test_accuracies = [0.1468, 0.2099, 0.3, 0.3623, 0.3852]



In [None]:
epoch = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]

