In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda
import torch.optim as optim
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import time
import re
from PIL import Image
from tqdm.notebook import tqdm
from torch.nn.functional import relu
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from sklearn.utils import shuffle
from SHG import SHG
from utils import *

In [2]:
""" PATHS FOR DATA"""

# Input images
TRAIN_IMGS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/train/image/"
VAL_IMGS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/validation/image/"

# Ground truth heatmaps
TRAIN_HEATMAPS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/train/heatmaps/"
VAL_HEATMAPS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/validation\heatmaps/"

In [9]:
SAVED_MODEL_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/models/Wed_Mar_24_15-27-01_2021/epoch_0.pth"
cur_model_path = None
start_epoch = 0
train_loss = np.array([])
val_loss = np.array([])
scheduler = None

# Read the mean rgb if it has been calculated previously
try:
    average_rgb = np.loadtxt("./average_rgb.npy")
except:
    average_rgb = get_mean_rgb(TRAIN_IMGS_PATH)
    np.savetxt("./average_rgb.npy", average_rgb)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Define model
model = SHG(num_hourglasses=1).to(device)

# If we decide to use a saved model
if SAVED_MODEL_PATH is not None:
    model.load_state_dict(torch.load(SAVED_MODEL_PATH)) # loads the model
    start_epoch = int(re.findall("(?<=epoch_)(.*)(?=.pth)", SAVED_MODEL_PATH)[0]) + 1 # finds the starting epoch
    cur_model_path = re.findall("^(.*)(?=epoch)", SAVED_MODEL_PATH)[0] # finds the directory of the saved model
    train_loss = np.load(cur_model_path + "/loss.npy") # loads the average training loss
    val_loss = np.load(cur_model_path + "/val_loss.npy") # loads the validation loss
    scheduler = torch.load(cur_model_path + "scheduler.pth")

Device: cuda


In [10]:
if (cur_model_path is None): # if we do not use a saved model
    cur_model_path = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/models/" + time.asctime().replace(" ", "_").replace(":", "-")
    os.mkdir(cur_model_path)
    print("Created direction at", cur_model_path)

In [11]:
class dataset(Dataset):
    def __init__(self, X_path, y_path, average_rgb):
        self.X_path = X_path
        self.y_path = y_path
        self.X_data = os.listdir(self.X_path)
        self.average_rgb = average_rgb
        self.norm = transforms.Normalize(mean = self.average_rgb, std = [1, 1, 1])

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, i):
        ID = self.X_data[i]
        x = Image.open(self.X_path + ID)

        y = []
        for heatmap in os.listdir(self.y_path + ID[:-4]):
            y.append(torch.from_numpy(np.load(self.y_path + ID[:-4] + "/" + heatmap)))

        x = TF.to_tensor(x)

        if (x.shape[0] == 1): # If the image is gray-scale, cast it to rgb
            x = torch.stack((x[0],) * 3)

        x = self.norm(x) # Subtracts mean rgb

        y = torch.stack(y)
        return x, y

train_data = dataset(TRAIN_IMGS_PATH, TRAIN_HEATMAPS_PATH, average_rgb)
val_data = dataset(VAL_IMGS_PATH, VAL_HEATMAPS_PATH, average_rgb)

In [12]:
LEARNING_RATE = 2.5e-4
NUM_EPOCHS = 100
MINI_BATCH_SIZE = 16

criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr = LEARNING_RATE)
if (scheduler is None):
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.2, patience=2, verbose=True)

train_dataloader = DataLoader(train_data, batch_size = MINI_BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = 1)

In [13]:
# MANGLER AT SØRGE FOR, AT VÆGTENE IKKE STARTER MED AT VÆRE 0

torch.cuda.empty_cache()

for epoch in tqdm(range(start_epoch, NUM_EPOCHS), desc = "EPOCH"):
    model.train()
    train_loss = np.append(train_loss, 0)
    for x, y in tqdm(train_dataloader, leave = False, desc = "MINI BATCH", total = len(train_dataloader)):
        x = x.to(device, dtype = torch.float)
        y = y.to(device, dtype = torch.float)

        # Predict
        predictions = model(x)

        # Backpropegation
        optimizer.zero_grad()

        loss = criterion(predictions.to(device), y)

        train_loss[-1] += loss.item()

        loss.backward()
        optimizer.step()
    
    train_loss[-1] /= len(train_dataloader)
    
    print("Average train loss of epoch {}: {}".format(epoch, train_loss[-1]))

    # Validation
    with torch.no_grad():
        model.eval()
        val_loss = np.append(val_loss, 0)
        for x, y in tqdm(val_dataloader, leave = False, desc = "VALIDATION", total = len(val_dataloader)):
            x = x.to(device, dtype = torch.float)
            y = y.to(device, dtype = torch.float)

            # Predict
            predictions = model(x)

            # Loss
            loss = criterion(predictions.to(device), y)

            # Saving loss
            val_loss[-1] += loss.item()

        val_loss[-1] /= len(val_dataloader)

    print("Validation loss at epoch {}: {}".format(epoch, val_loss[-1]))

    # Saving model
    torch.save(model.state_dict(), cur_model_path + "/epoch_{}".format(epoch) + ".pth")

    # Saving training loss
    np.save(cur_model_path + "/loss.npy", train_loss)

    # Saving validation loss
    np.save(cur_model_path + "/val_loss.npy", val_loss)

    # Scheduler
    scheduler.step(val_loss[-1])

HBox(children=(FloatProgress(value=0.0, description='EPOCH', max=99.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 1: 8.40267806122319e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 1: 5.780111503067433e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 2: 5.891555685059114e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 2: 7.287495846922344e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 3: 4.86096452049718e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 3: 4.084584529423761e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 4: 4.273288074222126e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 4: 5.110260739832945e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 5: 4.00880305640285e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 5: 3.0270126332000992e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 6: 3.632546007835594e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 6: 2.666389273838143e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…

Average train loss of epoch 7: 3.5863851185247565e-06


HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=5064.0, style=ProgressStyle(description_…

Validation loss at epoch 7: 2.4834426275447006e-06


HBox(children=(FloatProgress(value=0.0, description='MINI BATCH', max=7394.0, style=ProgressStyle(description_…




KeyboardInterrupt: 

In [19]:
# Saving scheduler
torch.save(scheduler, cur_model_path + "scheduler.pth")