In [None]:
import pandas as pd

import torch.optim as optim
import torch.nn as nn
import torch.utils.data as data
import torch

import time
import os

import matplotlib.pyplot as plt
import numpy as np
import copy
import wandb

from src.models import MLP, train, evaluate, epoch_time, MyDataset

In [None]:
seed = 1
torch.manual_seed(seed)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
input_size = 3
select = "manual+3"
###
hidden_size_1 = 12
hidden_size_2 = 5
activation = "tanh"
lr = 0.006392255358324788
weight_decay = 0.00024303795827860364

In [None]:
recipes = list(range(input_size))
dataset_train = MyDataset(pd.read_csv(f"data/cali_tmp_2/{select}_train.csv"), recipes)
dataset_test = MyDataset(pd.read_csv(f"data/cali_tmp_2/{select}_test.csv"), recipes)
test_size = int(0.5 * len(dataset_test))
test_data, valid_data = torch.utils.data.random_split(
    dataset_test, [len(dataset_test) - test_size, test_size]
)
BATCH_SIZE = 256
train_iterator = data.DataLoader(dataset_train,shuffle=True,batch_size=BATCH_SIZE)
valid_iterator = data.DataLoader(valid_data,batch_size=BATCH_SIZE)
test_iterator = data.DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
model = MLP(3 * len(recipes), [hidden_size_1 * 5, hidden_size_2 * 5], 3, activation).to(device)
criterion = nn.L1Loss()
criterion = criterion.to(device)

if os.path.exists(f"results/models/best/{select}.pt"):
    model.load_state_dict(torch.load(f"models/best/{select}.pt"))
    best_model = copy.deepcopy(model)
else:
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    EPOCHS = 50000

    best_valid_loss = float('inf')

    for epoch in range(EPOCHS):

        start_time = time.monotonic()

        train_loss = train(model, train_iterator, optimizer, criterion, device)
        valid_loss = evaluate(model, valid_iterator, criterion, device)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), f'models/best/{select}.pt')
            best_model = copy.deepcopy(model)

        # wandb.log({"train_loss": train_loss, "valid_loss": valid_loss, "best_valid_loss": best_valid_loss})

        end_time = time.monotonic()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s | Train Loss: {train_loss:.5f} | Val. Loss: {valid_loss:.5f}')