##**Classifier model based on PerceiverIO**
##Main driver for The Neuron
Eric Buehler 2022

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive',force_remount=True)
prefix='/content/drive/MyDrive/Colab Notebooks/neuron'
prefix_='/content/drive/MyDrive/"Colab Notebooks"/neuron'
modelname="4_24_22_m2"

prefix_models=prefix+"/models/"+modelname+"/"

if not os.path.exists(prefix_models):
    os.makedirs(prefix_models)
os.chdir(prefix)

In [None]:
import os,sys
import math
import matplotlib.pyplot as plt

import numpy as np
import cv2
import PIL 
from PIL import Image, ImageOps

import pickle
import tqdm


In [None]:
!nvidia-smi -L

In [None]:
!nvidia-smi 

In [None]:
import torch
 
import torchvision
 
import matplotlib.pyplot as plt
import numpy as np
 
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR, StepLR, LambdaLR

print("Torch version:", torch.__version__) 

In [None]:
!git clone https://github.com/lucidrains/perceiver-pytorch
!cd perceiver-pytorch/
!pip install perceiver-pytorch

In [None]:
import torch
from perceiver_pytorch import Perceiver
from perceiver_pytorch import PerceiverIO
 
from torch.utils.data import DataLoader,Dataset
from torchvision.io import read_image
import pandas as pd
from sklearn.model_selection import train_test_split

from PIL import Image

to_pil = transforms.ToPILImage()

In [None]:
from torch.autograd import Variable

In [None]:
numchannel=3
batchSize=8
CPUonly=False

In [None]:
class ImageDataset(Dataset):
    def __init__(self,dir,mode,images,transform):
        self.images=list(images)
        self.mode=list(mode["Mode"])
        self.speed=list(mode["Speed"])
        self.transform=transform
        self.dir=dir+"/images/image_"
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self,index):
        im_pil = Image.open(self.dir+str(int(self.images[index]))+".png").convert('RGB')

        image=self.transform(im_pil)

        return (self.mode[index],image)

In [None]:
"""
Left=1
Right=2
Fwd=3
Bwd=4
Stop=0
"""

In [None]:
import torchvision.transforms as T

data_dir = './data_v4'

def load_split_train_test(data_dir,csvfile, valid_size = .2):
    train_transforms = transforms.Compose([
                                     transforms.Resize((80,160)),
                                       transforms.ToTensor()
                                       ])
    test_transforms = transforms.Compose([
                                         transforms.Resize((80,160)),
                                       transforms.ToTensor()
                                       ])

    df=pd.read_csv(csvfile)
    mode=df.iloc[:,1:]
    images=df.iloc[:,0]
    X_train, X_test, y_train, y_test =train_test_split(mode,images,test_size=valid_size)
    train_data=ImageDataset(data_dir,X_train,y_train,train_transforms)
    test_data=ImageDataset(data_dir,X_test,y_test,test_transforms)
    trainloader = torch.utils.data.DataLoader(train_data, batch_size=batchSize)

    testloader = torch.utils.data.DataLoader(test_data, batch_size=1)
    return trainloader, testloader

csvfile=data_dir+"/train.csv"
dataloader, testloader = load_split_train_test(data_dir,csvfile, .5)
print("Number of training batches: ", len(dataloader), "batch size= ", batchSize, "total: ",len(dataloader)*batchSize)
print("Number of test batches: ", len(testloader), "batch size= ", 1, "total: ",len(testloader))

print("TOTAL images (account for full batches): ", len(dataloader)*batchSize+len(testloader) )

In [None]:
print(len(dataloader))
print(len(dataloader)*batchSize)

In [None]:
input=next(iter(testloader))[1]
image = to_pil(input[0])
print(image.size)
plt.imshow(image)
plt.show()

In [None]:
im_resx=image.size[0]
im_resy=image.size[1]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")

if CPUonly == True:
     device = torch.device("cpu")
     print("CPU")

In [None]:
from torch.nn.modules.dropout import Dropout2d


class network(nn.Module):
    def __init__(self,nclasses):
        super(network, self).__init__()
        self.queries_dim_=256
        latent_dim_=784
        logits_dim= nclasses
        embed_dim=128
        embed_dim_min=16
        dim=3
        self.embed_dim_=embed_dim

        self.convert_up = nn.Sequential( #numchannel -> embed_dim
            nn.Linear(  numchannel, 4),
            nn.Linear(  4, 8),
            nn.Linear(  8, 16),
            nn.Linear(  16, 32),
            nn.Linear(  32, 48),
            nn.Linear(  48, embed_dim)
        )

        self.encoders=[]
        embed_dim_=embed_dim
        while embed_dim_>embed_dim_min:
            layer=nn.Conv2d(in_channels=embed_dim_, out_channels=int(embed_dim_/2), kernel_size=3, stride=1, padding=1).to(device)
            self.encoders.append(layer)
            embed_dim_=int(embed_dim_/2)

        self.decoders=[]
        embed_dim_=embed_dim_min
        while embed_dim_<embed_dim:
            layer=nn.Conv2d(in_channels=embed_dim_, out_channels=int(embed_dim_*2), kernel_size=3, stride=1, padding=1).to(device)
            self.decoders.append(layer)
            embed_dim_=int(embed_dim_*2)
            
        self.convs=[]
        for item in self.encoders:
            self.convs.append(item)
        self.convs.append(nn.BatchNorm2d(embed_dim_min))
        for item in self.decoders:
            self.convs.append(item)
        self.convs.append(nn.BatchNorm2d(embed_dim))
        self.convs=nn.ModuleList(self.convs)

        self.query_gen = nn.Sequential( # embed_dim*4 -> self.queries_dim_
            nn.Linear(  embed_dim*4, 384),
            nn.Linear(  384, 512),
            nn.Linear(  512, 640),
            nn.Linear(  640, 784),
            nn.Linear(  784, 1024),
            nn.Linear(  1024, 784),
            nn.Linear(  784, 640),
            nn.Linear(  640, 512),
            nn.Linear(  512, self.queries_dim_),
        )
        
        self.convert_down = nn.Sequential( #im_resx*im_resy -> 1
            nn.Linear(  im_resx*im_resy, 2048),
            nn.Linear(  2048, 1024),
            nn.Linear(  1024, 784),
            nn.Linear(  784, 512),
            nn.Linear(  512, 256),
            nn.Linear(  256, 128),
            nn.Linear(  128, 1),
        )

        self.pos_emb_x = nn.Embedding(im_resy, embed_dim*1)
        self.pos_emb_y = nn.Embedding(im_resx, embed_dim*1)

        self.pos_matrix_i = torch.zeros (im_resx, im_resy, dtype=torch.long)
        self.pos_matrix_j = torch.zeros (im_resx, im_resy,dtype=torch.long)
        for i in range(im_resy):
            for j in range(im_resx):
                self.pos_matrix_i [j,i]=i
                self.pos_matrix_j [j,i]=j
                       
        self.pos_matrix_j =torch.flatten(self.pos_matrix_j , start_dim=0, end_dim=1) 
        self.pos_matrix_i =torch.flatten(self.pos_matrix_i , start_dim=0, end_dim=1)  

        self.model = PerceiverIO(
            dim = embed_dim*4,
            queries_dim = self.queries_dim_,
            logits_dim = nclasses,
            depth = 12,
            num_latents = 512,
            latent_dim = latent_dim_,
            cross_heads = 1,
            latent_heads = 8,
            cross_dim_head = 64,
            latent_dim_head = 64,
            weight_tie_layers = False,
            decoder_ff=True
        ).to(device)

        self.softmax=nn.Softmax(dim=-1)

    def forward(self, x):
        x=torch.permute(x,(0,2,3,1))
        x=self.convert_up(x)
        x=torch.permute(x,(0,3,1,2))
        x_=x.clone() 
        x=torch.flatten(x, start_dim=2, end_dim=3) 

        for layer in self.convs:
            x_=layer(x_)
        x_=torch.flatten(x_, start_dim=2, end_dim=3) 

        x_=torch.permute(x_, (0,2,1)  )
        x=torch.permute(x, (0,2,1)  )

        x=torch.cat([x,x_],dim=2)
        
        pos_matrix_j_=self.pos_matrix_j.repeat(x.shape[0], 1, 1).to(device=device) 
        pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1, 1).to(device=device) 
        

        pos_emb_y = self.pos_emb_y(pos_matrix_j_)
        pos_emb_y = torch.squeeze(pos_emb_y, 1)
        pos_emb_x = self.pos_emb_x( pos_matrix_i_)
        pos_emb_x = torch.squeeze(pos_emb_x, 1)

        catlist=[x,pos_emb_y,pos_emb_x]

        inputs= torch.cat(catlist, 2)
        queries=self.query_gen(inputs)
        outputs=self.model(inputs,queries=queries )

        outputs=torch.permute(outputs, (0,2,1)  )
        outputs=self.convert_down(outputs)
        outputs=torch.permute(outputs, (0,2,1)  )
        outputs=outputs.squeeze_()
        outputs=self.softmax(outputs)
        return outputs
        

In [None]:
print(im_resx*im_resy)

In [None]:
nclasses=5
startepoch=0

model = network(nclasses)
params=model.to(device)

#inputs=torch.randn(1,3,im_resy,im_resx).to(device)
#outputs=model(inputs)
#print(outputs.shape)

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
    print(table)

    print(f"Total params: {sum(p.numel() for p in model.parameters())}  Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}  Untrainable params: {sum(p.numel() for p in model.parameters())-sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [None]:
count_parameters(model)

In [None]:
!nvidia-smi

In [None]:
criterion =  nn.MSELoss()

optimizer = optim.AdamW(model.parameters() , lr=0.000005)

scheduler = StepLR(optimizer, gamma=0.9, step_size=2)


In [None]:
print(im_resx)
print(im_resy)

In [None]:
epochs=190
train_losses, test_losses,val_acc = [], [], []
steps = 0
print_every  = len (dataloader)
running_loss = 0.0

torch.cuda.empty_cache()

for epoch in range(startepoch, epochs):
    train_losses_epoch,test_losses_epoch=[],[]
    correct_train,correct_test=0,0
    print(f"Epoch {epoch+1}/{epochs}")
    for mode, inputs in tqdm.tqdm(dataloader):
        optimizer.zero_grad()
        
        steps += 1

        labels=mode.long()
        labels=F.one_hot(labels,num_classes=nclasses)
        
        inputs,labels = inputs.to(device),labels.to(device)
        outputs=model(inputs)
        loss = criterion(outputs.float(), labels.float() )
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        outputs_=outputs.argmax(-1).cpu().detach()
        labels_=labels.argmax(-1).cpu().detach()
        correct_train += (outputs_ == labels_).float().sum()

    accuracy_train = 100 * correct_train / (len(dataloader)*batchSize)
        
    test_loss = 0
    accuracy = 0
    model.eval()
    
    print("\nNow evaluate test batches...")
    with torch.no_grad():
        for mode, inputs in testloader:
            optimizer.zero_grad()

            labels=mode.long()
            labels=F.one_hot(labels,num_classes=nclasses)

            inputs,labels = inputs.to(device),labels.to(device)
            
            outputs=model(inputs)
            outputs = torch.unsqueeze(outputs, dim=0)
            outputs_=outputs.argmax(-1).cpu().detach()
            batch_loss = criterion(outputs.float(), labels.float() )
            test_loss += batch_loss.item()
                
            numr=batchSize-1

            labels_=labels.argmax(-1).cpu().detach()
            correct_test += (outputs_ == labels_).float().sum()

    accuracy_test = 100 * correct_test / len(testloader)

    

    train_losses.append(running_loss/print_every)
    train_losses_epoch.append(running_loss/print_every)
    test_losses.append(test_loss/len(testloader))    
    test_losses_epoch.append(test_loss/len(testloader))  
        
    print(f"Epoch {epoch+1}/{epochs} "
            f"Train loss: {running_loss/print_every:.6f} "
            f"Test loss: {test_loss/len(testloader):.6f} "
            f"Train accuracy: {accuracy_train}% "
            f"Test accuracy: {accuracy_test}% "
            )
    
    running_loss = 0
    model.train()
    
    with open(prefix_models+"train_loss.txt","a") as file:
        for item in train_losses_epoch:
            file.write(f"E{epoch}_{item}\n")

    with open(prefix_models+"test_loss.txt","a") as file:
        for item in test_losses_epoch:
            file.write(f"E{epoch}_{item}\n")

    with open(prefix_models+"train_acc.txt","a") as file:
        file.write(f"E{epoch}_{accuracy_train}\n")

    with open(prefix_models+"test_acc.txt","a") as file:
        file.write(f"E{epoch}_{accuracy_test}\n")

    scheduler.step()

    fgg=f"model_E{epoch}.pth"
    namesve = prefix_models+fgg
    torch.save(model, namesve)
    

print('Finished Training')

In [None]:
label,input=next(iter(testloader))
outputs=model(input.to(device))
print("OUTPUTS",outputs)
outputs_=outputs.argmax(-1).cpu().detach()
print("PRED",outputs_)

image = to_pil(input[0])
print(image.size)
plt.imshow(image)
plt.show()

print("REAL",label)

In [None]:

fgg=f"model_final.pth"
namesve = prefix_models+fgg
torch.save(model,namesve) 


In [None]:
for mode, inputs in testloader:
    labels=mode.long()
    labels=F.one_hot(labels,num_classes=nclasses)

    optimizer.zero_grad()
    
    inputs,labels = inputs.to(device),labels.to(device)
    outputs=model(inputs)
    print(outputs)
    
    loss = criterion(outputs.float(), labels.float() )
    pred_mode=outputs.argmax(-1).cpu().detach().numpy()
    real_mode=labels.argmax(-1).cpu().detach().numpy()
    print("Pred:",pred_mode," Real:",real_mode)