In [4]:
import sys
sys.path.append('../..')

In [None]:
import pathlib

import Datas
import Model

import torch
import torch.optim

import ignite.engine

import CustomTrainer

In [6]:
config = {
    'model' : {
        'device' : 'cpu',
        'kernel_size' : (3, 3),
        'nb_iterations' : 5,
        'lamda' : {
            'initialize' : 5e-2,
            'requires_grad' : True
        },
        'sigma' : {
            'initialize' : 5e-3,
            'requires_grad' : True
        }
    }
}

In [None]:
dataloader_train, dataloader_validation = Datas.from_config(config)

In [8]:
model = Model.TotalVariation.from_config(config)

TotalVariation(
  (layers): ModuleList(
    (0-4): 5 x Layer(
      (inv): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (HT): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    )
  )
)

In [None]:
train_folder = '.'

# Make outputs paths
output_path = pathlib.Path(config['output'].get('folder', train_folder))
if not(output_path.exists()):
    output_path.mkdir()

models_save_path = output_path / config['output']['models_save']['path']
if not(models_save_path.exists()):
    models_save_path.mkdir()
models_save_every = config['output']['models_save']['every']

imgs_save_path = output_path / config['output']['imgs_save']['path']
if not(imgs_save_path.exists()):
    imgs_save_path.mkdir()

path_imgs_train = imgs_save_path / 'train_datas'
if not(path_imgs_train.exists()):
    path_imgs_train.mkdir()

path_imgs_eval = imgs_save_path / 'eval_datas'
if not(path_imgs_eval.exists()):
    path_imgs_eval.mkdir()


nb_epochs = config['train']['nb_epochs']
learning_rate = config['train']['learning_rate']
criterion = eval(config['train']['loss'])

imgs_save_every = config['output']['imgs_save']['every']

df_training_path = output_path / config['output']['metrics']['train']
df_validation_path = output_path / config['output']['metrics']['validation']
loss_path = output_path / config['output']['loss']

datas_device = config['dataset']['device']
model_device = config['model']['device']


In [None]:
optimizer = torch.optim.Adam(
    params=model.parameters(),
    lr = learning_rate
)

criterion = torch.nn.MSELoss()

train_step = CustomTrainer.create_train_step(
    model, model_device, datas_device, optimizer, criterion
)

trainer = CustomTrainer.CustomEngine(train_step)
trainer.add_event_handler(
    ignite.engine.Events.ITERATION_COMPLETED,
    CustomTrainer.update_epoch_loss
)
trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED,
    CustomTrainer.compute_epoch_loss
)
trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED,
    CustomTrainer.save_epoch_loss,
    loss_path
)
trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED,
    CustomTrainer.print_logs
)

In [None]:
trainer.add_event_handler(
    # ignite.engine.Events.COMPLETED,
    ignite.engine.Events.EPOCH_COMPLETED(every=models_save_every) 
    | ignite.engine.Events.COMPLETED,
    # Callback
    CustomTrainer.save_model,
    # Parameters of callback
    model,
    models_save_path
)


import Evaluator
# Make evaluator
evaluate_function = Evaluator.create_evaluate_function(
    model, model_device, datas_device
)
evaluator = ignite.engine.Engine(evaluate_function)

trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED(every=imgs_save_every) 
    | ignite.engine.Events.COMPLETED,
    # Callback
    Evaluator.evaluate_dataloader_with_variable_size_image,
    # Parameters of callback
    evaluator,
    model,
    model_device,
    datas_device,
    dataloader_train,
    path_imgs_train
)

trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED(every=imgs_save_every)
    | ignite.engine.Events.COMPLETED,
    # Callback
    Evaluator.evaluate_dataloader_with_variable_size_image,
    # Parameters of callback
    evaluator,
    model,
    model_device,
    datas_device,
    dataloader_validation,
    path_imgs_eval
)