In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import skimage.io as io
import torch
import torchvision.transforms as TT
from torch.utils.data import DataLoader
import cnn
from data_augment import *
from dataloader import NoseKeypointDataset
from display import *
from learn import test, train

In [None]:
ROOT_DIR = Path("imm_face_db")

# Initialite Datasets

transform = part1_augment
# Use all 6 images of the first 32 persons (index 1-32) as the training set
# (total 32 x 6 = 192 images)
training_set = NoseKeypointDataset(
    idxs=np.arange(33), root_dir=ROOT_DIR, transform=transform
)
assert len(training_set) == 192

# Use images of the remaining 8 persons (index 33-40) as the validation set
# (total 8 * 6 = 48 images)
validation_set = NoseKeypointDataset(
    idxs=np.arange(32, 40), root_dir=ROOT_DIR, transform=transform
)
assert len(validation_set) == 48

# Initialize Dataloaders
batch_size = 16
train_loader = DataLoader(training_set, batch_size, shuffle=True)
test_loader = DataLoader(validation_set, batch_size, shuffle=False)

In [None]:
show_keypoints(training_set[2][0], training_set[2][1])

In [None]:
# Training and Testing

epochs = 20
learn_rate = 0.001
show_every = 3
loss_per_epoch = []
for ep in range(epochs):
    print(f"========== Epoch {ep} ==========")
    trained_model, train_loss = train(
        train_loader, cnn.NoseFinder(), learn_rate
    )
    _, valid_loss = test(test_loader, trained_model, show_every)
    
    print_epoch(ep, train_loss, valid_loss)
    loss_per_epoch.append([train_loss, valid_loss])