In [1]:
!rm -fr r_trader out
!mkdir out input
!git clone https://github.com/abreham-atlaw/r_trader
!cd r_trader &&  git checkout deep-reinforcement.training-experiment-cnn.tmp
!pip install cattrs positional-encodings==6.0.1 dropbox pymongo==4.3.3 dependency-injector==4.41.0

mkdir: cannot create directory ‘input’: File exists
Cloning into 'r_trader'...
remote: Enumerating objects: 17124, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 17124 (delta 66), reused 69 (delta 52), pack-reused 17026 (from 3)[K
Receiving objects: 100% (17124/17124), 77.92 MiB | 14.39 MiB/s, done.
Resolving deltas: 100% (12282/12282), done.
Branch 'deep-reinforcement.training-experiment-cnn.tmp' set up to track remote branch 'deep-reinforcement.training-experiment-cnn.tmp' from 'origin'.
Switched to a new branch 'deep-reinforcement.training-experiment-cnn.tmp'


In [2]:
import os
KAGGLE_ENV = os.path.exists("/kaggle/working")
REPO_PATH = "/kaggle/working/r_trader" if KAGGLE_ENV else "/content/r_trader"

print(f"KAGGLE ENV: {KAGGLE_ENV}")

import sys
sys.path.append(REPO_PATH)

KAGGLE ENV: False


In [3]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, Adagrad
import matplotlib.pyplot as plt

import os
import signal

from core.utils.research.data.load.dataset import BaseDataset
from core.utils.research.training.trainer import Trainer
from core.utils.research.model.model.cnn.model import CNN
from core.utils.research.model.model.linear.model import LinearModel
from lib.utils.torch_utils.model_handler import ModelHandler
from core.utils.research.training.callbacks.checkpoint_callback import CheckpointCallback, StoreCheckpointCallback
from core.utils.research.training.data.repositories.checkpoint_repository import CheckpointRepository
from lib.utils.file_storage import PCloudClient
from core.utils.research.training.data.state import TrainingState
from core import Config
from core.utils.research.training.callbacks.metric_callback import MetricCallback
from core.utils.research.training.data.repositories.metric_repository import MetricRepository, MongoDBMetricRepository
from core.utils.kaggle import FusedManager
from core.di import init_di, ApplicationContainer
from core.utils.research.training.data.metric import MetricsContainer
from core.utils.research.model.layers import Indicators
from core.di import ServiceProvider
from core.utils.kaggle.data_repository import KaggleDataRepository
from core.utils.research.losses import ProximalMaskedLoss

[0m PID:12955 [2025-01-28 08:24:12.756257]  XLA is not installed. Training using TPU will not be possible. [93m [0m


In [4]:
def download_data(root, datasets, zip_filename, kernel_mode=True, checksums=None):
    repository = KaggleDataRepository(
        output_path=root,
        zip_filename=zip_filename
    )
    repository.download_multiple(datasets, kernel=kernel_mode, checksums=checksums)
    os.system(f"unzip -d root/")

In [5]:
DATA_ROOT = "/kaggle/input" if KAGGLE_ENV else "/content/input"

DATASETS = [
    f"abrehamatlaw0/rtrader-datapreparer-simsim-cum-0-it-10-0-dataset"
    # for i in range(0,4)
]
CHECKSUMS = [
    '99f3a6b3b3c1482d91a7a8f146ff0787981aaafd51db24c6142e7f60b5a46483',
]
KERNEL_MODE = False
ZIP_FILENAME = "out.zip"
if not KAGGLE_ENV:
    download_data(DATA_ROOT, DATASETS, ZIP_FILENAME, kernel_mode=KERNEL_MODE, checksums=CHECKSUMS)


CONTAINERS = [os.path.join(DATA_ROOT, container) for container in os.listdir(DATA_ROOT)]
DATA_PATHES, TEST_DATA_PATHES = [
    [
        os.path.join(container, "out", type_)
        for container in CONTAINERS
    ]
    for type_ in ["train", "test"]
]

NOTEBOOK_ID = "abrehamalemu/rtrader-training-exp-0-cnn-150-cum-0-it-10-tot"
MODEL_ID = NOTEBOOK_ID.replace("/", "-")

NUM_FILES = None
DATA_CACHE_SIZE = 2
DATALOADER_WORKERS = 4

LR = 1e-4

LOSS_P = 1

BATCH_SIZE = 256
EPOCHS = 100
TIMEOUT = 10*60*60

DTYPE = torch.float32
NP_DTYPE = np.float32

MODEL_URL = None
SAVE_PATH = os.path.abspath("./out/model.zip")
STATE_SAVE_PATH = os.path.abspath("./out/model.zip")

METRIC_REPOSITORY = MongoDBMetricRepository(
    Config.MONGODB_URL,
    MODEL_ID
)

CALLBACKS = [
    StoreCheckpointCallback(
        path=os.path.dirname(SAVE_PATH),
        interval=5,
        fs=ServiceProvider.provide_file_storage(Config.MODEL_IN_PATH),
        active=True
    ),
    MetricCallback(
       METRIC_REPOSITORY
    )
]

[94m PID:12955 [2025-01-28 08:24:12.836403]  Downloading abrehamatlaw0/rtrader-datapreparer-simsim-cum-0-it-10-0-dataset [0m
[94m PID:12955 [2025-01-28 08:24:12.838317]  Downloading to /content/input/abrehamatlaw0-rtrader-datapreparer-simsim-cum-0-it-10-0-dataset [0m
[94m PID:12955 [2025-01-28 08:24:12.839994]  Checking pre-downloaded for /content/input/abrehamatlaw0-rtrader-datapreparer-simsim-cum-0-it-10-0-dataset [0m
[94m PID:12955 [2025-01-28 08:24:12.841653]  Generating checksum for '/content/input/abrehamatlaw0-rtrader-datapreparer-simsim-cum-0-it-10-0-dataset' [0m


In [6]:
repository = CheckpointRepository(
    ServiceProvider.provide_file_storage()
)

In [7]:
state_model = repository.get(MODEL_ID)
# state_model = None
if state_model is None:
    raise ValueError(f"Can't Find Model: {MODEL_ID}")

else:
    print("[+]Using loaded model...")
    state, model = state_model
state = TrainingState(
    epoch=0,
    batch=0,
    id=MODEL_ID
)

[94m PID:12955 [2025-01-28 08:24:15.959003]  Using Storage 0 for abrehamalemu-rtrader-training-exp-0-cnn-150-cum-0-it-10-tot.zip [0m
[+]Using loaded model...


  model.load_state_dict_lazy(torch.load(os.path.join(dirname, 'model_state.pth'), map_location=torch.device('cpu')))


In [8]:
dataset = BaseDataset(
    root_dirs=DATA_PATHES,
    out_dtypes=NP_DTYPE,
    num_files=NUM_FILES
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=DATALOADER_WORKERS, pin_memory=True)



In [9]:
test_dataset = BaseDataset(
    root_dirs=TEST_DATA_PATHES,
    out_dtypes=NP_DTYPE,
    num_files=NUM_FILES
)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=DATALOADER_WORKERS, pin_memory=True)

In [10]:
trainer = Trainer(model, callbacks=CALLBACKS)

[94m PID:12955 [2025-01-28 08:24:27.020256]  Using device: 1 [0m


In [11]:
trainer.cls_loss_function = ProximalMaskedLoss(
    n=len(Config.AGENT_STATE_CHANGE_DELTA_STATIC_BOUND) + 1 ,
    p=LOSS_P,
    softmax=True,
    device=trainer.device,

)
trainer.reg_loss_function = nn.MSELoss()
trainer.optimizer = Adam(trainer.model.parameters(), lr=LR)

In [12]:
class TimeoutException(Exception):
    pass

def handle_timeout(*args, **kwargs):
    raise TimeoutException()

signal.signal(signal.SIGALRM, handle_timeout)
signal.alarm(TIMEOUT)

0

In [None]:
try:
    trainer.train(dataloader, val_dataloader=test_dataloader, epochs=EPOCHS, progress=True, progress_interval=1000, state=state, cls_loss_only=False)
except TimeoutException:
    pass

Model Summary
Layer Name							Number of Parameters
layers.0.weight			3840
layers.0.bias			128
layers.1.weight			49152
layers.1.bias			128
layers.2.weight			49152
layers.2.bias			128
layers.3.weight			49152
layers.3.bias			128
ff_block.layers.0.weight			29162496
ff_block.layers.0.bias			256
ff_block.layers.1.weight			65536
ff_block.layers.1.bias			256
ff_block.layers.2.weight			65536
ff_block.layers.2.bias			256
ff_block.layers.3.weight			65536
ff_block.layers.3.bias			256
ff_block.layers.4.weight			110592
ff_block.layers.4.bias			432
Total Params:29622960
[+]Shuffling dataset...


Epoch 1 loss: 18.54132843017578(cls: 18.540721893310547, reg: 0.0006072913529351354): 100%|██████████| 127/127 [00:16<00:00,  7.53it/s]

Epoch 1 completed, loss: 17.760488510131836(cls: 17.758596420288086, reg: 0.0018954275874421)





Validation loss: loss: 17.657474517822266(cls: 17.65683364868164, reg: 0.0006374060758389533)
[+]Shuffling dataset...


Epoch 2 loss: 16.928913116455078(cls: 16.92829132080078, reg: 0.0006220544455572963): 100%|██████████| 127/127 [00:16<00:00,  7.51it/s]

Epoch 2 completed, loss: 17.681825637817383(cls: 17.68148422241211, reg: 0.0003398231347091496)





Validation loss: loss: 17.637712478637695(cls: 17.637638092041016, reg: 7.956940680742264e-05)
[+]Shuffling dataset...


Epoch 3 loss: 18.093334197998047(cls: 18.093252182006836, reg: 8.232587424572557e-05): 100%|██████████| 127/127 [00:17<00:00,  7.24it/s]

Epoch 3 completed, loss: 17.65068817138672(cls: 17.650449752807617, reg: 0.0002407092833891511)





Validation loss: loss: 17.69122886657715(cls: 17.691171646118164, reg: 5.563030208577402e-05)
[+]Shuffling dataset...


Epoch 4 loss: 18.36224937438965(cls: 18.362201690673828, reg: 4.848768367082812e-05): 100%|██████████| 127/127 [00:17<00:00,  7.33it/s]

Epoch 4 completed, loss: 17.602157592773438(cls: 17.601533889770508, reg: 0.0006246307748369873)





Validation loss: loss: 17.622262954711914(cls: 17.62162971496582, reg: 0.0006268627475947142)
[+]Shuffling dataset...


Epoch 5 loss: 17.63783836364746(cls: 17.63730239868164, reg: 0.0005358900525607169): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 5 completed, loss: 17.512109756469727(cls: 17.51043701171875, reg: 0.0016767471097409725)





Validation loss: loss: 17.63184356689453(cls: 17.628950119018555, reg: 0.00290865171700716)
[+]Shuffling dataset...


Epoch 6 loss: 16.790069580078125(cls: 16.786815643310547, reg: 0.003253997303545475): 100%|██████████| 127/127 [00:17<00:00,  7.33it/s]

Epoch 6 completed, loss: 17.353126525878906(cls: 17.349441528320312, reg: 0.0036896909587085247)





Validation loss: loss: 17.66045570373535(cls: 17.656946182250977, reg: 0.0035013961605727673)
[+]Shuffling dataset...


Epoch 7 loss: 16.50062370300293(cls: 16.496591567993164, reg: 0.0040316046215593815): 100%|██████████| 127/127 [00:17<00:00,  7.35it/s]

Epoch 7 completed, loss: 17.036039352416992(cls: 17.029069900512695, reg: 0.006969777401536703)





Validation loss: loss: 17.80746841430664(cls: 17.799652099609375, reg: 0.007817111909389496)
[+]Shuffling dataset...


Epoch 8 loss: 16.523761749267578(cls: 16.51764678955078, reg: 0.006114624440670013): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 8 completed, loss: 16.416847229003906(cls: 16.40235137939453, reg: 0.014499308541417122)





Validation loss: loss: 17.930606842041016(cls: 17.917112350463867, reg: 0.013499931432306767)
[+]Shuffling dataset...


Epoch 9 loss: 15.30254077911377(cls: 15.2899808883667, reg: 0.012559713795781136): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 9 completed, loss: 15.265317916870117(cls: 15.240419387817383, reg: 0.024899922311306)





Validation loss: loss: 18.453824996948242(cls: 18.43404197692871, reg: 0.01978827826678753)
[+]Shuffling dataset...


Epoch 10 loss: 12.90202808380127(cls: 12.880836486816406, reg: 0.02119145542383194): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 10 completed, loss: 13.580127716064453(cls: 13.547625541687012, reg: 0.0325060673058033)





Validation loss: loss: 19.115646362304688(cls: 19.091886520385742, reg: 0.023778563365340233)
[+]Shuffling dataset...


Epoch 11 loss: 11.594535827636719(cls: 11.570537567138672, reg: 0.023998580873012543): 100%|██████████| 127/127 [00:17<00:00,  7.35it/s]

Epoch 11 completed, loss: 11.644405364990234(cls: 11.6089506149292, reg: 0.03545261546969414)





Validation loss: loss: 20.4276180267334(cls: 20.399093627929688, reg: 0.0285198874771595)
[+]Shuffling dataset...


Epoch 12 loss: 9.13451862335205(cls: 9.106796264648438, reg: 0.027722032740712166): 100%|██████████| 127/127 [00:17<00:00,  7.38it/s]

Epoch 12 completed, loss: 9.675583839416504(cls: 9.638627052307129, reg: 0.03695785999298096)





Validation loss: loss: 22.014583587646484(cls: 21.984922409057617, reg: 0.02964901737868786)
[+]Shuffling dataset...


Epoch 13 loss: 8.284416198730469(cls: 8.257043838500977, reg: 0.02737237699329853): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 13 completed, loss: 8.0485200881958(cls: 8.012267112731934, reg: 0.03625164553523064)





Validation loss: loss: 24.14266586303711(cls: 24.11056137084961, reg: 0.03211774677038193)
[+]Shuffling dataset...


Epoch 14 loss: 6.027218818664551(cls: 5.990544319152832, reg: 0.036674417555332184): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 14 completed, loss: 6.781026840209961(cls: 6.7443647384643555, reg: 0.036662906408309937)





Validation loss: loss: 26.183998107910156(cls: 26.15419578552246, reg: 0.02979424223303795)
[+]Shuffling dataset...


Epoch 15 loss: 6.119129180908203(cls: 6.090217590332031, reg: 0.02891162596642971): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 15 completed, loss: 5.748045921325684(cls: 5.713979244232178, reg: 0.03406313806772232)





Validation loss: loss: 28.396793365478516(cls: 28.367595672607422, reg: 0.029214875772595406)
[+]Shuffling dataset...


Epoch 16 loss: 5.696603775024414(cls: 5.668907165527344, reg: 0.0276965145021677): 100%|██████████| 127/127 [00:17<00:00,  7.40it/s]

Epoch 16 completed, loss: 4.994128704071045(cls: 4.962336540222168, reg: 0.031791698187589645)





Validation loss: loss: 30.0073184967041(cls: 29.97765350341797, reg: 0.02965725213289261)
[+]Shuffling dataset...


Epoch 17 loss: 4.1571855545043945(cls: 4.129333972930908, reg: 0.027851693332195282): 100%|██████████| 127/127 [00:17<00:00,  7.38it/s]

Epoch 17 completed, loss: 4.485151767730713(cls: 4.455785751342773, reg: 0.029364289715886116)





Validation loss: loss: 30.490089416503906(cls: 30.463232040405273, reg: 0.026870639994740486)
[+]Shuffling dataset...


Epoch 18 loss: 3.524207592010498(cls: 3.495201587677002, reg: 0.029006078839302063): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 18 completed, loss: 4.085842609405518(cls: 4.058261394500732, reg: 0.02758064679801464)





Validation loss: loss: 32.1374397277832(cls: 32.10975646972656, reg: 0.02768135443329811)
[+]Shuffling dataset...


Epoch 19 loss: 3.233517646789551(cls: 3.206157922744751, reg: 0.027359606698155403): 100%|██████████| 127/127 [00:17<00:00,  7.28it/s]

Epoch 19 completed, loss: 3.750925064086914(cls: 3.724879264831543, reg: 0.026047175750136375)





Validation loss: loss: 32.87677001953125(cls: 32.8546257019043, reg: 0.022131826728582382)
[+]Shuffling dataset...


Epoch 20 loss: 2.963660955429077(cls: 2.944131374359131, reg: 0.019529469311237335): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 20 completed, loss: 3.4199888706207275(cls: 3.3957250118255615, reg: 0.02426282688975334)





Validation loss: loss: 33.69263458251953(cls: 33.6629524230957, reg: 0.029673704877495766)
[+]Shuffling dataset...


Epoch 21 loss: 4.205653190612793(cls: 4.170458793640137, reg: 0.0351945236325264): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 21 completed, loss: 3.192621946334839(cls: 3.168642520904541, reg: 0.023979797959327698)





Validation loss: loss: 34.18791198730469(cls: 34.16633224487305, reg: 0.021580830216407776)
[+]Shuffling dataset...


Epoch 22 loss: 2.9422924518585205(cls: 2.922329902648926, reg: 0.01996248960494995): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 22 completed, loss: 3.1406054496765137(cls: 3.1183598041534424, reg: 0.022245459258556366)





Validation loss: loss: 34.59537124633789(cls: 34.575416564941406, reg: 0.019936975091695786)
[+]Shuffling dataset...


Epoch 23 loss: 2.0504517555236816(cls: 2.0358974933624268, reg: 0.014554229564964771): 100%|██████████| 127/127 [00:17<00:00,  7.34it/s]

Epoch 23 completed, loss: 2.9997236728668213(cls: 2.9785454273223877, reg: 0.02117762714624405)





Validation loss: loss: 35.38886260986328(cls: 35.367576599121094, reg: 0.021283796057105064)
[+]Shuffling dataset...


Epoch 24 loss: 3.927760601043701(cls: 3.9067225456237793, reg: 0.021037956699728966): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 24 completed, loss: 2.8985133171081543(cls: 2.8788352012634277, reg: 0.019680507481098175)





Validation loss: loss: 35.88325500488281(cls: 35.86418914794922, reg: 0.01904742792248726)
[+]Shuffling dataset...


Epoch 25 loss: 2.4641973972320557(cls: 2.4478960037231445, reg: 0.016301343217492104): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 25 completed, loss: 2.8324966430664062(cls: 2.8137314319610596, reg: 0.01876521296799183)





Validation loss: loss: 35.645591735839844(cls: 35.625972747802734, reg: 0.019620543345808983)
[+]Shuffling dataset...


Epoch 26 loss: 2.049588918685913(cls: 2.0281662940979004, reg: 0.021422643214464188): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 26 completed, loss: 2.75925874710083(cls: 2.741060972213745, reg: 0.01819729246199131)





Validation loss: loss: 36.268035888671875(cls: 36.251060485839844, reg: 0.01701582781970501)
[+]Shuffling dataset...


Epoch 27 loss: 2.613513708114624(cls: 2.59639835357666, reg: 0.01711530052125454): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 27 completed, loss: 2.6758604049682617(cls: 2.6584603786468506, reg: 0.01740012690424919)





Validation loss: loss: 37.49384689331055(cls: 37.47705841064453, reg: 0.016767023131251335)
[+]Shuffling dataset...


Epoch 28 loss: 2.505120277404785(cls: 2.4917612075805664, reg: 0.01335899531841278): 100%|██████████| 127/127 [00:17<00:00,  7.35it/s]

Epoch 28 completed, loss: 2.610546350479126(cls: 2.5939533710479736, reg: 0.0165926031768322)





Validation loss: loss: 37.43688201904297(cls: 37.41938781738281, reg: 0.017517680302262306)
[+]Shuffling dataset...


Epoch 29 loss: 2.2090182304382324(cls: 2.192439317703247, reg: 0.0165789183229208): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 29 completed, loss: 2.6243624687194824(cls: 2.608564615249634, reg: 0.01579876057803631)





Validation loss: loss: 37.63566970825195(cls: 37.61748504638672, reg: 0.018193433061242104)
[+]Shuffling dataset...


Epoch 30 loss: 2.8528339862823486(cls: 2.833500862121582, reg: 0.01933312974870205): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 30 completed, loss: 2.6839020252227783(cls: 2.6688902378082275, reg: 0.015010804869234562)





Validation loss: loss: 36.95987319946289(cls: 36.94169616699219, reg: 0.01822681352496147)
[+]Shuffling dataset...


Epoch 31 loss: 2.209418773651123(cls: 2.1939220428466797, reg: 0.015496738255023956): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 31 completed, loss: 2.6963136196136475(cls: 2.680818796157837, reg: 0.015495088882744312)





Validation loss: loss: 37.345035552978516(cls: 37.33201599121094, reg: 0.013056879863142967)
[+]Shuffling dataset...


Epoch 32 loss: 3.582444429397583(cls: 3.5678653717041016, reg: 0.014579162001609802): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 32 completed, loss: 2.727651596069336(cls: 2.7134149074554443, reg: 0.014236810617148876)





Validation loss: loss: 37.505828857421875(cls: 37.493831634521484, reg: 0.012001809664070606)
[+]Shuffling dataset...


Epoch 33 loss: 1.7666305303573608(cls: 1.7557744979858398, reg: 0.01085606124252081): 100%|██████████| 127/127 [00:17<00:00,  7.34it/s]

Epoch 33 completed, loss: 2.639477014541626(cls: 2.625868797302246, reg: 0.013607807457447052)





Validation loss: loss: 37.25431442260742(cls: 37.242862701416016, reg: 0.011471662670373917)
[+]Shuffling dataset...


Epoch 34 loss: 2.556572437286377(cls: 2.54709529876709, reg: 0.009477231651544571): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 34 completed, loss: 2.565337896347046(cls: 2.5522522926330566, reg: 0.01308569498360157)





Validation loss: loss: 37.96877670288086(cls: 37.95556640625, reg: 0.013207362033426762)
[+]Shuffling dataset...


Epoch 35 loss: 2.087059259414673(cls: 2.0767745971679688, reg: 0.0102846072986722): 100%|██████████| 127/127 [00:17<00:00,  7.38it/s]

Epoch 35 completed, loss: 2.4595725536346436(cls: 2.446748733520508, reg: 0.012823420576751232)





Validation loss: loss: 37.89279556274414(cls: 37.87610626220703, reg: 0.01668662391602993)
[+]Shuffling dataset...


Epoch 36 loss: 2.3493738174438477(cls: 2.335143566131592, reg: 0.014230232685804367): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 36 completed, loss: 2.4760310649871826(cls: 2.463534116744995, reg: 0.01249657105654478)





Validation loss: loss: 38.274574279785156(cls: 38.263397216796875, reg: 0.011162465438246727)
[+]Shuffling dataset...


Epoch 37 loss: 1.920501470565796(cls: 1.9111762046813965, reg: 0.00932527706027031): 100%|██████████| 127/127 [00:17<00:00,  7.25it/s]

Epoch 37 completed, loss: 2.419434070587158(cls: 2.407764434814453, reg: 0.011669846251606941)





Validation loss: loss: 38.952823638916016(cls: 38.94279479980469, reg: 0.010069410316646099)
[+]Shuffling dataset...


Epoch 38 loss: 1.941607117652893(cls: 1.928889274597168, reg: 0.012717824429273605): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 38 completed, loss: 2.357501268386841(cls: 2.3462610244750977, reg: 0.011239771731197834)





Validation loss: loss: 39.04338073730469(cls: 39.03058624267578, reg: 0.012789801694452763)
[+]Shuffling dataset...


Epoch 39 loss: 3.1953775882720947(cls: 3.181354522705078, reg: 0.014023014344274998): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 39 completed, loss: 2.343768358230591(cls: 2.332305669784546, reg: 0.011462687514722347)





Validation loss: loss: 39.38865280151367(cls: 39.37605285644531, reg: 0.012619995512068272)
[+]Shuffling dataset...


Epoch 40 loss: 2.17317271232605(cls: 2.163325309753418, reg: 0.00984750036150217): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 40 completed, loss: 2.36342453956604(cls: 2.3518309593200684, reg: 0.011593344621360302)





Validation loss: loss: 39.50221633911133(cls: 39.490543365478516, reg: 0.01169554702937603)
[+]Shuffling dataset...


Epoch 41 loss: 2.832517147064209(cls: 2.8233234882354736, reg: 0.009193738922476768): 100%|██████████| 127/127 [00:17<00:00,  7.37it/s]

Epoch 41 completed, loss: 2.414231538772583(cls: 2.4022815227508545, reg: 0.011950274929404259)





Validation loss: loss: 39.834228515625(cls: 39.81852722167969, reg: 0.01570173352956772)
[+]Shuffling dataset...


Epoch 42 loss: 2.937509059906006(cls: 2.9212841987609863, reg: 0.016224971041083336): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 42 completed, loss: 2.460190773010254(cls: 2.449096202850342, reg: 0.011093887500464916)





Validation loss: loss: 39.390113830566406(cls: 39.37956237792969, reg: 0.010577262379229069)
[+]Shuffling dataset...


Epoch 43 loss: 2.9805567264556885(cls: 2.9715476036071777, reg: 0.00900903157889843): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 43 completed, loss: 2.5136196613311768(cls: 2.5030133724212646, reg: 0.01060653105378151)





Validation loss: loss: 39.465946197509766(cls: 39.453670501708984, reg: 0.012268373742699623)
[+]Shuffling dataset...


Epoch 44 loss: 2.3209750652313232(cls: 2.310345411300659, reg: 0.010629748925566673): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 44 completed, loss: 2.5201849937438965(cls: 2.5096118450164795, reg: 0.01057238969951868)





Validation loss: loss: 39.16060256958008(cls: 39.15083312988281, reg: 0.009749664925038815)
[+]Shuffling dataset...


Epoch 45 loss: 2.4440879821777344(cls: 2.4352033138275146, reg: 0.008884714916348457): 100%|██████████| 127/127 [00:17<00:00,  7.33it/s]

Epoch 45 completed, loss: 2.5449869632720947(cls: 2.535102605819702, reg: 0.009885010309517384)





Validation loss: loss: 39.88725662231445(cls: 39.875240325927734, reg: 0.012006179429590702)
[+]Shuffling dataset...


Epoch 46 loss: 1.708492398262024(cls: 1.6997261047363281, reg: 0.008766324259340763): 100%|██████████| 127/127 [00:17<00:00,  7.33it/s]

Epoch 46 completed, loss: 2.5514605045318604(cls: 2.5414137840270996, reg: 0.010047419928014278)





Validation loss: loss: 39.11075210571289(cls: 39.10250473022461, reg: 0.00826306827366352)
[+]Shuffling dataset...


Epoch 47 loss: 2.539750814437866(cls: 2.534681558609009, reg: 0.005069190636277199): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 47 completed, loss: 2.5836470127105713(cls: 2.5739307403564453, reg: 0.009716848842799664)





Validation loss: loss: 39.783329010009766(cls: 39.77321243286133, reg: 0.010120309889316559)
[+]Shuffling dataset...


Epoch 48 loss: 2.7281646728515625(cls: 2.7215003967285156, reg: 0.00666438415646553): 100%|██████████| 127/127 [00:17<00:00,  7.34it/s]

Epoch 48 completed, loss: 2.523624897003174(cls: 2.5145037174224854, reg: 0.009122327901422977)





Validation loss: loss: 39.718990325927734(cls: 39.709964752197266, reg: 0.009036384522914886)
[+]Shuffling dataset...


Epoch 49 loss: 2.407927989959717(cls: 2.3986096382141113, reg: 0.009318238124251366): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 49 completed, loss: 2.5291635990142822(cls: 2.5194294452667236, reg: 0.009734437800943851)





Validation loss: loss: 40.28422546386719(cls: 40.271785736083984, reg: 0.0124644311144948)
[+]Shuffling dataset...


Epoch 50 loss: 2.0352940559387207(cls: 2.027245283126831, reg: 0.008048726245760918): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 50 completed, loss: 2.6400504112243652(cls: 2.6308670043945312, reg: 0.009182850830256939)





Validation loss: loss: 39.186397552490234(cls: 39.17841720581055, reg: 0.007992574945092201)
[+]Shuffling dataset...


Epoch 51 loss: 2.2294578552246094(cls: 2.2217066287994385, reg: 0.007751237601041794): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 51 completed, loss: 2.6493594646453857(cls: 2.6409475803375244, reg: 0.008413257077336311)





Validation loss: loss: 39.16236877441406(cls: 39.153926849365234, reg: 0.008435075171291828)
[+]Shuffling dataset...


Epoch 52 loss: 2.263833522796631(cls: 2.25567626953125, reg: 0.008157188072800636): 100%|██████████| 127/127 [00:17<00:00,  7.35it/s]

Epoch 52 completed, loss: 2.629625082015991(cls: 2.6211986541748047, reg: 0.008426612243056297)





Validation loss: loss: 38.874961853027344(cls: 38.865257263183594, reg: 0.009668393060564995)
[+]Shuffling dataset...


Epoch 53 loss: 2.593611478805542(cls: 2.5814478397369385, reg: 0.012163609266281128): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 53 completed, loss: 2.561040163040161(cls: 2.5528671741485596, reg: 0.00817190669476986)





Validation loss: loss: 39.185646057128906(cls: 39.17757797241211, reg: 0.008089453913271427)
[+]Shuffling dataset...


Epoch 54 loss: 2.9537343978881836(cls: 2.939229965209961, reg: 0.014504523016512394): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 54 completed, loss: 2.55318546295166(cls: 2.5440127849578857, reg: 0.009172285906970501)





Validation loss: loss: 39.818817138671875(cls: 39.80769348144531, reg: 0.011117654852569103)
[+]Shuffling dataset...


Epoch 55 loss: 2.676915407180786(cls: 2.666687488555908, reg: 0.01022790651768446): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 55 completed, loss: 2.457765579223633(cls: 2.4494502544403076, reg: 0.008316298015415668)





Validation loss: loss: 39.63652420043945(cls: 39.62808609008789, reg: 0.008440621197223663)
[+]Shuffling dataset...


Epoch 56 loss: 2.705078601837158(cls: 2.693037986755371, reg: 0.012040693312883377): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 56 completed, loss: 2.4135782718658447(cls: 2.405817985534668, reg: 0.007760826963931322)





Validation loss: loss: 39.76865768432617(cls: 39.76174545288086, reg: 0.006881995126605034)
[+]Shuffling dataset...


Epoch 57 loss: 1.681739330291748(cls: 1.6754076480865479, reg: 0.006331705488264561): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 57 completed, loss: 2.3854215145111084(cls: 2.378307580947876, reg: 0.00711420550942421)





Validation loss: loss: 40.843807220458984(cls: 40.836219787597656, reg: 0.007612625602632761)
[+]Shuffling dataset...


Epoch 58 loss: 3.1447083950042725(cls: 3.1376309394836426, reg: 0.007077373098582029): 100%|██████████| 127/127 [00:17<00:00,  7.33it/s]

Epoch 58 completed, loss: 2.4889142513275146(cls: 2.480681896209717, reg: 0.00823278073221445)





Validation loss: loss: 40.39965057373047(cls: 40.391387939453125, reg: 0.008287347853183746)
[+]Shuffling dataset...


Epoch 59 loss: 2.91432523727417(cls: 2.907482624053955, reg: 0.006842613685876131): 100%|██████████| 127/127 [00:17<00:00,  7.25it/s]

Epoch 59 completed, loss: 2.5340232849121094(cls: 2.525916576385498, reg: 0.008106354624032974)





Validation loss: loss: 39.75876235961914(cls: 39.75019454956055, reg: 0.008574148640036583)
[+]Shuffling dataset...


Epoch 60 loss: 2.2614798545837402(cls: 2.2538819313049316, reg: 0.0075979651883244514): 100%|██████████| 127/127 [00:17<00:00,  7.32it/s]

Epoch 60 completed, loss: 2.5370895862579346(cls: 2.529008626937866, reg: 0.008080803789198399)





Validation loss: loss: 40.064693450927734(cls: 40.05867004394531, reg: 0.006031927186995745)
[+]Shuffling dataset...


Epoch 61 loss: 3.7270753383636475(cls: 3.722405433654785, reg: 0.004669905640184879): 100%|██████████| 127/127 [00:17<00:00,  7.34it/s]

Epoch 61 completed, loss: 2.536393165588379(cls: 2.5294101238250732, reg: 0.006983350496739149)





Validation loss: loss: 40.47378158569336(cls: 40.46821975708008, reg: 0.005522145424038172)
[+]Shuffling dataset...


Epoch 62 loss: 2.813311815261841(cls: 2.808041572570801, reg: 0.005270249675959349): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 62 completed, loss: 2.610046148300171(cls: 2.6029932498931885, reg: 0.0070520080626010895)





Validation loss: loss: 40.06462860107422(cls: 40.055782318115234, reg: 0.008852241560816765)
[+]Shuffling dataset...


Epoch 63 loss: 1.7521257400512695(cls: 1.7438311576843262, reg: 0.00829462893307209): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 63 completed, loss: 2.5285227298736572(cls: 2.521899700164795, reg: 0.006624344736337662)





Validation loss: loss: 40.07902908325195(cls: 40.07358932495117, reg: 0.005439646542072296)
[+]Shuffling dataset...


Epoch 64 loss: 2.4958324432373047(cls: 2.4903366565704346, reg: 0.005495905876159668): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 64 completed, loss: 2.5132977962493896(cls: 2.5072124004364014, reg: 0.006085171364247799)





Validation loss: loss: 40.31452941894531(cls: 40.30679702758789, reg: 0.007728015538305044)
[+]Shuffling dataset...


Epoch 65 loss: 3.0695252418518066(cls: 3.0616559982299805, reg: 0.007869217544794083): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 65 completed, loss: 2.479750633239746(cls: 2.472998857498169, reg: 0.006752354092895985)





Validation loss: loss: 40.2116813659668(cls: 40.204017639160156, reg: 0.007651823107153177)
[+]Shuffling dataset...


Epoch 66 loss: 1.8444650173187256(cls: 1.8376595973968506, reg: 0.006805455312132835): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 66 completed, loss: 2.3766260147094727(cls: 2.3703062534332275, reg: 0.006320428103208542)





Validation loss: loss: 40.748294830322266(cls: 40.74077224731445, reg: 0.007489433512091637)
[+]Shuffling dataset...


Epoch 67 loss: 1.440264344215393(cls: 1.433964490890503, reg: 0.00629980955272913): 100%|██████████| 127/127 [00:17<00:00,  7.36it/s]

Epoch 67 completed, loss: 2.362126350402832(cls: 2.3556694984436035, reg: 0.006456384435296059)





Validation loss: loss: 40.6761589050293(cls: 40.668861389160156, reg: 0.007299704942852259)
[+]Shuffling dataset...


Epoch 68 loss: 2.59834361076355(cls: 2.591362476348877, reg: 0.00698122987523675): 100%|██████████| 127/127 [00:17<00:00,  7.35it/s]

Epoch 68 completed, loss: 2.3839805126190186(cls: 2.3778433799743652, reg: 0.006137295160442591)





Validation loss: loss: 40.837669372558594(cls: 40.831607818603516, reg: 0.006069088354706764)
[+]Shuffling dataset...


Epoch 69 loss: 2.4648730754852295(cls: 2.459714412689209, reg: 0.005158653017133474): 100%|██████████| 127/127 [00:17<00:00,  7.22it/s]

Epoch 69 completed, loss: 2.3831324577331543(cls: 2.3770692348480225, reg: 0.006062739063054323)





Validation loss: loss: 41.395870208740234(cls: 41.39014434814453, reg: 0.0057535842061042786)
[+]Shuffling dataset...


Epoch 70 loss: 3.3690881729125977(cls: 3.363858699798584, reg: 0.005229353904724121): 100%|██████████| 127/127 [00:17<00:00,  7.18it/s]

Epoch 70 completed, loss: 2.3606011867523193(cls: 2.3549060821533203, reg: 0.005695861764252186)





Validation loss: loss: 41.3485221862793(cls: 41.338558197021484, reg: 0.00997096672654152)
[+]Shuffling dataset...


Epoch 71 loss: 2.6588258743286133(cls: 2.6502981185913086, reg: 0.008527846075594425): 100%|██████████| 127/127 [00:17<00:00,  7.30it/s]

Epoch 71 completed, loss: 2.4210803508758545(cls: 2.414806365966797, reg: 0.006273944396525621)





Validation loss: loss: 41.12207794189453(cls: 41.115028381347656, reg: 0.00706158671528101)
[+]Shuffling dataset...


Epoch 72 loss: 2.2678964138031006(cls: 2.2610421180725098, reg: 0.006854240782558918): 100%|██████████| 127/127 [00:17<00:00,  7.31it/s]

Epoch 72 completed, loss: 2.6107335090637207(cls: 2.6039016246795654, reg: 0.0068329088389873505)





Validation loss: loss: 40.71703338623047(cls: 40.706756591796875, reg: 0.010285728611052036)
[+]Shuffling dataset...


Epoch 73 loss: 3.9050097465515137(cls: 3.895170211791992, reg: 0.009839619509875774): 100%|██████████| 127/127 [00:17<00:00,  7.29it/s]

Epoch 73 completed, loss: 2.6979939937591553(cls: 2.6916401386260986, reg: 0.006354909855872393)





Validation loss: loss: 40.014888763427734(cls: 40.008567810058594, reg: 0.006326151546090841)
[+]Shuffling dataset...


Epoch 74 loss: 2.7492265701293945(cls: 2.7437429428100586, reg: 0.005483583547174931): 100%|██████████| 127/127 [00:17<00:00,  7.27it/s]

Epoch 74 completed, loss: 2.7875585556030273(cls: 2.7814018726348877, reg: 0.0061560566537082195)





Validation loss: loss: 39.81277847290039(cls: 39.80714797973633, reg: 0.0056529175490140915)
[+]Shuffling dataset...


Epoch 75 loss: 2.5490264892578125(cls: 2.5432519912719727, reg: 0.005774587392807007): 100%|██████████| 127/127 [00:17<00:00,  7.28it/s]

Epoch 75 completed, loss: 2.733380079269409(cls: 2.72778058052063, reg: 0.00560083007439971)





Validation loss: loss: 39.746315002441406(cls: 39.73957061767578, reg: 0.006742487195879221)
[+]Shuffling dataset...


Epoch 76 loss: 1.8895882368087769(cls: 1.8833880424499512, reg: 0.006200247444212437): 100%|██████████| 127/127 [00:17<00:00,  7.16it/s]

Epoch 76 completed, loss: 2.68506121635437(cls: 2.679381847381592, reg: 0.005678991787135601)





Validation loss: loss: 38.59113311767578(cls: 38.584102630615234, reg: 0.0070134676061570644)
[+]Shuffling dataset...


Epoch 77 loss: 2.482888698577881(cls: 2.4768853187561035, reg: 0.0060034869238734245):  62%|██████▏   | 79/127 [00:10<00:06,  7.42it/s]

In [None]:
ModelHandler.save(model, SAVE_PATH)

In [None]:
repository.update(trainer.state, trainer.model)

In [None]:
metrics = MetricsContainer()
for metric in METRIC_REPOSITORY.get_all():
    metrics.add_metric(metric)

for i in range(3):
    train_losses = [metric.value[i] for metric in metrics.filter_metrics(source=0)]
    val_losses = [metric.value[i] for metric in metrics.filter_metrics(source=1)]
    plt.figure()
    plt.plot(train_losses)
    plt.plot(val_losses)
    plt.show()

In [None]:
for X, y in test_dataloader:
    break
y_hat = model(X.to(trainer.device)).detach().cpu().numpy()

import matplotlib.pyplot as plt
def softmax(x):
    exp_x = np.exp(x - np.max(x))
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x

def scale(x):
    x = softmax(x)
    x = x / np.max(x)
    return x

for i in range(y_hat.shape[0]):
    plt.figure()
    plt.plot(y[i, :-1])
    plt.plot(scale(y_hat[i, :-1]))


In [None]:
!rm -fr r_trader