# Imports

In [None]:
from os import path, listdir
from copy import deepcopy
import stlearn as st
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

import trainer_nmf as trainer
import data_nmf as get_data
from models import get_model
import tester_nmf as tester
from loss import *

In [None]:
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Load Data 

In [None]:
apply_log = False
batch_size = 128

In [None]:
dl_train, dl_valid, dl_test, _ = get_data.main(
    apply_log=apply_log, 
    batch_size=batch_size, 
    device=device
)

# Modelling

## Set HyperParameters

In [None]:
model_name = 'NMF'
max_epochs = 300
early_stopping = 15
model_params = {
    'learning_rate': 0.1,
    'optimizer': "SGD",
    'latent_dim': 40,
    'batch_size': batch_size
}

## Build Model 

In [None]:
model = get_model(model_name, model_params, dl_train)
optimizer = getattr(optim, model_params['optimizer'])(model.parameters(), lr=model_params['learning_rate'])
criterion = RMSELoss()

## Train Model 

In [None]:
model, valid_loss = trainer.train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    max_epochs=max_epochs,
    early_stopping=early_stopping,
    dl_train=dl_train,
    dl_test=dl_valid, 
    device=device,
    model_name=model_name
)

## Test 

In [None]:
test_loss, df_test_preds = tester.test(
    model=model,
    criterion=criterion,
    dl_test=dl_test_exp,
    device=device
)
print(f'Test loss = {test_loss}')

# Results Analysis 