# Inspect Detectron2 training on plankton dataset

In [None]:
import os
import glob
import pickle
import json
import cv2
import random
import matplotlib.pyplot as plt
import torch
import numpy as np
import time

from detectron2 import model_zoo
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

import lib.training_functions as training_functions # Custom training functions
import lib.my_visualizer as my_visualizer # Custom Detectron2 visualizer to increase font size

### Choose output directory

List all output directories

In [None]:
# Output directory
output_dir = 'output/output_*'
output_dirs = glob.glob(output_dir)
output_dirs.sort()
output_dirs

By default, select the last one

In [None]:
output_dir = output_dirs[-1]

### Read training settings

In [None]:
with open(os.path.join(output_dir, 'settings.pickle'),'rb') as set_file:
    settings = pickle.load(set_file)
settings    

### Prepare dataset

In [None]:
data_dir = settings['dataset']
# Register the dataset to detectron2
for d in ['train', 'valid', 'test']:
    DatasetCatalog.register('plankton_' + d, lambda d=d: training_functions.my_dataset_function(os.path.join(data_dir, d)))
    MetadataCatalog.get('plankton_' + d).set(thing_classes=['plankton'])
plankton_metadata = MetadataCatalog.get('plankton_train')

### Have a look at a few training frames

In [None]:
dataset_dicts = training_functions.format_bbox(os.path.join(data_dir, 'train')) 
for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d['file_name'])
    visualizer = Visualizer(img[:, :, ::-1], metadata=plankton_metadata, scale=1)
    out = visualizer.draw_dataset_dict(d)
    plt.figure(figsize = (10,10))
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.show()

## Training

### Read training metrics

In [None]:
training_metrics = []
with open(os.path.join(output_dir, 'metrics.json'), 'r') as f:
    for line in f:
        training_metrics.append(json.loads(line))

### Plot training and validation loss evolution

In [None]:
plt.figure(figsize = (10,5))
plt.plot(
    [m['iteration'] for m in training_metrics if 'total_loss' in m], 
    [m['total_loss'] for m in training_metrics if 'total_loss' in m]
)
plt.plot(
    [m['iteration'] for m in training_metrics if 'validation_loss' in m], 
    [m['validation_loss'] for m in training_metrics if 'validation_loss' in m])
plt.legend(['total_loss', 'validation_loss'], loc='best')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Iteration')
plt.ylim([min(plt.ylim()),1])
plt.show()

In [None]:
plt.figure(figsize = (10,5))
plt.plot(
    [m['iteration'] for m in training_metrics if 'total_loss' in m], 
    [m['total_loss'] for m in training_metrics if 'total_loss' in m]
)
plt.plot(
    [m['iteration'] for m in training_metrics if 'validation_loss' in m], 
    [m['validation_loss'] for m in training_metrics if 'validation_loss' in m])
plt.legend(['total_loss', 'validation_loss'], loc='best')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Iteration')
plt.ylim([min(plt.ylim()),0.45])
plt.show()

## Testing

### Look at test metrics

In [None]:
with open(os.path.join(output_dir, 'test_results.pickle'),'rb') as results_file:
    results = pickle.load(results_file)
results

### Have a look at a few predicted frames

Load model with default predictor

In [None]:
# Get default config
cfg = get_cfg()
# Load training config
cfg.merge_from_file(os.path.join(output_dir, 'my_cfg.yaml'))
# Set detection threshold (can be played with)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = settings['test_threshold']
# Load model weights
cfg.MODEL.WEIGHTS = os.path.join(output_dir, 'model_final.pth')

predictor = DefaultPredictor(cfg)

Run prediction

In [None]:
test_dataset_dicts = training_functions.format_bbox(os.path.join(data_dir, 'test'))
for d in random.sample(test_dataset_dicts, 10):
    im = cv2.imread(d['file_name'])
    outputs = predictor(im)
    v = my_visualizer.Visualizer(im[:, :, ::-1],
                   metadata=plankton_metadata, 
                   scale=1
    )
    v = v.draw_instance_predictions(outputs['instances'].to('cpu'))
    plt.figure(figsize=(10,10))
    plt.imshow(v.get_image()[:, :, ::-1])
    plt.show()