In [52]:
from perceiver_pytorch import Perceiver
import torch
import torchvision
import torchvision.transforms as transforms

In [53]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [68]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np

transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(), transforms.ToTensor()])

transform_test= transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

#x_train= trainset.data
#x_test= testset.data
#y_train= np.array(testset.targets)
#y_test= np.array(testset.targets)

Files already downloaded and verified
Files already downloaded and verified


In [69]:
train_loader = DataLoader(trainset, batch_size=5,shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=5,shuffle=False, num_workers=2)

In [70]:

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different name
    latent_dim = 128,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 2,            # number of heads for latent self attention, 8
    cross_dim_head = 8,
    latent_dim_head = 8,
    num_classes = 10,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False,    # whether to weight tie layers (optional, as indicated in the diagram)
    fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
    self_per_cross_attn = 2      # number of self attention blocks per cross attention
) 

In [None]:
import torch.optim as optim

num_epochs=50
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.09, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
dev=torch.device("cpu")
model.to(dev)

for epoch in range(num_epochs):  # loop over the dataset multiple times
    model.train()
    train_loss= 0
    train_accuracy= 0
    total= 0
    correct= 0

    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, targets = data
        inputs, targets = inputs.to(dev), targets.to(dev)
         # zero the parameter gradients
        optimizer.zero_grad()
         # forward + backward + optimize
        outputs = model(inputs.permute(0,2,3,1))
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss+= loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        train_accuracy += correct/total
        print(".",end="")
    scheduler.step()
    print("\n")
    print('Epoch {} Train Loss: {:.3f} | Train Acc: {:.3f}'.format(epoch+1, train_loss/len(train_loader), train_accuracy/len(train_loader)),end="\n")
    train_loss_epoch = train_loss/len(train_loader)
    train_acc_epoch = train_accuracy/len(train_loader)
print('Finished Training') 

In [None]:
test_accuracy= 0
total=0
correct=0
model.eval

with torch.no_grad():
    for i, data in enumerate(test_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        outputs= model(inputs.permute(0,2,3,1))
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        test_accuracy += correct/total
        print(f"test accuracy: {test_accuracy}")