In [1]:
#!pip install import-ipynb
import import_ipynb
from RetinaUtils import Cortex
from PCN import PCN

import numpy as np
import cv2
import zmq
import pickle
from matplotlib import pyplot as plt
import zlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms

importing Jupyter notebook from RetinaUtils.ipynb
importing Jupyter notebook from PCN.ipynb


In [2]:
#Check if gpu is available
gpu = torch.cuda.is_available()
device = torch.device('cuda:0' if gpu else 'cpu')

In [3]:
#Load Piotr's Cortex
C = Cortex(gpu=gpu)
data_dir = '../retina_data/cortices/'
C.loadLocs(data_dir+'50k_Lloc_tight.pkl', data_dir+'50k_Rloc_tight.pkl')
C.loadCoeffs(data_dir+'50k_Lcoeff_tight.pkl', data_dir+'50k_Rcoeff_tight.pkl')

In [4]:
#Load the pre-trained model
model = PCN(num_classes=11, circles=4)
model_dict = torch.load('../models/final_PCN_cortical.pt', map_location=device)
model.load_state_dict(model_dict['model_state_dict'])
model.to(device)

PCN(
  (PC_block1): Sequential(
    (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): PcConv(
      (relu): ReLU(inplace=True)
      (FFconv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bypass): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
  )
  (PC_block2): Sequential(
    (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): PcConv(
      (relu): ReLU(inplace=True)
      (FFconv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (FBconv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bypass): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
  )
  (PC_block3): Sequential(
    (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r

In [18]:
#Connect to the Client
context = zmq.Context()
socket = context.socket(zmq.PULL)
socket.bind("tcp://*:5555")

In [19]:
#Get number of batches being sent
dataset_length = int(socket.recv())

#Used to calculate final accuracy
correct=0
total=0

#Receive the images
for i in range(0, dataset_length):
    
    data = socket.recv()
    #Decompress and load the retinal vector
    data = zlib.decompress(data)
    vectors, labels = pickle.loads(data)
    #Backproject into cortical image
    batch = []
    for V in vectors:
        #Model expects normalised images, so divide by 255.0
        cortical_img = C.cort_img(V).astype('float32')/255.0
        batch.append(cortical_img)
    
    #Get the batches to the correct dimensions
    batch = torch.as_tensor(batch).permute(0,3,1,2).to(device)
    batch = F.interpolate(batch, (64,64))
    labels = torch.as_tensor(labels).float().to(device)
    
    #Calculate accuracy
    with torch.no_grad():
        #Get prediction
        out = model(batch)
        _, predicted = torch.max(out.data, 1)
        total += out.size(0)
        #Check if its correct
        correct += (predicted==labels).sum().item()
    
#Print final accuracy
print(f'Accuracy: {100.0*correct/total}')

Accuracy: 90.0
