In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import TripletData as TripletDataloader
import numpy as np

def getLabels(classes):
    result = []
    for i in classes:
        if i[0] == i[1]:
            result.append(1)
        else:
            result.append(0)
    return result


def resultsCounter(corr, count, res, label):
    for i in range(len(label)):
        if(res[i] == label[i]):
            corr += 1
        count += 1
    return corr, count

def compareDistance(d1, d2):
    result = []
    for i in range(len(d1)):
        if (d1[i] < d2[i]):
            result.append(1)
        else:
            result.append(0)
    return result

class TripletNet(nn.Module):
    def __init__(self, convNet):
        super(TripletNet, self).__init__()
        self.convNet = convNet

    def forward(self, x, y, z):
        embedded_x = self.convNet(x)
        embedded_y = self.convNet(y)
        embedded_z = self.convNet(z)
        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
        return dist_a, dist_b, embedded_x, embedded_y, embedded_z
    
    
class SaimeseConv(torch.nn.Module):
    def __init__(self):
        super(SaimeseConv, self).__init__()
      
        self.conv1 = nn.Conv2d(1, 32, 14)
        self.conv2 = nn.Conv2d(32, 64, 12)
        #self.conv3 = nn.Conv2d(128, 256, 4)
        self.conv4 = nn.Conv2d(64, 128, 10)
        self.pool = nn.MaxPool2d(2)
        self.linear1 = nn.Linear(128*8*8, 2048)
        self.linear2 = nn.Linear(2048, 10)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = F.relu(x)
        x = self.pool(x)
        print(x.shape)

        x = self.conv2(x)
        print(x.shape)
        x = F.relu(x)
        x = self.pool(x)
        print(x.shape)
        x = self.conv4(x)
        print(x.shape)
        x = F.relu(x)
        #print(x.shape)
        x = x.view(x.shape[0], -1)
        #print(x.shape)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        return x
""" 
class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, 10)

        def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            return self.fc2(x)

"""

' \nclass Net(nn.Module):\n        def __init__(self):\n            super(Net, self).__init__()\n            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n            self.conv2_drop = nn.Dropout2d()\n            self.fc1 = nn.Linear(320, 50)\n            self.fc2 = nn.Linear(50, 10)\n\n        def forward(self, x):\n            x = F.relu(F.max_pool2d(self.conv1(x), 2))\n            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n            x = x.view(-1, 320)\n            x = F.relu(self.fc1(x))\n            x = F.dropout(x, training=self.training)\n            return self.fc2(x)\n\n'

In [2]:
width = 105
height = 105

trainSet, testSet = TripletDataloader.main('parsedData24.csv',0.9, 12, width, height)

numChars = 24
model = SaimeseConv()
tnet = TripletNet(model).cuda()
tnet.train()

optimizer = optim.Adam(tnet.parameters(),lr = 0.0002)
criterion = nn.TripletMarginLoss(margin=0.5, p=2)
#criterion = torch.nn.MarginRankingLoss(margin = args.margin)

trainCorrect = 0
trainCount = 0

print("~~~| Training |~~~")
for epoch in range(0,30):
    for batch_i, batch_data in enumerate(trainSet):
        A, P, N, classes = batch_data['A'],batch_data['P'], batch_data['N'],batch_data['character']
        A = A.unsqueeze(1)
        P = P.unsqueeze(1)
        N = N.unsqueeze(1)
        # x1: A,  x2: P, x3: N
        x1 = Variable(torch.tensor(A, dtype=torch.float32)).cuda()
        x2 = Variable(torch.tensor(P, dtype=torch.float32)).cuda()
        x3 = Variable(torch.tensor(N, dtype=torch.float32)).cuda()
        #print(x1.shape)
        y = Variable(torch.tensor(getLabels(classes),dtype=torch.float32)).cuda()
        
        dista, distb, out1, out2, out3 = tnet(x1, x2, x3)
   
        result = compareDistance(dista, distb)
        trainCorrect, trainCount = resultsCounter(trainCorrect, trainCount, result, y)
        loss = criterion(out1, out2, out3)
        #print(loss.item())
        #loss_triplet = criterion(dista, distb, target)
        #loss_embedd = embedded_x.norm(2) + embedded_y.norm(2) + embedded_z.norm(2)
        #loss = loss_triplet + 0.001 * loss_embedd
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
print("Training Accuracy: ", trainCorrect/trainCount)
        
print("\n\n~~~| Testing |~~~")
model.eval()
testCorrect = 0
testCount = 0

for batch_i, test_data in enumerate(testSet):
    A, P, N, classes = test_data['A'],test_data['P'], test_data['N'],test_data['character']
    A = A.unsqueeze(1)
    P = P.unsqueeze(1)
    N = N.unsqueeze(1)
    # x1: A,  x2: P, x3: N
    x1 = Variable(torch.tensor(A, dtype=torch.float32)).cuda()
    x2 = Variable(torch.tensor(P, dtype=torch.float32)).cuda()
    x3 = Variable(torch.tensor(N, dtype=torch.float32)).cuda()
    #print(x1.shape)
    y = Variable(torch.tensor(getLabels(classes),dtype=torch.float32)).cuda()

    dista, distb, out1, out2, out3 = tnet(x1, x2, x3)
    result = compareDistance(dista, distb)
            
    testCorrect, testCount = resultsCounter(testCorrect, testCount, result, y)
    print("Loss: ", round(loss.item(),5))
    print("Generated: ", result)
    print("Label: ", y)

print("Testing Accuracy: ",testCorrect/testCount)


~~~| ModelDataloader.py Execution |~~~
Loaded dataset
~~~| ModelDataloader.py Complete |~~~

~~~| Training |~~~
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])




torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12

torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([1

torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([1

torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])
torch.Size([12, 1, 105, 105])
torch.Size([12, 32, 92, 92])
torch.Size([12, 32, 46, 46])
torch.Size([12, 64, 35, 35])
torch.Size([12, 64, 17, 17])
torch.Size([12, 128, 8, 8])


KeyboardInterrupt: 