# Imports

In [None]:
from os import path, listdir
from copy import deepcopy
import stlearn as st
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from torch import tensor
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

%load_ext autoreload
%autoreload 2

from trainer_ae import trainer_ae
from data import get_data
from models import get_model
from tester_ae import tester_ae
from loss import *

In [None]:
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Load Data 

In [3]:
dataset_name = 'Visium_Mouse_Olfactory_Bulb'
model_name='AE'
batch_size=128

In [None]:
dl_train, dl_valid, dl_test = get_data(
    model_name=model_name, 
    dataset_name=dataset_name, 
    batch_size=best_params, 
    device=device
)

# Load Model

In [None]:
params = {
    'learning_rate': 0.1,
    'optimizer': "SGD",
    'latent_dim': 40,
    'batch_size': batch_size
}

In [None]:
model = get_model(model_name=model_name, params=params, dl_train=dl_train)

# Train Model 

In [None]:
max_epochs = 300
early_stopping = 15

## Load Optimizer 

In [None]:
optimizer = getattr(optim, params['optimizer'])(model.parameters(), lr=params['learning_rate'])
criterion = NON_ZERO_RMSELoss()

## Train

In [None]:
model, valid_loss = trainer_ae(
    model=model, 
    optimizer=optimizer, 
    criterion=criterion,
    max_epochs=max_epochs, 
    early_stopping=early_stopping, 
    dl_train=dl_train, 
    dl_test=dl_valid, 
    device=device, 
    dataset_name=dataset_name, 
    model_name=model_name
)

In [None]:
train_res = 4.55
valid_res = 4.58
print(f'Train final results (after log transform) = {train_res}')
print(f'Train final results = {np.exp(train_res)}')
print(f'Valid final results (after log transform) = {valid_res}')
print(f'Valid final results = {np.exp(valid_res)}')

## Test 

In [None]:
loss_fn = RMSELoss()

test_loss = tester_ae(
    model=model,
    dl_test=dl_test,
    device=device,
    loss_fn=loss_fn
)
print(f'Test loss (after log transform) = {test_loss}')
print(f'Test loss = {np.exp(test_loss)}')