In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import logging

from utils.train_pipeline import Trainer, CheckPointArgs, TrainArgs
from utils.metrics import DivergenceLoss
from utils.dataset import rbc_data

from models import *

In [None]:
experiment_name = 'rbc_data_16_4'
model_name = 'Multigrid'

In [None]:
training_args = TrainArgs(num_epochs = 100,batch_size = 16, learning_rate = 2e-5)
checkpoint_args = CheckPointArgs(model_name, experiment_name)

In [None]:
print(training_args.batch_size)
print(training_args.learning_rate)
print(training_args.num_epochs)

In [None]:
input_length = 16
output_length = 4

data_prep = [torch.load('data/sample_0.pt'),
             torch.load('data/sample_1.pt'),
             torch.load('data/sample_2.pt'),
             torch.load('data/sample_4.pt')]

train_indices = list(range(3000))
valid_indices = list(range(3000, 4000))

train_ds = rbc_data(data_prep, train_indices, input_length, output_length, False)
valid_ds = rbc_data(data_prep, train_indices, input_length, output_length, False)

In [None]:
model = MG((2, 64, 64), input_length, output_length)

In [None]:
print(count_parameters(model))

In [None]:
logging_configs = {
    'filename' : f'multigrid_log.log',
    'level'    : logging.INFO,
# 'format'   : "{asctime} {levelname:<8} {message}"
}

In [None]:
torch.cuda.empty_cache()

In [None]:
trainer = Trainer(model, train_ds, valid_ds, checkpoint_args, training_args, logging_config = logging_configs)
trainer.train()