# **PerceiverIO Classifier**
# Eric Buehler 2022

#Import libraries and download perceiver library from GitHub
#Setup enviornment for training

Import libraries

In [None]:
import torch
 
import torchvision

from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, StepLR

from torch.utils.data import DataLoader,Dataset
from sklearn.model_selection import train_test_split

from PIL import Image

import os,sys
import numpy as np
import tqdm
from google.colab import drive
import matplotlib.pyplot as plt

to_pil = transforms.ToPILImage()

print("Torch version:", torch.__version__) 

Display information on current GPU

In [None]:
!nvidia-smi -L

In [None]:
!nvidia-smi 

Setup enviornment and install perciever library

Note: This block will create  "Colab Notebooks/PercieverIO_Classifier" directory in your Google Drive.

In [None]:
modelname="2_4_22_m1"


drive.mount('/content/drive')

prefix='/content/drive/MyDrive/Colab Notebooks/PercieverIO_Classifier/'

try:
    os.mkdir(prefix)
except FileExistsError:
    pass

prefix_images=prefix+'images/'
try:
    os.mkdir(prefix_images)
except FileExistsError:
    pass
try:
    os.mkdir(prefix_images+modelname+"/")
except FileExistsError:
    pass

prefix_models=prefix+'models/'+modelname+"/"
try:
    os.mkdir(prefix+'models/')
except FileExistsError:
    pass
try:
    os.mkdir(prefix_models)
except FileExistsError:
    pass

CPUonly=False

In [None]:
!git clone https://github.com/lucidrains/perceiver-pytorch
!cd perceiver-pytorch/
!pip install perceiver-pytorch

os.chdir(prefix)

In [6]:
from perceiver_pytorch import Perceiver
from perceiver_pytorch import PerceiverIO

#Important variables




In [7]:
im_res=28 #Image resolution
numchannel=1 #Input image channels, 1 because 28x28x1 image
dim_=1 #Ouput image channels, 1 because this is classification

autoload=False #Autoload selection

batch_size_test = 1000 #Batch size for test
batch_size_train = 64 #Batch size for train

#Setup dataloaders, display sample image, and define other key variables

Setup dataloaders

In [None]:

trainloader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(prefix+'/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

testloader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(prefix+'/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

print ("Number of training batches: ", len(trainloader), "batch size= ", batch_size_train, "total: ",len(trainloader)*batch_size_train )
print ("Number of test batches: ", len(testloader), "batch size= ", batch_size_test, "total: ",len(testloader)*batch_size_test)

print("TOTAL images (account for full batches): ", len(trainloader)*batch_size_train+len(testloader)*batch_size_train )
 
nclasses=len(trainloader.dataset.classes)
print("Total classes: ",nclasses)

Display sample image

In [None]:
input=next(iter(testloader))[0]
image = to_pil(input[0])
print(image.size)
plt.imshow(image,cmap='gray')
plt.show()

Define device

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if CPUonly == True:
     print ("Using CPU.")
     device = torch.device("cpu")

#Autoload/define model and setup criterion, optimizer, and scheduler

Autoload/define model

In [11]:
directory=os.listdir(prefix_models)

maxepoch=0
for item in directory:
    if "txt" in item:
        continue
    num=int(item.split(".")[0].split("_")[2])
    if num>maxepoch:
        maxepoch=num


if autoload:
    if maxepoch==0:
        startepoch=1
    else:
        startepoch=maxepoch
    model=torch.load(f"{prefix_models}model_epoch_{startepoch-1}.pth")
if not autoload:
    startepoch=0
    
    model = PerceiverIO(
        dim = im_res*im_res*numchannel,                    # dimension of sequence to be encoded
        queries_dim = 32,            # dimension of decoder queries
        logits_dim = nclasses,            # dimension of final logits
        depth = 6,                   # depth of net
        num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
        latent_dim = 512,            # latent dimension
        cross_heads = 1,             # number of heads for cross attention. paper said 1
        latent_heads = 8,            # number of heads for latent self attention, 8
        cross_dim_head = 64,         # number of dimensions per cross attention head
        latent_dim_head = 64,        # number of dimensions per latent self attention head
        weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
    )
    model.to(device)

Define criterion, optimizer, and scheduler

In [12]:
criterion =  nn.MSELoss()
optimizer = optim.Adam(model.parameters() , lr=0.0001 )
scheduler = ExponentialLR(optimizer, gamma=0.9)

Define queries

In [13]:
queries=torch.zeros(1,32)
queries=queries.to(device)

#Utility functions

Unflatten

In [14]:
m2 = nn.Sequential(
  nn.Unflatten (1, (dim_, im_res, im_res))
)

Flattening function

In [15]:
def conv_flattened_to_image_ind (inputs):
    outputs_for_image= torch.clone  (inputs)
    outputs_for_image=torch.permute(outputs_for_image, (0,2,1)  )
    outputs_for_image=torch.flatten(outputs_for_image, start_dim=1, end_dim=2)
    outputs_for_image = m2(outputs_for_image)

    return outputs_for_image


#Train

Train the model.
Model autosaves and displays plentiful information

In [None]:
#@title Default title text
lowestloss = 1000

epochs = 50
numimgs = 8
steps = 0
print_every  = len (trainloader)
running_loss = 0.0

torch.cuda.empty_cache()

for epoch in range(startepoch, epochs):
    train_losses_epoch, test_losses_epoch, accuracy, val_acc = [],[], [], []

    print(f"Epoch {epoch+1}/{epochs}")
    for inputs, labels in tqdm.tqdm(trainloader):
        steps += 1

        #Perpute input
        inputs=torch.permute(inputs, (0,2,3,1))
        
        #One-hot encode labels
        labels=F.one_hot(labels,num_classes=nclasses)
        labels.unsqueeze_(1)

        #Reset gradients
        optimizer.zero_grad()
        
        #Flatten inputs
        inputs=torch.flatten(inputs, start_dim=1, end_dim=3)
        inputs.unsqueeze_(1)
        
        #Move inputs/labels to device
        inputs,labels = inputs.to(device),labels.to(device)

        #Run model, with inputs and queries as input
        outputs=model(inputs,queries=queries )

        #Get loss
        loss = criterion(outputs.float(), labels.float() )
        
        #Backpropogate loss
        loss.backward()

        #Step optimizer
        optimizer.step()

        train_losses_epoch.append(loss.item())


    #Set model into eval mode
    model.eval()
    
    print("\nNow evaluate test batches...")
    with torch.no_grad():
        for inputs, labels  in testloader:
            #Clone labels
            labels_=labels.clone()
            #Permute input
            inputs=torch.permute(inputs, (0,2,3,1)  )
            
            #One-hot encode labels
            labels=F.one_hot(labels,num_classes=nclasses)
            labels.unsqueeze_(1)

            #Reset gradients
            optimizer.zero_grad()
            
            #Flatten inputs
            inputs=torch.flatten(inputs, start_dim=1, end_dim=3)
            inputs.unsqueeze_(1)

            #Move inputs/labels/labels_ to device
            inputs,labels,labels_ = inputs.to(device),labels.to(device),labels_.to(device)

            #Run model, with inputs and queries as input
            outputs=model(inputs,queries=queries )

            #Calculate test loss
            outputs_=outputs.argmax(-1).cpu().detach().numpy()
            batch_loss =   criterion(outputs.float(), labels.float() ) #estimate loss for test batch

            test_losses_epoch.append(batch_loss.item())
                

            
            #Display images
            numr=numimgs-3 #  columns

            #Unflatten the data, make it into an image array
            outputs_for_image_inputs=conv_flattened_to_image_ind(inputs)
            fig=plt.figure(figsize=(8,2*numimgs ))
            imcount=0

            #ii_ is counter of images drawn from batch
            ii_=0
            while ii_<numimgs: 
                print ("Image # from batch considered: ", ii_)
                
                image = to_pil( outputs_for_image_inputs[ii_,:]  )

                print(f"Predicted class: {outputs_[ii_][0]}")
                print(f"Real class: {labels_[ii_]}")
                sub = fig.add_subplot(numr, 3 ,   imcount +1)
                sub.set_title(f"Image sample\nPredicted class: {outputs_[ii_][0]}\nReal class: {labels_[ii_]}" )
                plt.axis('off')
                plt.imshow(image,cmap='gray')
                
                imcount=imcount+1

                ii_=ii_+1 #Next image within test batch
                
            
            plt.savefig(prefix_images+f"{modelname}/image_epoch_{epoch}.png")
            plt.show()

            #Calculate accuracy
            ps = torch.exp(outputs.squeeze()) #Squeeze used to remove dimension 1
            top_p, top_class = ps.topk(1, dim=1)
            equals = top_class == labels_.view(*top_class.shape)

            accuracy.append(torch.mean(equals.type(torch.FloatTensor)).item())
            
            break
            
                
    print("Evaluation of test images done.")
    val_acc.append(sum(accuracy)/len(accuracy)) 
        
    print(f"Epoch {epoch+1}/{epochs}\n"
            f"Train loss: {sum(train_losses_epoch)/len(train_losses_epoch):.6f}\n "
            f"Test loss: {sum(test_losses_epoch)/len(test_losses_epoch):.6f}\n "
            f"Accuracy: {sum(accuracy)/len(accuracy):.6f}"
            )

    #Set model into train mode
    model.train()


    #Save model for current epoch
    namesve = prefix_models+f"model_epoch_{epoch}.pth"
    torch.save(model, namesve)

    #Save losses and accuracy
    with open(prefix_models+"train_loss.txt","a") as file:
        for item in train_losses_epoch:
            file.write(f"E{epoch}_{item}\n")

    with open(prefix_models+"test_loss.txt","a") as file:
        for item in test_losses_epoch:
            file.write(f"E{epoch}_{item}\n")
    
    with open(prefix_models+"accuracy.txt","a") as file:
        for item in val_acc:
            file.write(f"E{epoch}_{item}\n")

    #Step scheduler
    scheduler.step()

Save final model

In [None]:
namesve = prefix_models+f"model_final.pth"
torch.save(model, namesve)

print('Finished Training')