In [12]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from pathlib import Path
import torch.nn.init as init

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.ground_truth = None
        self.labels = None 
        self.img_transform_pre = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])
        self.img_transform_post = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])
        self.cnn1 = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            nn.PReLU(),
            nn.AvgPool2d(2, stride=2),
            nn.Conv2d(6, 2, kernel_size=5),
            nn.PReLU(),
            nn.AvgPool2d(2, stride=2))

#         self.fc1 = nn.Sequential(
#             #nn.Linear(50 * 4 * 4, 500),
#             nn.Linear(2 * 61 * 61, 512),
#             nn.ReLU(inplace=True),
#             nn.Linear(512, 256),
#             nn.Linear(256, 128),
#             nn.Linear(128, 32),
#             nn.Linear(32, 2))
        self.f1 = nn.Linear(2 * 29 * 29, 1024)
        self.t = nn.PReLU()
        self.f2 = nn.Linear(1024, 256)
        self.f3 = nn.Linear(256, 64)
        self.f4 = nn.Linear(64, 16)
        self.f5 = nn.Linear(16, 4)
        self.f6 = nn.Linear(4, 2)
        
        init.xavier_uniform(self.f1.weight, gain=np.sqrt(2))
        init.xavier_uniform(self.f2.weight, gain=np.sqrt(2))
        init.xavier_uniform(self.f3.weight, gain=np.sqrt(2))
        init.xavier_uniform(self.f4.weight, gain=np.sqrt(2))
        init.xavier_uniform(self.f5.weight, gain=np.sqrt(2))
        init.xavier_uniform(self.f6.weight, gain=np.sqrt(2))

    def forward_once(self, x):
        output = self.cnn1(x)
        #print(output.shape)
        #exit()
        output = output.view(output.size()[0], -1)
        output = self.f1(output)
        self.t(output)
        output = self.f2(output)
        output = self.f3(output)
        output = self.f4(output)
        output = self.f5(output)
        output = self.f6(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2
    
    def regularize(self, lam=0.00055):
        loss = Variable(torch.zeros(1))
        for param in self.parameters():
            loss += torch.norm(param)
        loss *= lam
        return loss
    
    def process_image(self, img_array):
        return self.img_transform_post(img_array)
    
    def classify(self, current, omit=None):
        #current = Image.fromarray(current)
        current = self.img_transform_post(current)
        #current = current.unsqueeze(0)
        inputs1 = []
        inputs2 = []
        for i, t in enumerate(self.ground_truth):
            if omit is None or i != omit:
                inputs1.append(t.unsqueeze(0))
                inputs2.append(current.unsqueeze(0))
        inputs1 = Variable(torch.cat(inputs1))
        inputs2 = Variable(torch.cat(inputs2))
        output1, output2 = self(inputs1, inputs2)
        distances = F.pairwise_distance(output1, output2).data.numpy()
        i = 0
        for j, d in enumerate(distances):
            if distances[j] < distances[i]:
                i = j
        return distances, i, self.labels[i]
    
    
    
    
        


In [13]:
import torch
model = torch.load('recognizer.pkl')

In [14]:
fishes = ['lionfish', 'bluecrab']
training_num = [1,2,3,4,5,6,7,8]
testing_num = [9,10,11,12]
path = Path('../data')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
raw_data1 = []
labels1 = []
for fish in fishes:
    for num in testing_num:
        input_path = path / f'{fish}_{num}.jpg'
        img = Image.open(input_path)
        arr = transform(img)
#         arr = np.array(img)
# #         arr = torch.FloatTensor(arr)
# #         arr = arr.unsqueeze(0)
        raw_data1.append(arr)
        labels1.append(fish)

In [15]:
model.classify(raw_data1[0])

(array([[ 1.33662808],
        [ 1.85500658],
        [ 3.21879148],
        [ 2.59950399],
        [ 4.80777025],
        [ 4.82734919],
        [ 3.71394777],
        [ 3.71394777],
        [ 0.3742694 ],
        [ 3.4450171 ],
        [ 4.59232664],
        [ 1.48784506],
        [ 2.33763671],
        [ 1.22073984],
        [ 4.49013567],
        [ 3.44501805]], dtype=float32), 8, 'bluecrab')