In [2]:
import pandas as pd
import numpy as np

In [3]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.metrics import F1
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import TensorBoardLogger

In [4]:
N_CURRENCIES = 1
INPUT_FEATURE_SIZE = 1
WINDOW_SIZE = 50

In [5]:
btc = pd.read_csv("../data/0_raw/BTC_USD_2013-10-01_2021-04-21-CoinDesk.csv")

In [6]:
btc.columns

Index(['Currency', 'Date', 'Closing Price (USD)', '24h Open (USD)',
       '24h High (USD)', '24h Low (USD)'],
      dtype='object')

In [7]:
btc.Date = btc.Date.apply(pd.Timestamp)

In [8]:
btc = btc.sort_values("Date", ascending=True)

In [9]:
btc

Unnamed: 0,Currency,Date,Closing Price (USD),24h Open (USD),24h High (USD),24h Low (USD)
0,BTC,2013-10-01,123.654990,124.304660,124.751660,122.563490
1,BTC,2013-10-02,125.455000,123.654990,125.758500,123.633830
2,BTC,2013-10-03,108.584830,125.455000,125.665660,83.328330
3,BTC,2013-10-04,118.674660,108.584830,118.675000,107.058160
4,BTC,2013-10-05,121.338660,118.674660,121.936330,118.005660
...,...,...,...,...,...,...
2754,BTC,2021-04-17,61965.782598,63225.093917,63520.325374,60033.534667
2755,BTC,2021-04-18,60574.444728,61444.232503,62534.028498,59802.889267
2756,BTC,2021-04-19,56850.830166,60191.525406,60531.988848,52148.983544
2757,BTC,2021-04-20,56224.101588,56335.389141,57609.368118,54449.245330


In [10]:
btc = btc[btc.Date >= pd.Timestamp("2018")]

In [11]:
len(btc)

1206

In [12]:
btc[(btc.Date < pd.Timestamp("2020"))]

Unnamed: 0,Currency,Date,Closing Price (USD),24h Open (USD),24h High (USD),24h Low (USD)
1553,BTC,2018-01-01,13439.417500,13062.145000,14213.441250,12587.603750
1554,BTC,2018-01-02,13337.621250,13439.417500,13892.242500,12859.802500
1555,BTC,2018-01-03,14881.545000,13337.621250,15216.756250,12955.965000
1556,BTC,2018-01-04,15104.450000,14881.545000,15394.986250,14588.595000
1557,BTC,2018-01-05,14953.852500,15104.450000,15194.406250,14225.166250
...,...,...,...,...,...,...
2277,BTC,2019-12-27,7183.706536,7212.808361,7427.472280,7105.723864
2278,BTC,2019-12-28,7227.293712,7183.706083,7251.381246,7065.278308
2279,BTC,2019-12-29,7311.560644,7227.294388,7348.789794,7217.079597
2280,BTC,2019-12-30,7385.464848,7315.151548,7520.637034,7272.593791


In [13]:
int(len(btc)*7/10)

844

In [14]:
btc = btc[['24h Open (USD)', '24h High (USD)', '24h Low (USD)', 'Closing Price (USD)']]

In [15]:
btc.columns = ["open", "high", "low", "close"]

In [16]:
btc["mid"] = (btc["high"] +btc["low"]) / 2.0

In [17]:
btc

Unnamed: 0,open,high,low,close,mid
1553,13062.145000,14213.441250,12587.603750,13439.417500,13400.522500
1554,13439.417500,13892.242500,12859.802500,13337.621250,13376.022500
1555,13337.621250,15216.756250,12955.965000,14881.545000,14086.360625
1556,14881.545000,15394.986250,14588.595000,15104.450000,14991.790625
1557,15104.450000,15194.406250,14225.166250,14953.852500,14709.786250
...,...,...,...,...,...
2754,63225.093917,63520.325374,60033.534667,61965.782598,61776.930021
2755,61444.232503,62534.028498,59802.889267,60574.444728,61168.458883
2756,60191.525406,60531.988848,52148.983544,56850.830166,56340.486196
2757,56335.389141,57609.368118,54449.245330,56224.101588,56029.306724


In [18]:
btc.close

1553    13439.417500
1554    13337.621250
1555    14881.545000
1556    15104.450000
1557    14953.852500
            ...     
2754    61965.782598
2755    60574.444728
2756    56850.830166
2757    56224.101588
2758    56608.769748
Name: close, Length: 1206, dtype: float64

In [19]:
time_series = btc.close.to_numpy()

In [20]:
train_index =  int(len(btc)*70/100)
val_index = int(len(btc)*85/100)
test_index = len(btc)
train_index, val_index, test_index

(844, 1025, 1206)

In [21]:
class TimeSeriesDataset(Dataset):
    def __init__(self, x: np.ndarray, seq_len = WINDOW_SIZE):
        self.x = torch.tensor(x).float()
        self.seq_len = seq_len
        
    def __len__(self):
        #return len(self.x) - ( self.seq_len -1 ) #sliding window count
        return len(self.x) - ( self.seq_len)
    
    def __getitem__(self, index):
        #return (self.x[index:index+self.seq_len], self.x[index+self.seq_len]) # regression
        
        window = self.x[index:index+self.seq_len]
        price_change = self.x[index+self.seq_len] - self.x[index+self.seq_len-1]
        price_change = 0 if price_change == 0 else 1 if price_change>0 else 2 #2 düşüş
        return (window, price_change)

In [31]:
class LSTM_based_classification_model(pl.LightningModule):
    def __init__(self,
                 data = time_series,
                 num_classes = 3,
                 window_size = WINDOW_SIZE,
                 input_size = 1,
                 batch_size=8,
                 lstm_hidden_size = 64,
                 lstm_stack_size = 2,
                 lstm_dropout = 0.5,
                 bidirectional = False,
                 ):
        
        super().__init__()
        self.data = time_series
        self.num_classes = num_classes
        self.window_size = window_size
        self.input_size = input_size
        self.batch_size = batch_size
        
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_stack_size = lstm_stack_size
        self.lstm_dropout = lstm_dropout
        self.bidirectional = bidirectional 
        
        self.stack_lstm = nn.LSTM(input_size = self.input_size, 
                hidden_size = self.lstm_hidden_size, 
                num_layers= self.lstm_stack_size,
                dropout = self.lstm_dropout,
                bidirectional = self.bidirectional, 
                batch_first=True,)
        
#         self.linear1 = nn.Linear(self.lstm_hidden_size, 128)
        
#         self.linear2 = nn.Linear(128, 64)
        
#         self.activation = nn.ReLU()
        self.output_layer = nn.Linear(64, self.num_classes)
        
        self.f1_score = pl.metrics.F1(num_classes=self.num_classes)
        self.accuracy_score = pl.metrics.Accuracy()
        
        self.train_dl = DataLoader(TimeSeriesDataset(self.data[:train_index], self.window_size), 
                                   batch_size=self.batch_size)
        
        self.val_dl = DataLoader(TimeSeriesDataset(self.data[train_index:val_index], self.window_size),
                                 batch_size=self.batch_size)
        
        self.test_dl = DataLoader(TimeSeriesDataset(self.data[val_index:], self.window_size),
                                  batch_size=self.batch_size)
    
    def forward(self, x):
        
        x = x.view(x.size()[0], x.size()[1], self.input_size) #(batch, window_len, feature_size)
        
        x, _=  self.stack_lstm(x)
        x = x[:, -1, :] # equivalent to return sequence = False on keras :)
        
#         x = self.linear1(x)
#         x = self.activation(x)
        
#         x = self.linear2(x)
#         x = self.activation(x)
        
        output = self.output_layer(x)
        #print("output1", output[0])
        output = F.log_softmax(output, dim = 1)
        #output = F.softmax(output)
        #print ("output", output.size())
        return output
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        output = self(x)
        loss = F.nll_loss(output, y)
        self.log('train_loss', loss, on_step=True, prog_bar=True)
        
        acc = self.accuracy_score(torch.max(output, dim=1)[1], y)
        self.log('train_acc', acc, on_step=True, prog_bar=True)
        
        f1 = self.f1_score(torch.max(output, dim=1)[1], y)
        self.log('train_f1', f1, on_step=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_nb):
        x, y = batch
        output = self(x)
        loss = F.nll_loss(x, y)
        self.log('val_loss', loss, on_epoch=True, reduce_fx=torch.mean, prog_bar=True)
        
        #print(torch.max(output, dim=1)[1])
        acc = self.accuracy_score(torch.max(output, dim=1)[1], y)
        self.log('val_acc', acc, on_epoch=True, reduce_fx=torch.mean, prog_bar=True)
        
        f1 = self.f1_score(torch.max(output, dim=1)[1], y)
        self.log('val_f1', f1, on_epoch=True, reduce_fx=torch.mean, prog_bar=True)
        
    def test_step(self, batch, batch_nb):
        x, y = batch
        output = self(x)
        
        print(y, torch.max(output, dim=1)[1])
        
        loss = F.nll_loss(x, y)
        self.log('test_loss', loss, on_epoch=True, reduce_fx=torch.mean)
        
        acc = self.accuracy_score(torch.max(output, dim=1)[1], y)
        self.log('test_acc', acc, on_epoch=True, reduce_fx=torch.mean)
        
        f1 = self.f1_score(torch.max(output, dim=1)[1], y)
        self.log('test_f1', f1, on_epoch=True, reduce_fx=torch.mean)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=6e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
        #weight and biases
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl

In [33]:
!rm -rf ../output/models/lstm_v1/version_*

In [34]:
#logger = WandbLogger(name='lstm.v4',project='pytorchlightning')
logger = TensorBoardLogger("../output/models/lstm_model_logs", name="lstm_v1")

In [41]:
model = LSTM_based_classification_model(batch_size=128)
trainer = pl.Trainer(gpus=-1, 
                     logger = logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [42]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type     | Params
--------------------------------------------
0 | stack_lstm     | LSTM     | 50.4 K
1 | output_layer   | Linear   | 195   
2 | f1_score       | F1       | 0     
3 | accuracy_score | Accuracy | 0     
--------------------------------------------
50.6 K    Trainable params
0         Non-trainable params
50.6 K    Total params
0.203     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [43]:
trainer.test()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

tensor([2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2,
        1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 2,
        1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1,
        1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1,
        1, 1, 1, 1, 2, 1, 2, 2], device='cuda:0') tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([2, 2, 1], device='cuda:0') tensor([2, 2, 2], device='cuda:0')

-----------------------------

[{'test_loss': -28952.986328125,
  'test_acc': 0.40458014607429504,
  'test_f1': 0.40458014607429504}]

In [47]:
train = TimeSeriesDataset(time_series[:train_index])
labels = [train[i][1] for i in range(train.__len__())]

In [55]:
pd.DataFrame({"label": labels}).groupby("label").size()

label
1    413
2    381
dtype: int64

[34m[1mwandb[0m: 500 encountered ({"errors":[{"message":"Error 1040: Too many connections","path":["project"]}],"data":{"project":null}}), retrying request
[34m[1mwandb[0m: Network error resolved after 0:00:43.808244, resuming normal operation.


In [259]:
#dropout, batch normalization 

False

In [115]:
model.test_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7f46e05073d0>