# Basic Training Scheme development notebook

## Importing section

In [1]:
import torch
from torch import optim, nn
from torch.utils.data import random_split, DataLoader
from etnn.data.ferris_wheel import load_ferris_wheel_dataset
from etnn.data.tree_structure import TreeNode
from etnn.nn.layer_framework import ChiralLayerManagementFramework
from tqdm import tqdm

from etnn.tools.training import train_epoch, eval_epoch

In [2]:
import numpy as np

## Parameter definition

In [3]:
NUM_GONDOLAS = 10
NUM_PART_PG = 5
DATASET_SIZE = 10_000
DATASET_PATH = "../datasets"

In [4]:
val_perc = 0.1
test_perc = 0.2
BATCH_SIZE = 32

In [5]:
INPUT_DIM = 15
HIDDEN_DIM = 32
OUT_DIM = 1
K=2
LEARNING_RATE = 0.001
NUM_MAX_EPOCHS = 10

## data preparation

In [6]:
dataset = load_ferris_wheel_dataset(
    num_gondolas=NUM_GONDOLAS,
    num_part_pg=NUM_PART_PG,
    num_to_generate=DATASET_SIZE,
    dataset_path=DATASET_PATH
)

In [7]:
len(dataset)

10000

In [8]:
generator = torch.Generator().manual_seed(420)
train_ds, val_ds, test_ds = random_split(
    dataset,
    [1-val_perc-test_perc, val_perc, test_perc],
    generator=generator
)

In [9]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
# todo: add more tree mutated sequences that are in dataset
# todo: add some faulty sequences which disturb this and see how much this affects the training

## Build tree structure

In [11]:
tree_structure = TreeNode(
    node_type="C",
    children=[
        TreeNode("P", [TreeNode("E", NUM_PART_PG)])
        for _ in range(NUM_GONDOLAS)
    ]
)

## Define device

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
device

'cuda'

## Define Model

In [14]:
model = ChiralLayerManagementFramework(
    in_dim=15,
    tree=tree_structure,
    hidden_dim=HIDDEN_DIM,
    out_dim=OUT_DIM,
    k=K
).to(device)

## Define Loss and Optimizers

In [15]:
criterion = nn.MSELoss()

In [16]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

## Training

In [0]:
# init loss history
loss_history_train = []
loss_history_test = []

In [20]:
# train for N epochs
for epoch in tqdm(range(NUM_MAX_EPOCHS)):
    mean_train_loss = train_epoch(
        model,
        train_loader,
        optimizer,
        device,
        criterion
    )

    loss_history_train += [mean_train_loss]

    loss_history_test += [eval_epoch(
        model,
        test_loader,
        device,
        criterion
    )]

100%|██████████| 10/10 [01:05<00:00,  6.54s/it]


In [21]:
loss_history_train

[tensor(2414.2751),
 tensor(40.6598),
 tensor(43.2114),
 tensor(43.0136),
 tensor(45.4412),
 tensor(40.9359),
 tensor(49.3088),
 tensor(41.5498),
 tensor(41.9255),
 tensor(44.7525),
 tensor(41.5806),
 tensor(44.6666),
 tensor(40.8808),
 tensor(40.4938),
 tensor(39.5798),
 tensor(39.4360),
 tensor(40.7334),
 tensor(40.7510),
 tensor(40.0163),
 tensor(38.7285)]

In [22]:
loss_history_test

[tensor(36.8590),
 tensor(40.4110),
 tensor(53.3749),
 tensor(55.8331),
 tensor(46.7400),
 tensor(40.8851),
 tensor(36.2696),
 tensor(42.4216),
 tensor(40.0277),
 tensor(44.7888),
 tensor(49.9517),
 tensor(35.9098),
 tensor(60.9032),
 tensor(40.9720),
 tensor(35.8627),
 tensor(36.9581),
 tensor(37.4648),
 tensor(37.3593),
 tensor(37.8715),
 tensor(38.1591)]

## Test/Debug

for x,y in train_loader:
    break

x.shape

x

y.shape

y

for i in range(1000):
    optimizer.zero_grad()
    prediction = model(x).flatten()
    loss = criterion(prediction, y)
    print(f"Epoch:{i+1} - loss:{loss}")
    loss.backward()
    optimizer.step()

y

model(x).flatten()