In [43]:
import cv2
import numpy as np
import os
from pathlib import Path

import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from MagicPoint.dataset.artificial_dataset import ArtificialDataset, available_modes
from MagicPoint.model.magic_point import MagicPoint
from common.model_utils import detector_loss, detector_metrics
from common.utils import *

%load_ext autoreload
%autoreload 2
%matplotlib inline


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [44]:
config = load_config('../configs/config_toy.yaml')
data_config = config['data']
model_config = config['model']
experiment_config = config['experiment']


In [48]:
set_seed(experiment_config['seed'])

train_dataset = ArtificialDataset(available_modes[0], data_config)
val_dataset = ArtificialDataset(available_modes[1], data_config)
test_dataset = ArtificialDataset(available_modes[2], data_config)

train_data_loader = DataLoader(train_dataset, model_config['batch_size'], collate_fn=collate,
                               shuffle=True, num_workers=4)
val_data_loader = DataLoader(val_dataset, model_config['val_batch_size'], collate_fn=collate, 
                             shuffle=True, num_workers=2)
test_data_loader = DataLoader(test_dataset, 1, collate_fn=collate, 
                             shuffle=True)


In [49]:
epoch = 0
model = MagicPoint(model_config)
optimizer = optim.Adam(model.parameters(), lr=model_config['learning_rate'])

if experiment_config['load_checkpoints']:
    checkpoint_path = get_checkpoint_path(experiment_config, model_config, 
                                          experiment_config['load_checkpoint_iter'])
    if checkpoint_path.exists():
        epoch, model_sd, optimizer_sd = load_checkpoint(checkpoint_path)
        model.load_state_dict(model_sd)
        optimizer.load_state_dict(optimizer_sd)

logs_base = '../../logs'
writer = SummaryWriter(log_dir=init_log_dir(logs_base, experiment_config))


In [50]:
for epoch in range(epoch, experiment_config['num_epochs']):
    # print(epoch)
    
    model.train()
    
    train_loss = 0
    train_precision = 0
    train_recall = 0
    
    for x, y in train_data_loader:
        optimizer.zero_grad()
        
        y_pred = model(x)
        loss = detector_loss(y_pred['logits'], y, model_config)
        
        loss.backward()
        optimizer.step()
        
        metrics = detector_metrics(y_pred['probs'], y)
        
        train_loss += loss.item()
        train_precision += metrics['precision'].item()
        train_recall += metrics['recall'].item()
        
    train_loss /= train_data_loader.__len__()
    train_precision /= train_data_loader.__len__()
    train_recall /= train_data_loader.__len__()
    
    writer.add_scalar('training/loss', train_loss, epoch)
    writer.add_scalar('training/precision', train_precision, epoch)
    writer.add_scalar('training/recall', train_recall, epoch)
        
    model.eval()
    
    with torch.no_grad():
        val_loss = 0
        val_precision = 0
        val_recall = 0
        
        for x, y in val_data_loader:
            y_pred = model(x)
            loss = detector_loss(y_pred['logits'], y, model_config)
            
            metrics = detector_metrics(y_pred['probs'], y)
            
            val_loss += loss.item()
            val_precision += metrics['precision'].item()
            val_recall += metrics['recall'].item()
            
        val_loss /= val_data_loader.__len__()
        val_precision /= val_data_loader.__len__()
        val_recall /= val_data_loader.__len__()
            
        writer.add_scalar('validation/loss', val_loss, epoch)
        writer.add_scalar('validation/precision', val_precision, epoch)
        writer.add_scalar('validation/recall', val_recall, epoch)
    
    if experiment_config['keep_checkpoints'] != 0 and epoch != 0 and epoch % experiment_config['save_interval'] == 0:
        checkpoint_path = get_checkpoint_path(experiment_config, model_config, epoch)
        save_checkpoint(epoch, model, optimizer, checkpoint_path)
        clear_old_checkpoints(experiment_config)
    
    # TODO. Early stopping, maybe? Or reduce on plato.
    # TODO. Make a picture of model's performance from  train/val/test set each n epochs
    
writer.close()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


KeyboardInterrupt: 