In [None]:
from torch.utils.data import Dataset
import torch
from torch import optim
from torch import nn as nn
import os
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import pandas as pd
import glob
from scipy.spatial import distance
from sklearn.neighbors import NearestNeighbors
import torchvision.transforms as T
from PIL import Image
from utility import img_transform, EmbeddingHead
from datetime import date

In [None]:
DATA_PATH = "data/pblock-32965-idx_280x175"
NUM_TRAIN = 6593 # 20%
NUM_TEST = 26372 # 80%
NUM_TOTAL = 35912

In [None]:
df = pd.read_csv(os.path.join(DATA_PATH, "metadata.csv"))
df.head()

In [None]:
class PalletDataset():
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert('RGB')
        label = self.target[idx]
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        return img, label

In [None]:
train_selector = (df["target"] < NUM_TRAIN)
trainset = PalletDataset(df.loc[train_selector, "path"].values, df.loc[train_selector, "target"].values, transform=img_transform, target_transform=lambda x: torch.tensor(x, dtype=torch.long))
trainevalset = PalletDataset(df.loc[train_selector, "path"].values, df.loc[train_selector, "target"].values, transform=lambda x: img_transform(x, is_eval=True), target_transform=lambda x: torch.tensor(x, dtype=torch.long))
print("Trainset: ", len(trainset))
print("Trainevalset: ", len(trainevalset))

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)

In [None]:
class ClassifierHead(nn.Module):
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.fc1 = nn.Linear(in_features, n_classes)

    def forward(self, x):
        return self.fc1(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
net = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
net.fc = ClassifierHead(net.fc.in_features, NUM_TRAIN)
optimizer = optim.SGD(net.parameters(), lr=0.01)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[20, 40], gamma=0.3)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
EPOCHS = 80
SAVE_PATH = "model/"

net.to(device)
net.train()

for epoch in range(EPOCHS):
    running_loss = 0.0
    total = torch.tensor(0).to(device)
    correct = torch.tensor(0).to(device)
    for i, data in enumerate(trainloader, 0):
        x, y = data
        optimizer.zero_grad()
        out = net(x.to(device))
        loss = criterion(out, y.to(device))
        loss.backward()
        optimizer.step()

        _, pred = torch.max(out.data, 1)
        total += y.size(0)
        correct += (pred == y.to(device)).sum().item()

        running_loss += loss.item()
        ratio = (correct / total).to("cpu")
        if i % 200 == 199:
            print(f'[{epoch + 1}, {i + 1:3d}] loss: {running_loss / 199:.3f} acc: ({100*ratio:.2f})')
            running_loss = 0.0
    lr_scheduler.step()
    torch.save(net.state_dict(), os.path.join(SAVE_PATH, "model_classifier_256x128_e" + str(epoch) + "_" + str(date.today()) + ".pth"))
print('Finished Training')

## Eval

In [None]:
trainevalloader = torch.utils.data.DataLoader(trainevalset, batch_size=256, shuffle=False, num_workers=8)

In [None]:
net.eval()
net.to(device)
with torch.no_grad():
    total = torch.tensor(0).to(device)
    correct = torch.tensor(0).to(device)
    for i, data in enumerate(trainevalloader, 0):
        x, y = data
        out = net(x.to(device))
        _, pred = torch.max(out.data, 1)
        total += y.size(0)
        correct += (pred == y.to(device)).sum().item()
acc = (correct / total * 100).to("cpu")
print("Trainset acc: %.2f" %(acc))