In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import skimage
import skimage.transform
import skimage.filters
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda
import torch.optim as optim
import time
import cv2
import re
from tqdm.notebook import tqdm
from torch.nn.functional import relu
from torchsummary import summary
from sklearn.utils import shuffle
from SHG import SHG
from utils import *

In [None]:
TRAIN_LABELS_PATH = "D:/bsc_data/train/outputs.txt"
#TEST_LABELS_PATH = "D:/bsc_data/test/outputs.txt"
VAL_LABELS_PATH = "D:/bsc_data/validation/test.txt"

HEADER = ["ID"]
for i in range(17):
    HEADER.append("x{}".format(i))
    HEADER.append("y{}".format(i))
    HEADER.append("v{}".format(i))

train_labels = pd.read_csv(TRAIN_LABELS_PATH, delimiter = ",", names = HEADER)
#test_labels = pd.read_csv(TEST_LABELS_PATH, delimiter = ",", names = HEADER)
val_labels = pd.read_csv(VAL_LABELS_PATH, delimiter = ",", names = HEADER)
print(val_labels.shape)

TRAIN_IMGS_PATH = "D:/bsc_data/train/image/"
#TEST_IMGS_PATH = "D:/bsc_data/test/image/"
VAL_IMGS_PATH = "D:/bsc_data/validation/image/"

train_imgs = os.listdir(TRAIN_IMGS_PATH)
#test_imgs = os.listdir(TEST_IMGS_PATH)
val_imgs = os.listdir(VAL_IMGS_PATH)

train_labels, train_imgs = shuffle(train_labels, train_imgs)

In [None]:
LEARNING_RATE = 2.5e-4
NUM_EPOCHS = 100
MINI_BATCH_SIZE = 16
MINI_BATCHES = np.array_split(train_imgs, len(train_imgs)/MINI_BATCH_SIZE)
SAVED_MODEL_PATH = "D:/bsc_data/models/Wed_Mar_17_16-05-12_2021/epoch_0.pth"
cur_model_path = None
start_epoch = 0
average_loss = []
validation_loss = []

try:
    average_rgb = np.loadtxt("./average_rgb.npy")
except:
    average_rgb = get_mean_rgb(TRAIN_IMGS_PATH, train_imgs)
    np.savetxt("./average_rgb.npy", average_rgb)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SHG(num_hourglasses=1).to(device)

if SAVED_MODEL_PATH is not None:
    model.load_state_dict(torch.load(SAVED_MODEL_PATH))
    start_epoch = int(re.findall("(?<=epoch_)(.*)(?=.pth)", SAVED_MODEL_PATH)[0]) + 1
    cur_model_path = re.findall("^(.*)(?=epoch)", SAVED_MODEL_PATH)[0]
    average_loss = np.loadtxt(cur_model_path + "/loss.npy", delimiter = ",")
    
    try:
        validation_loss = np.loadtxt(cur_model_path + "/val_loss.npy", delimiter = ",")
    except:
        pass

criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr = LEARNING_RATE)

In [None]:
if (cur_model_path is None):
    cur_model_path = "D:/bsc_data/models/" + time.asctime().replace(" ", "_").replace(":", "-")
    os.mkdir(cur_model_path)

In [None]:
torch.cuda.empty_cache()

for epoch in tqdm(range(start_epoch, NUM_EPOCHS), desc = "EPOCH"):
    model.train()
    average_loss = np.append(average_loss, 0)
    for i, mini_batch in tqdm(enumerate(MINI_BATCHES[:5]), leave = False, desc = "MINI BATCH", total = len(MINI_BATCHES[:5])):
        # Creating data
        heatmaps = []
        imgs = []
        for img_name in mini_batch:
            heatmaps.append(create_heatmaps(train_labels.loc[train_labels["ID"] == img_name[:-4]].to_numpy()))
            img = plt.imread(TRAIN_IMGS_PATH + img_name)

            if (len(img.shape) == 2):
                img = grey_to_rgb(img)

            img -= average_rgb

            imgs.append(img)

        heatmaps = torch.FloatTensor(heatmaps).to(device)

        imgs = (torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2)).to(device)

        # Prediction
        predictions = model(imgs)

        # Backpropagation
        optimizer.zero_grad()

        loss = criterion(predictions.to(device), heatmaps)

        average_loss[-1] += loss.item()

        loss.backward()
        optimizer.step()

        if (i % 1000 == 0):
            print(loss)
    
    average_loss[-1] /= len(MINI_BATCHES)

    # Saving model
    torch.save(model.state_dict(), cur_model_path + "/epoch_{}".format(epoch) + ".pth")

    # Saving loss
    np.savetxt(cur_model_path + "/loss.npy", average_loss)

    print("average loss of epoch {}: {}".format(epoch, average_loss[-1]))

    # Validation
    model.eval()
    average_val_loss = 0
    for val_img in tqdm(val_imgs, leave = False, desc = "Valdiation", total = len(val_imgs)):
        heatmaps = create_heatmaps(val_labels.loc[val_labels["ID"] == val_img[:-4]].to_numpy())
        img = plt.imread(VAL_IMGS_PATH + val_img)

        if (len(img.shape) == 2):
            img = grey_to_rgb(img)

        img -= average_rgb

        heatmaps = torch.FloatTensor(heatmaps)
        heatmaps = heatmaps.reshape((1, heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2])).to(device)
        img = torch.from_numpy(np.array(img)).reshape((1, 3, img.shape[0], img.shape[1]))
        img = img.to(device)
        average_val_loss += criterion(model(img), heatmaps).item()
    average_val_loss /= len(val_imgs)
    validation_loss.append(average_val_loss)
    print("Validation loss at epoch {}: {}".format(epoch, average_val_loss))
    np.savetxt(cur_model_path + "/val_loss.npy", validation_loss)

In [None]:
x = plt.imread(VAL_IMGS_PATH + val_imgs[0])

gt_kp = val_labels.loc[val_labels["ID"] == val_imgs[0][:-4]].to_numpy()[0][1:]

model.eval()
x_tensor = torch.from_numpy(x).permute((2, 0, 1)).to(device)
x_tensor = x_tensor.reshape((1, x_tensor.shape[0], x_tensor.shape[1], x_tensor.shape[2]))
pred = model(x_tensor).cpu().data.numpy()[0]
print(pred.shape)
print(np.mean(pred))

img, pred_keypoints = draw_predicitions_and_gt(x, gt_kp, pred)
plt.imshow(img)
plt.show()