In [None]:
from torch.utils.data import Dataset
import torch
from torch import optim
from torch import nn as nn
import os
import glob
import cv2
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T
from utility import img_transform, EmbeddingHead
from datetime import date

In [None]:
DATA_PATH = "data/pblock-32965-idx_280x175"
NUM_TRAIN = 6593 # 20%
NUM_TEST = 26372 # 80%
NUM_TOTAL = 35912

In [None]:
df = pd.read_csv(os.path.join(DATA_PATH, "metadata.csv"))
df.head()

In [None]:
class PalletTupleDataset():
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)-1
        
    def __getitem__(self, idx):
        if idx % 2 == 0:
            # genuine pair
            idx_one = idx
            idx_two = idx+1
            label = 0
        else:
            # impostor pair
            idx_one = idx
            idx_list = list(range(len(self.data)))
            idx_list.remove(idx)
            idx_list.remove(idx+1)
            idx_list.remove(idx-1)
            idx_two = np.random.choice(idx_list)
            label = 1

        img_one = Image.open(self.data[idx_one]).convert('RGB')
        img_two = Image.open(self.data[idx_two]).convert('RGB')
        
        if self.transform:
            img_one = self.transform(img_one)
            img_two = self.transform(img_two)
        if self.target_transform:
            label = self.target_transform(label)
        return img_one, img_two, label

In [None]:
class PalletDataset():
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert('RGB')
        label = self.target[idx]
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        return img, label

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
train_selector = (df["target"] < NUM_TRAIN)
trainset = PalletTupleDataset(df.loc[train_selector, "path"].values, df.loc[train_selector, "target"].values, transform=img_transform, target_transform=lambda x: torch.tensor(x, dtype=torch.long))
trainevalset = PalletDataset(df.loc[train_selector, "path"].values, df.loc[train_selector, "target"].values, transform=lambda x: img_transform(x, is_eval=True))

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)
trainevalloader = torch.utils.data.DataLoader(trainevalset, batch_size=128, shuffle=False, num_workers=8)

In [None]:
def calc_energy(emb1, emb2):
    return torch.sum(torch.abs(emb1 - emb2), dim=1)
    #return torch.sqrt(torch.sum(torch.pow(emb1-emb2, 2), dim=1))

def criterion(energy, labels):
  Q = np.sqrt(2048)
  #Q = 2
  result = (1 - labels) * (2/Q) * energy**2 + labels * 2 * Q * torch.exp(-2.77/Q * energy)
  return torch.mean(result)

In [None]:
net = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
net.fc = EmbeddingHead()
optimizer = optim.SGD(net.parameters(), lr=0.01)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[20, 40], gamma=0.3)

In [None]:
EPOCHS = 80
SAVE_PATH = "model/"

net.to(device)
net.train()
for epoch in range(80):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x1, x2, labels = data
        optimizer.zero_grad()
        out1 = net(x1.to(device))
        out2 = net(x2.to(device))
        energy = calc_energy(out1, out2)
        loss = criterion(energy, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 199:.5f}')
            running_loss = 0.0
    torch.save(net.state_dict(), os.path.join(SAVE_PATH, "model_siamese_256x128_" + str(date.today()) + ".pth"))
print('Finished Training')