In [10]:
import cv2
import numpy as np
import os
from pathlib import Path

import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from MagicPoint.dataset.artificial_dataset import ArtificialDataset, available_modes
from MagicPoint.model.magic_point import MagicPoint
from common.model_utils import detector_loss
from common.utils import *

%load_ext autoreload
%autoreload 2
%matplotlib inline


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
config = load_config('../configs/config_toy.yaml')
data_config = config['data']
model_config = config['model']
experiment_config = config['experiment']


In [8]:
set_seed(experiment_config['seed'])
num_epochs = experiment_config['num_epochs']

save_interval = experiment_config['save_interval']
keep_checkpoints = experiment_config['keep_checkpoints']


In [9]:
train_dataset = ArtificialDataset(available_modes[0], data_config)
val_dataset = ArtificialDataset(available_modes[1], data_config)
test_dataset = ArtificialDataset(available_modes[2], data_config)

train_data_loader = DataLoader(train_dataset, model_config['batch_size'], collate_fn=collate,
                               shuffle=True, num_workers=4)
val_data_loader = DataLoader(val_dataset, model_config['val_batch_size'], collate_fn=collate, 
                             shuffle=True, num_workers=2)
test_data_loader = DataLoader(test_dataset, 1, collate_fn=collate, 
                             shuffle=True)


In [None]:
model = MagicPoint(model_config)
optimizer = optim.Adam(model.parameters(), lr=model_config['learning_rate'])

if experiment_config['load_checkpoints']:
    base_path = Path(experiment_config['checkpoints_path'], experiment_config['name'])
    base_path.mkdir(parents=True, exist_ok=True)
    
    checkpoint_path = base_path.joinpath(get_checkpoint_name(model_config['name'], 
                                                             experiment_config['checkpoint_iter']))
    if checkpoint_path.exists():
        model_sd, optimizer_sd = load_checkpoint(checkpoint_path)
        model.load_state_dict(model_sd)
        optimizer.load_state_dict(optimizer_sd)


In [None]:
for i in range(num_epochs):
    
    model.train()
    for x, y in train_data_loader:
        optimizer.zero_grad()
        
        y_pred = model(x)
        loss = detector_loss(y_pred['logits'], y, model_config)
        
        loss.backward()
        optimizer.step()
        
        # here add loss to tb
        # here calculate metric and also ad it to tb
        
    model.eval()
    with torch.no_grad():
        for x, y in val_data_loader:
            y_pred = model(x)
            loss = detector_loss(y_pred['logits'], y, model_config)
            
            # here calculate metric and add it to tb
    
    # Here each n iterations:
    # Save model
    # Add early stopping
    # Make a picture of model's performance from  train/val/test set each n epochs
          