In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import skimage
import skimage.transform
import skimage.filters
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.nn.functional import relu
from tqdm.notebook import tqdm
from SHG import SHG
from utils import *

In [6]:
TRAIN_LABELS_PATH = "D:/bsc_data/train/outputs.txt"
TEST_LABELS_PATH = "D:/bsc_data/test/outputs.txt"
VAL_LABELS_PATH = "D:/bsc_data/validation/outputs.txt"

HEADER = ["ID"]
for i in range(17):
    HEADER.append("x{}".format(i))
    HEADER.append("y{}".format(i))
    HEADER.append("v{}".format(i))

train_labels = pd.read_csv(TRAIN_LABELS_PATH, delimiter = ",", names = HEADER)
test_labels = pd.read_csv(TEST_LABELS_PATH, delimiter = ",", names = HEADER)
val_labels = pd.read_csv(VAL_LABELS_PATH, delimiter = ",", names = HEADER).T

TRAIN_IMGS_PATH = "D:/bsc_data/train/image/"
TEST_IMGS_PATH = "D:/bsc_data/test/image/"
VAL_IMGS_PATH = "D:/bsc_data/validation/image/"

train_imgs = os.listdir(TRAIN_IMGS_PATH)
test_imgs = os.listdir(TEST_IMGS_PATH)
val_imgs = os.listdir(VAL_IMGS_PATH)

In [7]:
LEARNING_RATE = 2.5e-4
NUM_EPOCHS = 100
MINI_BATCH_SIZE = 16
MINI_BATCHES = np.array_split(train_imgs, len(train_imgs)/MINI_BATCH_SIZE)

try:
    average_rgb = np.loadtxt("./average_rgb.npy")
except:
    average_rgb = get_mean_rgb(TRAIN_IMGS_PATH, train_imgs)
    np.savetxt("./average_rgb.npy", average_rgb)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SHG(num_hourglasses=1).to(device)

criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr = LEARNING_RATE)

In [8]:
torch.cuda.empty_cache()
model.train()

for epoch in tqdm(range(NUM_EPOCHS), desc = "EPOCH"):
    average_loss = 0
    for mini_batch in tqdm(MINI_BATCHES, leave = False, desc = "MINI BATCH"):
        heatmaps = []
        imgs = []

        for img_name in mini_batch:
            heatmaps.append(create_heatmaps(train_labels.loc[train_labels["ID"] == img_name[:-4]].to_numpy()))
            img = plt.imread(TRAIN_IMGS_PATH + img_name)

            if (len(img.shape) == 2):
                img = grey_to_rgb(img)
            
            img -= average_rgb
            
            imgs.append(img)
        
        heatmaps = torch.FloatTensor(heatmaps).to(device)
        imgs = (torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2)).to(device)

        #predictions = torch.tensor(model(imgs), requires_grad = True)[0, :, :, :, :] # Removing unnecessary extra dimension
        predictions = model(imgs)

        optimizer.zero_grad()

        loss = criterion(predictions.to(device), heatmaps)
        print(loss)

        average_loss += loss/MINI_BATCH_SIZE

        loss.backward()
        optimizer.step()

    print("average loss of epoch {}: {}".format(epoch, average_loss))

    break
        


HBox(children=(FloatProgress(value=0.0, description='EPOCH', style=ProgressStyle(description_width='initial'))…

HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

tensor(0.0077, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(4.9720, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(1.0603e+12, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(5.8531e+09, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(3.6314e+08, device='cuda:0', grad_fn=<MseLossBackward>)



KeyboardInterrupt: 