## Assignment 2 - Fashion Segmentation

#### Author: Joaquim Marset Alsina

### Imports

In [None]:
from train import train_model
from inference import perform_inference
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from plot_utils import plot_metrics

### Constants

In [None]:
root_path = './'

config_path = os.path.join(root_path, 'config.py')

results_path = os.path.join(root_path, 'results')
train_results_path = os.path.join(results_path, 'train')

test_results_path = os.path.join(results_path, 'test')
test_predictions_path = os.path.join(test_results_path, 'predictions')

data_path = os.path.join(root_path, 'fashionpedia')
test_images_path = os.path.join(data_path, 'images', 'val')
test_segmentations_path = os.path.join(data_path, 'segmentations', 'val')

seed = 1412

### Create required folders

In [None]:
if not os.path.exists(results_path):
    os.makedirs(results_path, exist_ok=True)
    os.makedirs(train_results_path, exist_ok=True)
    os.makedirs(test_results_path, exist_ok=True)
    os.makedirs(test_predictions_path, exist_ok=True)

### Train

In [None]:
train_model(config_path, seed, train_results_path)

### Perform Inference

In [None]:
checkpoint_path = os.path.join(train_results_path, 'latest.pth')

perform_inference(config_path, test_results_path, test_predictions_path, test_images_path, 
    test_segmentations_path, checkpoint_path, plot_segmentations=False)

### Plot training and validation statistics

In [None]:
log_files = glob.glob(os.path.join(train_results_path, '*.json'))
log_path = log_files[-1]

plot_metrics(log_path, results_path)

### Display some results

In [None]:
predictions = os.listdir(test_predictions_path)
num_test_images = len(predictions)

if num_test_images > 0:

    num_to_show = 10
    indices = np.random.choice(num_test_images, num_to_show)

    for index in indices:
        path = os.path.join(test_predictions_path, predictions[index])
        image = mpimg.imread(path)
        plt.figure()
        plt.imshow(image)