In [1]:
import os
if os.getcwd().endswith("Nets"):
    os.chdir("..")

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
from torch.utils.data import TensorDataset, random_split, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal
import pandas as pd
import cv2
import PIL.Image
from tqdm import tqdm

from utils.FrameUtils import remove_distortion
from utils.ControlUtils import calc_distance
from Params import *
import CueNetV2

In [7]:
def train_epoch(model, train_loader, device, optimiser, loss_fn):
    model.train()
    total_loss, count = 0, 0
    for features, labels in tqdm(train_loader):
        features, labels = features.to(device), labels.to(device)
        predictions = model(features)
        loss = loss_fn(predictions, labels)
        total_loss += loss
        
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        count += len(labels)
    return total_loss.item() / count


def test_epoch(model, test_loader, device, loss_fn):
    model.eval()
    total_loss, count = 0, 0
    with torch.no_grad():
        for features, labels in test_loader:
            features, labels = features.to(device), labels.to(device)
            predictions = model(features)
            
            total_loss += loss_fn(predictions, labels)
            count += len(labels)
    return total_loss.item() / count

def train(model, train_loader, test_loader, device, optimiser, loss_fn, epochs=10):
    res = { 'train_loss': [], 'test_loss': []}
    for ep in range(epochs):
        train_loss = train_epoch(model, train_loader, device, optimiser, loss_fn)
        test_loss = test_epoch(model, test_loader, device, loss_fn)
        
        print(f"Epoch {ep:2}, Train loss={train_loss:.7f}, test loss={test_loss:.7f}")
        res['train_loss'].append(train_loss)
        res['test_loss'].append(test_loss)
    return res


In [8]:
data_index = pd.read_csv("MLData/labels.csv", header=None)
data_index

In [9]:
def create_label(x_pos, y_pos, variance=10, img_width=240, img_height=180):
    x = np.arange(img_width)
    y = np.arange(img_height)
    X, Y = np.meshgrid(x, y)
    mu = np.array([x_pos, y_pos])
    sigma = np.array([[variance, 0], [0, variance]])
    pos = np.empty(X.shape + (2,))
    pos[:, :, 0] = X
    pos[:, :, 1] = Y
    rv = multivariate_normal(mu, sigma)
    pd = rv.pdf(pos)
    return torch.from_numpy(pd)

In [10]:
def load_img(path):
    img = PIL.Image.open(path)
    img.load()
    return np.asarray(img, dtype=np.uint8)

def preprocess_img(img):
    img = remove_distortion(img)
    img = cv2.equalizeHist(img)
    return torch.squeeze(TF.to_tensor(img[PROCESSING_Y:PROCESSING_Y+PROCESSING_SIZE_HEIGHT, PROCESSING_X:PROCESSING_X+PROCESSING_SIZE_WIDTH].astype(np.float32) / 255))

def flip_h(img_stack):
    return torch.flip(img_stack, [2])

def flip_v(img_stack):
    return torch.flip(img_stack, [1])

def flip_hv(img_stack):
    return torch.flip(img_stack, [1, 2])

def get_noise(size, mean= 0, std=0.1):
    return torch.randn(size) * std + mean

def create_variations(sample, label):
    size = sample.size()
    s_h = flip_h(sample) + get_noise(size)
    l_h = torch.flip(label, [1])
    s_v = flip_v(sample) + get_noise(size)
    l_v = torch.flip(label, [0])
    s_hv = flip_hv(sample) + get_noise(size)
    l_hv = torch.flip(label, [0, 1])
    return s_h, l_h, s_v, l_v, s_hv, l_hv

In [11]:
frame_queue = [None] * 2

data = []
labels = []

for index, row in tqdm(data_index.iterrows(), total=data_index.shape[0]):
    img_name = row[0]
    pos_x = row[1]
    pos_y = row[2]
    if pos_x == 0 and pos_y == 0:
        frame_queue.pop(0)
        frame_queue.append(None)
    else:
        pos_x = row[1] - PROCESSING_X
        pos_y = row[2] - PROCESSING_Y
        frame = preprocess_img(load_img(img_name))
        label = create_label(pos_x, pos_y)
        if (not frame_queue[0] is None) and (not frame_queue[1] is None):
            sample = torch.stack((frame, frame_queue[0], frame_queue[1]))
            size = sample.size()
            data.append(sample + get_noise(size))
            labels.append(label)
            
            s_h, l_h, s_v, l_v, s_hv, l_hv = create_variations(sample, label)
            data.extend((s_h, s_v, s_hv))
            labels.extend((l_h, l_v, l_hv))
        sample = torch.stack((frame, frame, frame))
        data.append(sample + get_noise(sample.size()))
        labels.append(label)
        s_h, l_h, s_v, l_v, s_hv, l_hv = create_variations(sample, label)
        data.extend((s_h, s_v, s_hv))
        labels.extend((l_h, l_v, l_hv))
        
        frame_queue.pop(0)
        frame_queue.append(frame)

print(f"Loaded {len(data)} training samples")

In [16]:
dataset = TensorDataset(torch.stack(data, dim=0), torch.stack(labels, dim=0))

In [17]:
fig, ax = plt.subplots(ncols=4)
features, label = next(iter(dataset))
print(features.shape)
ax[0].imshow(features[0], cmap='gray')
ax[1].imshow(features[1], cmap='gray')
ax[2].imshow(features[2], cmap='gray')
ax[3].imshow(label, cmap='Reds')

In [18]:
def plot(features, prediction, label):
    fig, ax = plt.subplots(ncols=5)
    ax[0].imshow(features[0], cmap='gray')
    ax[1].imshow(features[1], cmap='gray')
    ax[2].imshow(features[2], cmap='gray')
    ax[3].imshow(prediction, cmap='Reds')
    ax[3].set_title("prediction")
    ax[4].imshow(label, cmap='Reds')
    ax[4].set_title("label")

def plot_sample(data_loader, model):
    features, labels = next(iter(data_loader))
    model.eval()
    with torch.no_grad():
        prediction = model(torch.unsqueeze(features[0].to(device), 0)).cpu().detach()
    plot(features[0], prediction[0], labels[0])

In [19]:
batch_size = 16
train_dataset, test_dataset, validation_dataset = random_split(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)
print(f"Train set size:     {len(train_dataset)}")
print(f"Test set size:       {len(test_dataset)}")
print(f"Validation set size: {len(validation_dataset)}")

In [20]:
model = CueNetV2.load_cue_net_v2()

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
def plot_all(data_loader, model):
    model.eval()
    with torch.no_grad():
        for features, labels in data_loader:
            predictions = model(features.to(device)).cpu().detach()
            for i in range(batch_size):
                plot(features[i], predictions[i], labels[i])
                plt.show()
            break

In [23]:
plot_all(validation_loader, model)

In [24]:
def find_center(pdf, template_size=13):
    center = int(template_size / 2)
    template = create_label(center, center, img_width=template_size, img_height=template_size).numpy()
    template = template.astype(np.float32)
    res = cv2.matchTemplate(pdf.astype(np.float32), template, method=cv2.TM_SQDIFF)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
    min_x, min_y = min_loc
    return min_x + center, min_y + center

In [25]:
find_center(create_label(40, 50).numpy().astype(np.float32))

In [26]:
features, labels = next(iter(validation_loader))
model.eval()
with torch.no_grad():
    prediction = model(torch.unsqueeze(features[0].to(device), 0)).cpu().detach()
    plt.imshow(prediction[0], cmap="Reds")
    x, y = find_center(prediction[0].numpy())
    plt.scatter(x, y)

In [27]:
def se_distance(data_loader, model):
    errors = []
    with torch.no_grad():
        for features, labels in data_loader:
            predictions = model(features.to(device)).detach().cpu().numpy()
            for i in range(predictions.shape[0]):
                predicted_pos = np.array(find_center(predictions[i]))
                label_pos = np.array(find_center(labels[i].numpy()))
                errors.append(calc_distance(predicted_pos, label_pos) ** 2)
    return errors

In [28]:
counts, bins = np.histogram(se_distance(validation_loader, model))
plt.hist(bins[:-1], bins, weights=counts)

In [59]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.L1Loss(reduction='sum')

num_epochs = 2

res = train(model, train_loader, test_loader, device, optimizer, loss_fn, num_epochs)

plt.plot(np.arange(num_epochs), res['train_loss'], label='train loss')
plt.plot(np.arange(num_epochs), res['test_loss'], label='test loss')

In [60]:
plot_all(validation_loader, model)

In [61]:
s_distances = se_distance(validation_loader, model)
print(np.mean(s_distances))
counts, bins = np.histogram(s_distances)
plt.hist(bins[:-1], bins, weights=counts, log=True)

In [58]:
torch.save({'state_dict': model.state_dict()}, 'Nets/4ep.pt')