In [1]:
import os
from torch.utils.data import Dataset
import torch
from torch import optim
from torch import nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import torchvision.transforms as T
import pandas as pd
import glob
from PIL import Image
from sklearn.neighbors import NearestNeighbors
from typing import Callable, Sequence
from utility import img_transform, EmbeddingHead
from datetime import date

In [2]:
DATA_PATH = "data/pblock-32965-idx_280x175"
NUM_TRAIN = 6593 # 20%
NUM_TEST = 26372 # 80%
NUM_TOTAL = 35912

In [3]:
df = pd.read_csv(os.path.join(DATA_PATH, "metadata.csv"))
df.head()

Unnamed: 0,pallet_id,path,camera,frame,target,frame_rel
0,1001000000000002948,/home/nils/Documents/ude/pallet/data/pblock-32...,1,1009,0,0
1,1001000000000002948,/home/nils/Documents/ude/pallet/data/pblock-32...,1,1012,0,1
2,1001000000000002948,/home/nils/Documents/ude/pallet/data/pblock-32...,2,1007,0,0
3,1001000000000002948,/home/nils/Documents/ude/pallet/data/pblock-32...,2,1010,0,1
4,1001000000000002949,/home/nils/Documents/ude/pallet/data/pblock-32...,1,1234,1,0


In [4]:
class PalletTripletDataset():
    def __init__(self, img_df, transform=None, target_transform=None):
        self.img_df = img_df
        self.transform = transform
        self.target_transform = target_transform
        self.pallet_ids = img_df["target"].unique()

    def __len__(self):
        return len(self.pallet_ids)
        
    def __getitem__(self, idx):
        pid = self.pallet_ids[idx]
        rows = self.img_df.loc[self.img_df["target"] == pid]
        
        imgs = []

        for i in range(len(rows)):
            img = Image.open(rows.iloc[i]["path"]).convert("RGB")
            if self.transform:
                img = self.transform(img)
            imgs.append(img)
        
        return torch.stack(imgs), torch.ones(len(imgs)) * idx

In [5]:
train_selector = (df["target"] < NUM_TRAIN)
trainset = PalletTripletDataset(df.loc[train_selector], transform=img_transform, target_transform=lambda x: torch.tensor(x, dtype=torch.long))
print("Trainset: ", len(trainset))

Trainset:  6593


In [6]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=8)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [8]:
def _pairwise_distances(embeddings, squared=False):
    dot_product = torch.matmul(embeddings, torch.transpose(embeddings, 0, 1))
    square_norm = torch.diagonal(dot_product)
    distances = torch.unsqueeze(square_norm, 0) - 2.0 * dot_product + torch.unsqueeze(square_norm, 1)
    distances = torch.maximum(distances, torch.tensor(0.0))

    if not squared:
        mask = torch.eq(distances, 0.0).type(torch.float)
        distances = torch.sqrt(distances)
        distances = distances * (1.0 - mask)

    return distances

In [9]:
def _get_anchor_positive_triplet_mask(labels):
    mask = torch.zeros((len(labels), len(labels)))
    for i in range(len(labels)):
        for j in range(len(labels)):
            if(i != j and labels[i] == labels[j]):
                mask[i,j] = 1
    return mask

def _get_anchor_negative_triplet_mask(labels):
    mask = torch.zeros((len(labels), len(labels)))
    for i in range(len(labels)):
        for j in range(len(labels)):
            if(labels[i] != labels[j]):
                mask[i,j] = 1
    return mask

def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
    pairwise_dist = _pairwise_distances(embeddings, squared=squared)
    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).type(torch.float).to(device)
    anchor_positive_dist = torch.multiply(mask_anchor_positive, pairwise_dist)
    hardest_positive_dist = torch.max(anchor_positive_dist, dim=1, keepdim=True).values

    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).type(torch.float).to(device)
    max_anchor_negative_dist = torch.max(pairwise_dist, dim=1, keepdim=True).values
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
    hardest_negative_dist = torch.min(anchor_negative_dist, dim=1, keepdim=True).values


    triplet_loss = torch.maximum(hardest_positive_dist - hardest_negative_dist + margin, torch.tensor(0.0))
    triplet_loss = torch.mean(triplet_loss)
    return triplet_loss

In [10]:
net = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
net.fc = EmbeddingHead()
optimizer = optim.SGD(net.parameters(), lr=0.01)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[20, 40], gamma=0.3)

Using cache found in /home/nils/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
EPOCHS = 80
SAVE_PATH = "model/"

net.to(device)
net.train()
for epoch in range(EPOCHS):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x, y = data
        x = torch.flatten(x, end_dim=1)
        y = torch.flatten(y, end_dim=1)
        optimizer.zero_grad()
        out = net(x.to(device))
        loss = batch_hard_triplet_loss(y.to(device), out, margin=0.5, squared=False)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 199:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 199:.3f}')
            running_loss = 0.0
    lr_scheduler.step()
    torch.save(net.state_dict(), os.path.join(SAVE_PATH, "model_triplet_256x128_" + str(date.today()) + ".pth"))
print('Finished Training')

[1,   200] loss: 0.459
[1,   400] loss: 0.349
[1,   600] loss: 0.249
[1,   800] loss: 0.192
[2,   200] loss: 0.142
[2,   400] loss: 0.128
[2,   600] loss: 0.106
[2,   800] loss: 0.095
[3,   200] loss: 0.086
[3,   400] loss: 0.071
[3,   600] loss: 0.070
[3,   800] loss: 0.059
[4,   200] loss: 0.062
[4,   400] loss: 0.055
[4,   600] loss: 0.046
[4,   800] loss: 0.047
[5,   200] loss: 0.051
[5,   400] loss: 0.052
[5,   600] loss: 0.048
[5,   800] loss: 0.044
[6,   200] loss: 0.041
[6,   400] loss: 0.044
[6,   600] loss: 0.037
[6,   800] loss: 0.040
[7,   200] loss: 0.035
[7,   400] loss: 0.035
[7,   600] loss: 0.035
[7,   800] loss: 0.036
[8,   200] loss: 0.030
[8,   400] loss: 0.031
[8,   600] loss: 0.030
[8,   800] loss: 0.031
[9,   200] loss: 0.031
[9,   400] loss: 0.029
[9,   600] loss: 0.028
[9,   800] loss: 0.027
[10,   200] loss: 0.022
[10,   400] loss: 0.025
[10,   600] loss: 0.029
[10,   800] loss: 0.026
[11,   200] loss: 0.027
[11,   400] loss: 0.029
[11,   600] loss: 0.026
[11,