# Basic Training Scheme development notebook

## Importing section

In [1]:
import torch
from torch import optim, nn
from torch.utils.data import random_split, DataLoader
from etnn.data.ferris_wheel import load_ferris_wheel_dataset
from etnn.data.tree_structure import TreeNode
from etnn.nn.layer_framework import ChiralLayerManagementFramework
from tqdm import tqdm

In [2]:
import numpy as np

## Parameter definition

In [3]:
NUM_GONDOLAS = 10
NUM_PART_PG = 5
DATASET_SIZE = 10_000
DATASET_PATH = "../datasets"

In [4]:
val_perc = 0.1
test_perc = 0.2

In [5]:
INPUT_DIM = 15
HIDDEN_DIM = 32
OUT_DIM = 1
K=2
LEARNING_RATE = 0.001

## data preparation

In [6]:
dataset = load_ferris_wheel_dataset(
    num_gondolas=NUM_GONDOLAS,
    num_part_pg=NUM_PART_PG,
    num_to_generate=DATASET_SIZE,
    dataset_path=DATASET_PATH
)

In [7]:
len(dataset)

10000

In [8]:
generator = torch.Generator().manual_seed(420)
train_ds, val_ds, test_ds = random_split(
    dataset,
    [1-val_perc-test_perc, val_perc, test_perc],
    generator=generator
)

In [9]:
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

In [10]:
# todo: add more tree mutated sequences that are in dataset
# todo: add some faulty sequences which disturb this and see how much this affects the training

## Build tree structure

In [11]:
tree_structure = TreeNode(
    node_type="C",
    children=[
        TreeNode("P", [TreeNode("E", NUM_PART_PG)])
        for _ in range(NUM_GONDOLAS)
    ]
)

## Define device

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [25]:
device

'cuda'

## Define Model

In [26]:
model = ChiralLayerManagementFramework(
    in_dim=15,
    tree=tree_structure,
    hidden_dim=HIDDEN_DIM,
    out_dim=OUT_DIM,
    k=K
).to(device)

## Define Loss and Optimizers

In [27]:
criterion = nn.MSELoss()

In [28]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

## Training

In [41]:
# set model to training mode
model.train()

for batch_data, batch_label in tqdm(train_loader):
    # optimizer zero grad
    optimizer.zero_grad()

    # put data to device
    batch_data = batch_data.to(device)
    batch_label = batch_label.to(device)

    # put through model
    prediction = model(batch_data).flatten()

    # calculate loss
    loss = criterion(prediction, batch_label)

    # backward loss
    loss.backward()

    # optimizer step
    optimizer.step()


100%|██████████| 219/219 [00:04<00:00, 45.77it/s]


In [47]:
torch.mean(torch.stack([loss.detach().cpu()]))

tensor(56.8242)

In [39]:
with torch.no_grad():
    # set model to evaluation mode
    model.eval()

    for batch_data, batch_label in tqdm(test_loader):
        # put data to device
        batch_data = batch_data.to(device)
        batch_label = batch_label.to(device)

        # put through model
        prediction = model(batch_data).flatten()

        # calculate loss
        loss = criterion(prediction, batch_label)

100%|██████████| 63/63 [00:01<00:00, 60.74it/s]


In [40]:
loss

tensor(32.2877, device='cuda:0')

## Test/Debug

In [17]:
for x,y in train_loader:
    break

In [18]:
x.shape

torch.Size([32, 50, 15])

In [19]:
x

tensor([[[31.0000,  2.0000,  7.7000,  ...,  1.0000,  0.0000,  0.0000],
         [43.0000,  6.0000,  7.8000,  ...,  1.0000,  0.0000,  0.0000],
         [43.0000,  6.0000,  7.8000,  ...,  1.0000,  0.0000,  0.0000],
         ...,
         [37.0000,  7.0000,  7.2000,  ...,  0.0000,  1.0000,  0.0000],
         [49.0000,  5.0000,  6.1000,  ...,  0.0000,  1.0000,  0.0000],
         [30.0000,  2.0000,  7.7000,  ...,  1.0000,  0.0000,  0.0000]],

        [[28.0000,  2.0000,  6.2000,  ...,  1.0000,  0.0000,  0.0000],
         [30.0000,  2.0000,  7.7000,  ...,  1.0000,  0.0000,  0.0000],
         [31.0000,  2.0000,  7.7000,  ...,  1.0000,  0.0000,  0.0000],
         ...,
         [32.0000,  2.0000,  6.0000,  ...,  1.0000,  0.0000,  0.0000],
         [54.0000,  6.0000,  8.4000,  ...,  0.0000,  1.0000,  0.0000],
         [43.0000,  6.0000,  7.8000,  ...,  1.0000,  0.0000,  0.0000]],

        [[44.0000, 10.0000,  6.3000,  ...,  1.0000,  0.0000,  0.0000],
         [29.0000,  5.0000,  6.5000,  ...,  0

In [20]:
y.shape

torch.Size([32])

In [21]:
y

tensor([69.6241, 64.6696, 62.5482, 63.7993, 52.9346, 61.3342, 58.7269, 60.9088,
        61.8402, 55.8015, 55.7720, 57.9381, 57.3692, 60.3790, 67.6998, 57.6472,
        64.9572, 66.3945, 62.1344, 62.7825, 69.0934, 55.5725, 62.2126, 65.3915,
        74.2715, 57.2791, 64.5114, 59.2670, 61.2274, 61.1825, 62.1553, 60.1770])

In [22]:
for i in range(1000):
    optimizer.zero_grad()
    prediction = model(x).flatten()
    loss = criterion(prediction, y)
    print(f"Epoch:{i+1} - loss:{loss}")
    loss.backward()
    optimizer.step()

Epoch:1 - loss:428.2895812988281
Epoch:2 - loss:32366.56640625
Epoch:3 - loss:7892.162109375
Epoch:4 - loss:818.5023803710938
Epoch:5 - loss:3544.20068359375
Epoch:6 - loss:2663.4296875
Epoch:7 - loss:1619.7647705078125
Epoch:8 - loss:326.6105651855469
Epoch:9 - loss:86.373779296875
Epoch:10 - loss:255.9167938232422
Epoch:11 - loss:422.157470703125
Epoch:12 - loss:497.94384765625
Epoch:13 - loss:456.3053283691406
Epoch:14 - loss:328.8862609863281
Epoch:15 - loss:178.96530151367188
Epoch:16 - loss:67.61689758300781
Epoch:17 - loss:28.526538848876953
Epoch:18 - loss:58.78349304199219
Epoch:19 - loss:126.64144134521484
Epoch:20 - loss:190.12832641601562
Epoch:21 - loss:217.0679168701172
Epoch:22 - loss:197.32188415527344
Epoch:23 - loss:143.7203826904297
Epoch:24 - loss:82.50564575195312
Epoch:25 - loss:39.556392669677734
Epoch:26 - loss:28.712812423706055
Epoch:27 - loss:46.97333526611328
Epoch:28 - loss:78.14830780029297
Epoch:29 - loss:102.48876953125
Epoch:30 - loss:106.91947937011719

In [23]:
y

tensor([69.6241, 64.6696, 62.5482, 63.7993, 52.9346, 61.3342, 58.7269, 60.9088,
        61.8402, 55.8015, 55.7720, 57.9381, 57.3692, 60.3790, 67.6998, 57.6472,
        64.9572, 66.3945, 62.1344, 62.7825, 69.0934, 55.5725, 62.2126, 65.3915,
        74.2715, 57.2791, 64.5114, 59.2670, 61.2274, 61.1825, 62.1553, 60.1770])

In [27]:
model(x).flatten()

tensor([69.5522, 72.8696, 70.0393, 70.1519, 76.7386, 68.9264, 66.6385, 68.9720,
        70.9857, 63.4863, 69.1004, 70.3080, 70.0417, 69.3402, 67.6243, 69.4924,
        66.8773, 70.9033, 71.7158, 68.0593, 69.7302, 74.8552, 71.8740, 71.4821,
        69.9691, 68.4896, 71.6527, 67.3356, 69.1198, 67.9152, 73.7401, 71.0583],
       grad_fn=<ViewBackward0>)