## Imports

In [None]:
import torch
import wandb

import torch.nn as nn

from datetime import datetime

from torch import optim

In [None]:
from utils.model_utils import train, evaluate
from utils.corpus_utils import Corpus
from models.SAGE import CustomSAGE

## Data loading

In [None]:
train_data_path = 'data/train_oldtorch.pickle'
test_data_path = 'data/test_oldtorch.pickle'

In [None]:
train_data = torch.load(train_data_path)
test_data = torch.load(test_data_path)

In [None]:
train_graph = train_data.graph
test_graph = test_data.graph

In [None]:
token2idx = train_data.token2idx
idx2token = train_data.idx2token

In [None]:
vocab_size = max(idx2token.keys())+1

## Config

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
epochs = 100

hidden_dim = 64
num_conv_layers = 3

lr=0.001

In [None]:
hparams = {
    'hidden_dim' : hidden_dim,
    'num_conv_layers' : num_conv_layers
            }

In [None]:
timestamp = str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

print(timestamp)

In [None]:
config = {
    "train_dataset_file_path": train_data_path,
    "val_dataset_file_path": test_data_path,
    "epochs": epochs,
    "lr": lr,
    'timestamp' : timestamp
}
for hparam in hparams:
    config[hparam] = hparams[hparam]

In [None]:
wandb.init(project="master_thesis", entity="kesha_humonen", config=config)

## Model

In [None]:
model = CustomSAGE(vocab_size, hidden_dim=hidden_dim, num_conv_layers=num_conv_layers)
# if multigpu_available():
#     model = DataParallel(model)

optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

model = model.to(device)
criterion = criterion.to(device)

## Training loop

In [None]:
loss = []
loss_eval = []

# print(f'Start training model {str(model)}')
for epoch in range(epochs):
    train_graph = torch.load(train_data_path).graph
    test_graph = torch.load(train_data_path).graph #test_data.graph
    loss.append(train(model, train_graph, idx2token, optimizer, criterion, epoch, device))
    loss_eval.append(evaluate(model, test_graph, idx2token, criterion, epoch, device, timestamp, save_checkpoints=True))
    wandb.log({'train_bce': loss[-1], 'val_bce' : loss_eval[-1]})

In [None]:
wandb.finish()

## Learning curves (also see [wandb](https://wandb.ai/kesha_humonen/master_thesis))

In [None]:
import plotly.graph_objects as go

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=loss, name='train'))
fig.add_trace(go.Scatter(y=loss_eval, name='val'))
fig.update_layout(title='ChataboxModel',
                   xaxis_title='epoch',
                   yaxis_title=str(criterion)[:-3])
fig.show()

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.plot(loss, color='red', label='train')
plt.plot(loss_eval, color='blue', label='eval')
plt.legend()
plt.xlabel('epoch')
plt.ylabel(str(criterion)[:-3])
plt.title('ChataboxModel')
plt.show()