In [1]:
#conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [2]:
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

#score results
def score_model(true, pred, columns = ""):
    cm = confusion_matrix(true,pred)
    l = np.sum(cm)
    cm = cm/l
    s = [accuracy_score(true, pred), matthews_corrcoef(true, pred), f1_score(true,pred), cm[0,0], cm[1,1],cm[0,1],cm[1,0]]
    return pd.DataFrame(data = s, index = ['accuracy', 'matthew_corr', 'f1', 'tn', 'tp','fp','fn'], columns = [columns])


In [39]:
#prediction

#attention network, can be used on a sequence of any legnth of word embeddings
class AttentionRnn(nn.Module):
    def __init__(self, dim_input, dim_output):
        super(AttentionRnn, self).__init__()
        self.attention = nn.RNN(dim_input, dim_output, batch_first = True)
        self.activation = nn.Sequential(nn.Sigmoid(),
            nn.Softmax(1)
        )
    def forward(self, x):
        weights, hidden = self.attention(x)
        weights = torch.squeeze(weights)
        weights = self.activation(weights)
        weighted_vector = torch.sum(x*weights[:,:,None],1)
        return weighted_vector



class NeuralNetwork(nn.Module):
    def __init__(self, k = 768):
        super(NeuralNetwork, self).__init__()
        self.k = k
        self.attention = AttentionRnn(768,1)
        self.trading_strategy = nn.Sequential(
            nn.Linear(self.k+5,1)
        )
    def forward(self, x):
        returns, topic_vectors = torch.tensor_split(x,[1], dim = 2) #split into word signals and price signals
        weighted_vector = self.attention(topic_vectors)
        U,S,V = torch.pca_lowrank(weighted_vector,self.k)
        weighted_vector = torch.matmul(weighted_vector, V[:, :self.k])
        returns = torch.squeeze(returns)
        x = torch.cat((returns, weighted_vector), dim = 1)
        return torch.squeeze(self.trading_strategy(x))

class LSTM(nn.Module):
    def __init__(self, num_layers=1, dropout = 0):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(768+1,1, num_layers, dropout = dropout)
    def forward(self, x):
        output, hidden = self.lstm(x)
        output = torch.squeeze(output)
        return output

In [64]:
#classification

#attention network, can be used on a sequence of any legnth of word embeddings
class AttentionRnn(nn.Module):
    def __init__(self, dim_input, dim_output):
        super(AttentionRnn, self).__init__()
        self.attention = nn.RNN(dim_input, dim_output, batch_first = True)
        self.activation = nn.Sequential(nn.Sigmoid(),
            nn.Softmax(1)
        )
    def forward(self, x):
        weights, hidden = self.attention(x)
        weights = torch.squeeze(weights)
        weights = self.activation(weights)
        weighted_vector = torch.sum(x*weights[:,:,None],1)
        return weighted_vector



class NeuralNetwork(nn.Module):
    def __init__(self, k = 768):
        super(NeuralNetwork, self).__init__()
        self.k = k
        self.attention = AttentionRnn(768,1)
        self.trading_strategy = nn.Sequential(
            nn.Linear(self.k+5,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        returns, topic_vectors = torch.tensor_split(x,[1], dim = 2) #split into word signals and price signals
        weighted_vector = self.attention(topic_vectors)
        U,S,V = torch.pca_lowrank(weighted_vector,self.k)
        weighted_vector = torch.matmul(weighted_vector, V[:, :self.k])
        returns = torch.squeeze(returns)
        x = torch.cat((returns, weighted_vector), dim = 1)
        return torch.squeeze(self.trading_strategy(x))

class LSTM(nn.Module):
    def __init__(self, num_layers=1, dropout = 0):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(768+1,1, num_layers, dropout = dropout)
        self.trading_strategy = nn.Sigmoid()
    def forward(self, x):
        output, hidden = self.lstm(x)
        output = torch.squeeze(output)
        output = self.trading_strategy(output)
        return torch.squeeze(torch.tensor_split(output,[4], dim = 1)[1])

In [108]:
pred = model(torch.tensor(X_test).float())

In [4]:
#import btc price history
btc_prices = pd.read_csv("btc_prices.csv", index_col = 0)
btc_prices.index = pd.to_datetime(btc_prices.index).strftime('%Y-%m-%d')
btc_prices['log_ret'] = np.log(btc_prices.Close/btc_prices.Close.shift(1))
btc_ret = btc_prices[["log_ret"]].dropna()

#get the right daterange
btc_ret_train = btc_ret[(btc_ret.index >= '2016-01-01') & (btc_ret.index < '2021-11-01')].sort_index()

topic_vectors = pd.read_csv("bitcoin/btc_2016.csv", index_col = 0)
for i in [2017,2018,2019,2020, 2021]:
    topic_vectors = topic_vectors.append(pd.read_csv(f"bitcoin/btc_{i}.csv", index_col = 0))
topic_vectors.index = pd.date_range('2016-01-01', periods=len(topic_vectors)).strftime('%Y-%m-%d')
topic_vectors = btc_ret_train.join(topic_vectors, how= "left").dropna()

In [54]:
#import eth price history
eth_prices = pd.read_csv("eth_prices.csv", index_col = 0)
eth_prices.index = pd.to_datetime(eth_prices.index).strftime('%Y-%m-%d')
eth_prices['log_ret'] = np.log(eth_prices.Close/eth_prices.Close.shift(1))
eth_ret = btc_prices[["log_ret"]].dropna()

#get the right daterange
eth_ret_train = eth_ret[(eth_ret.index >= '2014-01-01') & (eth_ret.index < '2021-11-01')].sort_index()

topic_vectors = pd.read_csv("bitcoin/eth_2014.csv", index_col = 0)
for i in [2015, 2016, 2017,2018,2019,2020, 2021]:
    topic_vectors = topic_vectors.append(pd.read_csv(f"bitcoin/eth_{i}.csv", index_col = 0))
topic_vectors.index = pd.date_range('2014-01-01', periods=len(topic_vectors)).strftime('%Y-%m-%d')
topic_vectors = eth_ret_train.join(topic_vectors, how= "left").dropna()

In [55]:
#get the time series nicely streamlined into the machine learning models
rolling_window = 5

X_data = np.empty([len(topic_vectors)-rolling_window,rolling_window,769])
y_series_data = np.empty([len(topic_vectors)-rolling_window,rolling_window])
for i in range(0,len(topic_vectors)-rolling_window):
    X_data[i] = np.array(topic_vectors.iloc[i:i+rolling_window,:])
    y_series_data[i] = np.array(topic_vectors.iloc[i+1:i+1+rolling_window,0])
y_data = np.array(topic_vectors.iloc[rolling_window:,0])

train_cutoff = -3000
tune_cutoff = -304
test_cutoff = -304

X = X_data[:tune_cutoff]
X_tune = X_data[tune_cutoff:test_cutoff]
X_test = X_data[test_cutoff:]


y = y_data[:tune_cutoff]
y_tune = y_data[tune_cutoff:test_cutoff]
y_test = y_data[test_cutoff:]

y_series = y_series_data[:tune_cutoff]
y_series_test = y_series_data[test_cutoff:]

In [35]:
model_print = NeuralNetwork().to(device)
print(model_print)

NeuralNetwork(
  (attention): AttentionRnn(
    (attention): RNN(768, 1, batch_first=True)
    (activation): Sequential(
      (0): Sigmoid()
      (1): Softmax(dim=1)
    )
  )
  (trading_strategy): Sequential(
    (0): Linear(in_features=773, out_features=1, bias=True)
  )
)


In [6]:
def train_loop(dataloader, model, loss_fn, optimizer, new = True):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [37]:
#lstm
train_dataset = TensorDataset(torch.tensor(X).float(), torch.tensor(y_series).float())
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [61]:
#neural network
train_dataset = TensorDataset(torch.tensor(X).float(), torch.tensor(y>0).float())
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [100]:
def accuracy(pred, actual):
    p = torch.sign(pred)*torch.sign(actual)+1
    return -torch.sum(p)/len(pred)
accuracy(pred, torch.tensor(y_test))

tensor(-0.9000, dtype=torch.float64, grad_fn=<DivBackward0>)

In [9]:
#weighted BCELoss
def BCELoss_weighted(weights):
    def loss(pred, target):
        pred = torch.clamp(pred,min=1e-7,max=1-1e-7)
        bce = - weights[1] * target * torch.log(pred) - (1 - target) * weights[0] * torch.log(1 - pred)
        return torch.mean(bce)
    return loss

In [65]:
model = LSTM(1,0)
#model = NeuralNetwork(20)

learning_rate = 1e-3
batch_size = 64
epochs = 5000

loss_fn = BCELoss_weighted(weights = [2,1])
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    #test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1.088625  [    0/ 2288]
Epoch 2
-------------------------------
loss: 1.116024  [    0/ 2288]
Epoch 3
-------------------------------
loss: 0.865093  [    0/ 2288]
Epoch 4
-------------------------------
loss: 0.907727  [    0/ 2288]
Epoch 5
-------------------------------
loss: 1.033378  [    0/ 2288]
Epoch 6
-------------------------------
loss: 1.057966  [    0/ 2288]
Epoch 7
-------------------------------
loss: 0.999502  [    0/ 2288]
Epoch 8
-------------------------------
loss: 0.998392  [    0/ 2288]
Epoch 9
-------------------------------
loss: 1.076094  [    0/ 2288]
Epoch 10
-------------------------------
loss: 1.020231  [    0/ 2288]
Epoch 11
-------------------------------
loss: 0.957355  [    0/ 2288]
Epoch 12
-------------------------------
loss: 1.041401  [    0/ 2288]
Epoch 13
-------------------------------
loss: 1.008128  [    0/ 2288]
Epoch 14
-------------------------------
loss: 0.916916  [    0/ 2288]
Epoch 15
------

Epoch 117
-------------------------------
loss: 0.994820  [    0/ 2288]
Epoch 118
-------------------------------
loss: 1.042837  [    0/ 2288]
Epoch 119
-------------------------------
loss: 0.951338  [    0/ 2288]
Epoch 120
-------------------------------
loss: 1.105929  [    0/ 2288]
Epoch 121
-------------------------------
loss: 1.111961  [    0/ 2288]
Epoch 122
-------------------------------
loss: 1.068315  [    0/ 2288]
Epoch 123
-------------------------------
loss: 0.965871  [    0/ 2288]
Epoch 124
-------------------------------
loss: 0.970195  [    0/ 2288]
Epoch 125
-------------------------------
loss: 0.967078  [    0/ 2288]
Epoch 126
-------------------------------
loss: 1.052579  [    0/ 2288]
Epoch 127
-------------------------------
loss: 0.993761  [    0/ 2288]
Epoch 128
-------------------------------
loss: 1.002801  [    0/ 2288]
Epoch 129
-------------------------------
loss: 1.006675  [    0/ 2288]
Epoch 130
-------------------------------
loss: 0.965698  [    0

Epoch 231
-------------------------------
loss: 1.058293  [    0/ 2288]
Epoch 232
-------------------------------
loss: 1.022687  [    0/ 2288]
Epoch 233
-------------------------------
loss: 0.943968  [    0/ 2288]
Epoch 234
-------------------------------
loss: 1.064990  [    0/ 2288]
Epoch 235
-------------------------------
loss: 1.099748  [    0/ 2288]
Epoch 236
-------------------------------
loss: 0.949414  [    0/ 2288]
Epoch 237
-------------------------------
loss: 0.995448  [    0/ 2288]
Epoch 238
-------------------------------
loss: 0.991843  [    0/ 2288]
Epoch 239
-------------------------------
loss: 1.022288  [    0/ 2288]
Epoch 240
-------------------------------
loss: 0.990877  [    0/ 2288]
Epoch 241
-------------------------------
loss: 1.077326  [    0/ 2288]
Epoch 242
-------------------------------
loss: 1.067340  [    0/ 2288]
Epoch 243
-------------------------------
loss: 1.018142  [    0/ 2288]
Epoch 244
-------------------------------
loss: 1.014993  [    0

Epoch 345
-------------------------------
loss: 0.985238  [    0/ 2288]
Epoch 346
-------------------------------
loss: 1.078623  [    0/ 2288]
Epoch 347
-------------------------------
loss: 1.060055  [    0/ 2288]
Epoch 348
-------------------------------
loss: 0.994923  [    0/ 2288]
Epoch 349
-------------------------------
loss: 1.034075  [    0/ 2288]
Epoch 350
-------------------------------
loss: 0.963719  [    0/ 2288]
Epoch 351
-------------------------------
loss: 0.911780  [    0/ 2288]
Epoch 352
-------------------------------
loss: 0.929201  [    0/ 2288]
Epoch 353
-------------------------------
loss: 0.941136  [    0/ 2288]
Epoch 354
-------------------------------
loss: 0.978682  [    0/ 2288]
Epoch 355
-------------------------------
loss: 0.960282  [    0/ 2288]
Epoch 356
-------------------------------
loss: 0.976787  [    0/ 2288]
Epoch 357
-------------------------------
loss: 1.064275  [    0/ 2288]
Epoch 358
-------------------------------
loss: 1.039581  [    0

Epoch 459
-------------------------------
loss: 1.085937  [    0/ 2288]
Epoch 460
-------------------------------
loss: 0.995749  [    0/ 2288]
Epoch 461
-------------------------------
loss: 1.004991  [    0/ 2288]
Epoch 462
-------------------------------
loss: 1.033094  [    0/ 2288]
Epoch 463
-------------------------------
loss: 1.007615  [    0/ 2288]
Epoch 464
-------------------------------
loss: 0.974472  [    0/ 2288]
Epoch 465
-------------------------------
loss: 1.011706  [    0/ 2288]
Epoch 466
-------------------------------
loss: 1.052472  [    0/ 2288]
Epoch 467
-------------------------------
loss: 1.001195  [    0/ 2288]
Epoch 468
-------------------------------
loss: 1.002699  [    0/ 2288]
Epoch 469
-------------------------------
loss: 0.985582  [    0/ 2288]
Epoch 470
-------------------------------
loss: 0.985318  [    0/ 2288]
Epoch 471
-------------------------------
loss: 1.003945  [    0/ 2288]
Epoch 472
-------------------------------
loss: 0.973365  [    0

Epoch 573
-------------------------------
loss: 1.058908  [    0/ 2288]
Epoch 574
-------------------------------
loss: 1.039668  [    0/ 2288]
Epoch 575
-------------------------------
loss: 1.040924  [    0/ 2288]
Epoch 576
-------------------------------
loss: 1.032225  [    0/ 2288]
Epoch 577
-------------------------------
loss: 0.971727  [    0/ 2288]
Epoch 578
-------------------------------
loss: 0.944320  [    0/ 2288]
Epoch 579
-------------------------------
loss: 1.070398  [    0/ 2288]
Epoch 580
-------------------------------
loss: 0.945111  [    0/ 2288]
Epoch 581
-------------------------------
loss: 0.923074  [    0/ 2288]
Epoch 582
-------------------------------
loss: 0.991533  [    0/ 2288]
Epoch 583
-------------------------------
loss: 0.992252  [    0/ 2288]
Epoch 584
-------------------------------
loss: 0.996986  [    0/ 2288]
Epoch 585
-------------------------------
loss: 0.989182  [    0/ 2288]
Epoch 586
-------------------------------
loss: 0.982087  [    0

Epoch 687
-------------------------------
loss: 0.966995  [    0/ 2288]
Epoch 688
-------------------------------
loss: 0.952743  [    0/ 2288]
Epoch 689
-------------------------------
loss: 0.973718  [    0/ 2288]
Epoch 690
-------------------------------
loss: 0.992249  [    0/ 2288]
Epoch 691
-------------------------------
loss: 0.998071  [    0/ 2288]
Epoch 692
-------------------------------
loss: 0.958967  [    0/ 2288]
Epoch 693
-------------------------------
loss: 1.056455  [    0/ 2288]
Epoch 694
-------------------------------
loss: 0.895525  [    0/ 2288]
Epoch 695
-------------------------------
loss: 1.013605  [    0/ 2288]
Epoch 696
-------------------------------
loss: 1.018244  [    0/ 2288]
Epoch 697
-------------------------------
loss: 0.996193  [    0/ 2288]
Epoch 698
-------------------------------
loss: 0.967580  [    0/ 2288]
Epoch 699
-------------------------------
loss: 0.913203  [    0/ 2288]
Epoch 700
-------------------------------
loss: 1.059593  [    0

Epoch 801
-------------------------------
loss: 1.001582  [    0/ 2288]
Epoch 802
-------------------------------
loss: 0.981751  [    0/ 2288]
Epoch 803
-------------------------------
loss: 1.046747  [    0/ 2288]
Epoch 804
-------------------------------
loss: 1.034830  [    0/ 2288]
Epoch 805
-------------------------------
loss: 0.994395  [    0/ 2288]
Epoch 806
-------------------------------
loss: 1.011165  [    0/ 2288]
Epoch 807
-------------------------------
loss: 0.988502  [    0/ 2288]
Epoch 808
-------------------------------
loss: 1.016967  [    0/ 2288]
Epoch 809
-------------------------------
loss: 1.036567  [    0/ 2288]
Epoch 810
-------------------------------
loss: 0.971539  [    0/ 2288]
Epoch 811
-------------------------------
loss: 0.987615  [    0/ 2288]
Epoch 812
-------------------------------
loss: 1.045899  [    0/ 2288]
Epoch 813
-------------------------------
loss: 0.985252  [    0/ 2288]
Epoch 814
-------------------------------
loss: 1.036380  [    0

Epoch 915
-------------------------------
loss: 0.999659  [    0/ 2288]
Epoch 916
-------------------------------
loss: 1.002500  [    0/ 2288]
Epoch 917
-------------------------------
loss: 1.012191  [    0/ 2288]
Epoch 918
-------------------------------
loss: 0.990402  [    0/ 2288]
Epoch 919
-------------------------------
loss: 1.072209  [    0/ 2288]
Epoch 920
-------------------------------
loss: 0.984115  [    0/ 2288]
Epoch 921
-------------------------------
loss: 0.964237  [    0/ 2288]
Epoch 922
-------------------------------
loss: 0.901138  [    0/ 2288]
Epoch 923
-------------------------------
loss: 0.993077  [    0/ 2288]
Epoch 924
-------------------------------
loss: 0.969229  [    0/ 2288]
Epoch 925
-------------------------------
loss: 0.919272  [    0/ 2288]
Epoch 926
-------------------------------
loss: 1.008821  [    0/ 2288]
Epoch 927
-------------------------------
loss: 1.016989  [    0/ 2288]
Epoch 928
-------------------------------
loss: 0.971903  [    0

Epoch 1029
-------------------------------
loss: 1.065907  [    0/ 2288]
Epoch 1030
-------------------------------
loss: 1.003676  [    0/ 2288]
Epoch 1031
-------------------------------
loss: 1.008010  [    0/ 2288]
Epoch 1032
-------------------------------
loss: 0.947060  [    0/ 2288]
Epoch 1033
-------------------------------
loss: 0.988935  [    0/ 2288]
Epoch 1034
-------------------------------
loss: 0.897778  [    0/ 2288]
Epoch 1035
-------------------------------
loss: 1.014230  [    0/ 2288]
Epoch 1036
-------------------------------
loss: 1.015967  [    0/ 2288]
Epoch 1037
-------------------------------
loss: 1.027778  [    0/ 2288]
Epoch 1038
-------------------------------
loss: 0.985334  [    0/ 2288]
Epoch 1039
-------------------------------
loss: 0.994260  [    0/ 2288]
Epoch 1040
-------------------------------
loss: 1.010096  [    0/ 2288]
Epoch 1041
-------------------------------
loss: 1.019725  [    0/ 2288]
Epoch 1042
-------------------------------
loss: 1.

Epoch 1142
-------------------------------
loss: 0.993127  [    0/ 2288]
Epoch 1143
-------------------------------
loss: 0.990590  [    0/ 2288]
Epoch 1144
-------------------------------
loss: 1.004173  [    0/ 2288]
Epoch 1145
-------------------------------
loss: 1.001419  [    0/ 2288]
Epoch 1146
-------------------------------
loss: 1.017468  [    0/ 2288]
Epoch 1147
-------------------------------
loss: 1.042589  [    0/ 2288]
Epoch 1148
-------------------------------
loss: 0.988809  [    0/ 2288]
Epoch 1149
-------------------------------
loss: 0.985773  [    0/ 2288]
Epoch 1150
-------------------------------
loss: 1.060427  [    0/ 2288]
Epoch 1151
-------------------------------
loss: 1.063312  [    0/ 2288]
Epoch 1152
-------------------------------
loss: 1.025308  [    0/ 2288]
Epoch 1153
-------------------------------
loss: 1.063064  [    0/ 2288]
Epoch 1154
-------------------------------
loss: 1.037332  [    0/ 2288]
Epoch 1155
-------------------------------
loss: 0.

Epoch 1255
-------------------------------
loss: 0.936812  [    0/ 2288]
Epoch 1256
-------------------------------
loss: 0.982113  [    0/ 2288]
Epoch 1257
-------------------------------
loss: 0.984138  [    0/ 2288]
Epoch 1258
-------------------------------
loss: 1.023487  [    0/ 2288]
Epoch 1259
-------------------------------
loss: 0.882357  [    0/ 2288]
Epoch 1260
-------------------------------
loss: 1.110441  [    0/ 2288]
Epoch 1261
-------------------------------
loss: 1.036829  [    0/ 2288]
Epoch 1262
-------------------------------
loss: 0.946092  [    0/ 2288]
Epoch 1263
-------------------------------
loss: 1.011773  [    0/ 2288]
Epoch 1264
-------------------------------
loss: 0.914486  [    0/ 2288]
Epoch 1265
-------------------------------
loss: 0.943740  [    0/ 2288]
Epoch 1266
-------------------------------
loss: 0.984843  [    0/ 2288]
Epoch 1267
-------------------------------
loss: 0.946403  [    0/ 2288]
Epoch 1268
-------------------------------
loss: 1.

Epoch 1368
-------------------------------
loss: 0.970534  [    0/ 2288]
Epoch 1369
-------------------------------
loss: 0.948792  [    0/ 2288]
Epoch 1370
-------------------------------
loss: 0.975483  [    0/ 2288]
Epoch 1371
-------------------------------
loss: 0.991192  [    0/ 2288]
Epoch 1372
-------------------------------
loss: 1.012436  [    0/ 2288]
Epoch 1373
-------------------------------
loss: 0.974122  [    0/ 2288]
Epoch 1374
-------------------------------
loss: 1.068022  [    0/ 2288]
Epoch 1375
-------------------------------
loss: 0.987842  [    0/ 2288]
Epoch 1376
-------------------------------
loss: 0.904303  [    0/ 2288]
Epoch 1377
-------------------------------
loss: 0.970708  [    0/ 2288]
Epoch 1378
-------------------------------
loss: 0.959809  [    0/ 2288]
Epoch 1379
-------------------------------
loss: 0.950748  [    0/ 2288]
Epoch 1380
-------------------------------
loss: 1.042605  [    0/ 2288]
Epoch 1381
-------------------------------
loss: 1.

Epoch 1481
-------------------------------
loss: 1.022448  [    0/ 2288]
Epoch 1482
-------------------------------
loss: 1.035586  [    0/ 2288]
Epoch 1483
-------------------------------
loss: 1.026288  [    0/ 2288]
Epoch 1484
-------------------------------
loss: 0.983126  [    0/ 2288]
Epoch 1485
-------------------------------
loss: 0.959079  [    0/ 2288]
Epoch 1486
-------------------------------
loss: 0.972712  [    0/ 2288]
Epoch 1487
-------------------------------
loss: 0.950396  [    0/ 2288]
Epoch 1488
-------------------------------
loss: 1.011738  [    0/ 2288]
Epoch 1489
-------------------------------
loss: 1.002308  [    0/ 2288]
Epoch 1490
-------------------------------
loss: 1.030509  [    0/ 2288]
Epoch 1491
-------------------------------
loss: 0.919948  [    0/ 2288]
Epoch 1492
-------------------------------
loss: 0.995896  [    0/ 2288]
Epoch 1493
-------------------------------
loss: 0.995702  [    0/ 2288]
Epoch 1494
-------------------------------
loss: 0.

Epoch 1594
-------------------------------
loss: 1.016388  [    0/ 2288]
Epoch 1595
-------------------------------
loss: 0.994944  [    0/ 2288]
Epoch 1596
-------------------------------
loss: 0.922278  [    0/ 2288]
Epoch 1597
-------------------------------
loss: 0.991130  [    0/ 2288]
Epoch 1598
-------------------------------
loss: 1.040013  [    0/ 2288]
Epoch 1599
-------------------------------
loss: 1.001566  [    0/ 2288]
Epoch 1600
-------------------------------
loss: 0.928652  [    0/ 2288]
Epoch 1601
-------------------------------
loss: 0.985609  [    0/ 2288]
Epoch 1602
-------------------------------
loss: 0.993377  [    0/ 2288]
Epoch 1603
-------------------------------
loss: 1.019310  [    0/ 2288]
Epoch 1604
-------------------------------
loss: 0.942700  [    0/ 2288]
Epoch 1605
-------------------------------
loss: 0.947492  [    0/ 2288]
Epoch 1606
-------------------------------
loss: 1.005814  [    0/ 2288]
Epoch 1607
-------------------------------
loss: 1.

KeyboardInterrupt: 

In [16]:
x_test = torch.tensor(X_test).float()
y_value = y

In [28]:
x_test = torch.tensor(X_test).float()
y_value = y_test

In [68]:
#score, single output, classification
x_test = torch.tensor(X).float()
pred = model(x_test).detach().numpy()
score_model(pred>0.5, y>0)

Unnamed: 0,Unnamed: 1
accuracy,0.453671
matthew_corr,-0.027013
f1,0.153117
tn,0.404283
tp,0.049388
fp,0.498252
fn,0.048077


In [69]:
sum(pred>0.5)

223

In [70]:
#score, single output, classification
x_test = torch.tensor(X_test).float()
pred = model(x_test).detach().numpy()
score_model(pred>0.5, y_test>0)

Unnamed: 0,Unnamed: 1
accuracy,0.476974
matthew_corr,-0.009346
f1,0.070175
tn,0.457237
tp,0.019737
fp,0.503289
fn,0.019737


In [63]:
#score, single output, prediction
x_test = torch.tensor(X_test).float()
pred = model(x_test).detach().numpy()
score_model(pred>0, y_test>0)

Unnamed: 0,Unnamed: 1
accuracy,0.523026
matthew_corr,0.0
f1,0.686825
tn,0.0
tp,0.523026
fp,0.0
fn,0.476974


In [59]:
#lstm score, sequence output, prediction
x_test = torch.tensor(X_test).float()
pred = torch.squeeze(torch.tensor_split(model(x_test),[4], dim = 1)[1]).detach().numpy()
score_model(pred>0, y_test>0)

Unnamed: 0,Unnamed: 1
accuracy,0.532895
matthew_corr,0.080379
f1,0.477941
tn,0.319079
tp,0.213816
fp,0.309211
fn,0.157895


In [100]:
torch.save(model.state_dict(), "LSTM_weights_ETH.pt")

In [72]:
model.load_state_dict(torch.load("LSTM_weights.pt"))

<All keys matched successfully>

In [130]:
pred

array([-0.01067148,  0.00631406,  0.00704519,  0.03657668, -0.05143879,
        0.02181377,  0.04506073, -0.01136261,  0.01926613,  0.16838945,
        0.03360152,  0.05642891,  0.03385413,  0.00406338,  0.29805204,
        0.08665578,  0.01160478,  0.00852502, -0.05063779, -0.03675518,
        0.29312295,  0.03495735,  0.06120349, -0.00055414,  0.0021507 ,
        0.01060115,  0.043214  ,  0.12682067,  0.07407265,  0.00692551,
       -0.00707689, -0.09045514,  0.00346867, -0.03288126,  0.00295664,
        0.05423941,  0.00290809,  0.03392464,  0.08043066,  0.01311603,
        0.12408525, -0.10385028, -0.15686685,  0.00627755,  0.10512847,
        0.02558482,  0.05249054,  0.24260618,  0.02013379,  0.18393825,
        0.05151614,  0.03678934,  0.02269081, -0.02488468, -0.0233004 ,
        0.00809217, -0.00168681, -0.04773178, -0.01354207,  0.07085342,
        0.07419498,  0.21423754,  0.03976449,  0.30776474,  0.01885422,
       -0.00775475, -0.00070749, -0.10312967, -0.03355606, -0.00