In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
import pandas as pd

In [3]:
import os

In [5]:
batch_size = 16
device = 'cuda:0'

# Conv + LSTM 

In [6]:
class ConvLSTM(nn.Module):
    def __init__(self, n_features, n_hidden, seq_len, n_layers):
        super(ConvLSTM, self).__init__()
        self.n_hidden = n_hidden
        self.seq_len = seq_len
        self.n_layers = n_layers
        self.c1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size = 2, stride = 1) # Add a 1D CNN layer
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers
        )
        self.linear = nn.Linear(in_features=n_hidden, out_features=1)
    def reset_hidden_state(self):
        self.hidden = (
            torch.zeros(self.n_layers, self.seq_len-1, self.n_hidden),
            torch.zeros(self.n_layers, self.seq_len-1, self.n_hidden)
        )
    def forward(self, sequences):
        sequences = self.c1(sequences.view(len(sequences), 1, -1))
        lstm_out, self.hidden = self.lstm(
            sequences.view(len(sequences), self.seq_len-1, -1),
            self.hidden
        )
        last_time_step = lstm_out.view(self.seq_len-1, len(sequences), self.n_hidden)[-1]
        y_pred = self.linear(last_time_step)
        return y_pred

# Read Data

In [8]:
#train data
virus_data = []
for idx, file in enumerate(os.listdir('data/training/')):
    array = pd.read_pickle(os.path.join('data/training',file))
    for i in array:
        virus_data.append([idx,i])

#val data

# virus_data_val = []
# for idx, file in enumerate(os.listdir('data/validation/')):
#     array = pd.read_pickle(os.path.join('data/validation',file))
#     for i in array:
#         virus_data_val.append([idx,i])

In [9]:
max_len = max([len(i[1]) for i in virus_data])

In [10]:
for idx,data in enumerate(virus_data):
    virus_data[idx][1].extend([-1]*(max_len - len(virus_data[idx][1])))


# Dataset

In [11]:
class VirusDataset(Dataset):
    def __init__(self,virus_data):
        self.data = [i[1] for i in virus_data]
        self.label = [i[0] for i in virus_data]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return torch.Tensor(self.data[idx]),self.label[idx]

In [12]:
train_dataset = VirusDataset(virus_data)
# val_dataset = VirusDataset(virus_data_val)

In [13]:
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True)

# Training

In [14]:
model = ConvLSTM(n_features = 3 , n_hidden = 64, seq_len = max_len, n_layers = 2)

In [13]:
for idx,data in enumerate(train_loader):
    model.train()
    seq, label = data
    seq, label = seq.to(device),label.to(device)
    predict = model

RuntimeError: stack expects each tensor to be equal size, but got [60198] at entry 0 and [60111] at entry 1

In [15]:
for i in train_dataset:
    print(len(i[0]))

30698
32118
31966
31966
31966
31966
41130
31957
32118
40645
41134
32464
30783
31966
30832
41085
30874
31966
41074
41074
31966
32464
30969
31966
32118
41146
41106
30698
41134
31954
40645
31966
31966
34795
40645
41087
31966
31966
32464
30918
34795
31966
40645
40645
31201
41074
30873
30988
31966
31936
40645
31966
40645
40645
31966
31966
31966
41075
31966
31966
31966
40645
31201
41087
40917
41104
31966
30713
31966
31966
32421
41118
31966
40645
31834
40645
41134
32464
31966
41078
31966
31966
40645
31966
31966
40645
31966
30783
41131
41086
31966
41132
31966
31966
40645
31936
31966
40648
31966
41134
31966
31513
31966
31966
31966
31966
31966
31966
31966
40645
41087
41130
31966
30941
31966
40831
40645
30881
32464
30916
40645
40938
31948
32421
31966
31966
40831
31205
41130
31513
30950
31966
41074
40840
30783
30778
40645
31690
31966
40642
31966
31966
40645
41057
31966
31966
30832
30522
30936
31966
31966
31966
31966
31966
30960
31966
31966
31966
31966
32118
40645
30677
31966
34795
41134
30932
4064

40986
31966
40645
31121
40645
31966
31966
31966
31966
41204
31966
31966
31966
31778
41130
31966
31966
41133
30782
30763
41085
40645
30874
31966
32118
40645
31966
32021
34795
31966
31768
31966
31966
32251
31966
40645
30886
40645
30937
41148
31966
41085
41100
32464
31030
31966
31966
30937
31966
40907
31966
31966
31966
41085
31966
41078
31966
31966
32464
31966
31966
31966
30781
40645
30884
31966
30551
31966
31585
30905
41074
30874
31696
31477
41130
30937
31966
31966
30734
31966
31966
31966
31951
40645
41100
31966
31966
31966
31965
31967
40645
31513
41076
40645
41074
31966
30770
31243
31270
30713
30847
31774
31966
31966
30874
31966
30998
31966
31966
30883
31966
30838
40644
30915
41130
31966
30922
41079
31966
34795
31039
41086
41130
31966
31954
32803
30697
31103
31966
40645
30697
34795
32421
41142
31966
31966
40645
31966
31513
31966
31966
31966
40645
41085
31318
30866
30870
32464
31966
31966
41075
31966
30887
40645
30921
31966
40803
31966
41130
41074
30847
40645
30870
40645
41088
31966
4107

49368
49402
49352
31068
48732
49209
49440
49228
49074
30529
49172
30517
49440
30513
49308
49361
49427
32512
49153
49352
49402
49405
49375
49425
38608
48955
49414
30518
49352
49077
49308
48819
49311
49308
49364
49391
48732
49403
49352
49439
49352
49295
49403
38608
49283
48732
48921
49347
49324
49352
49394
49352
49229
32511
30540
49217
49308
49425
49317
49298
49273
49386
30501
31019
30506
49438
49362
49353
42951
30541
49439
49352
49413
49308
49354
49173
49302
49401
49220
49362
30746
49372
49440
49352
49282
49417
49351
31851
30818
49439
49361
49308
49363
49440
49245
49396
30501
49282
48920
49353
48556
49362
49406
30499
49435
30661
49341
49291
48075
49230
49433
30746
49240
49220
49284
49296
49438
49395
49369
38608
49361
49390
49295
48035
49352
49294
49267
48732
32701
49378
49074
32654
48737
49436
49439
49352
49207
49352
49352
49228
49247
48823
49423
49356
48963
48845
49352
49295
49007
49254
48885
49352
49341
49290
30528
49356
49362
49362
49309
49021
49229
49285
49359
49413
49440
30541
4941

49372
49109
49350
32841
31099
49074
49307
49373
49352
49353
49394
49424
49372
49392
48732
49291
48732
30499
49391
49362
49308
49440
49074
49179
49353
49440
49439
48732
49370
30499
49352
49401
48732
49074
49440
49352
49362
31099
49218
32512
38608
49322
49422
49440
49440
49352
49393
49423
49429
30500
49440
49370
49355
30964
49320
49284
49076
49364
30551
49364
30506
49417
32518
49075
49163
49423
49413
30504
49226
49334
49343
49404
38608
49412
49308
49233
49355
38608
49393
49392
49293
30505
49291
48732
48135
49295
37263
49233
49394
49303
49323
49227
49440
49356
49349
49440
49366
49394
49364
49294
49277
48823
49382
49091
49293
49352
49338
30501
30818
49417
49393
49428
49114
49425
30488
49440
30501
49257
49074
49426
49367
49352
49292
49227
49362
49428
49317
49389
30541
49232
48732
49283
49413
49394
49362
30506
30499
49429
49392
49352
49364
49297
49359
49070
49075
49285
30746
49395
49439
49360
49393
49075
49248
49372
49283
49440
49284
49378
49362
49384
38608
49292
49387
49351
49296
49440
4942

31588
30789
32374
31588
32711
31588
30817
30666
32711
31588
31837
30795
30842
31018
31057
31015
31837
30666
31015
32711
32374
32374
31837
31588
30774
31588
31543
31585
30770
31837
30638
32711
30817
30851
32711
31015
31588
32711
30810
32374
32711
31837
31018
31837
32711
30837
31837
31543
30697
32711
30774
31015
30774
32711
31837
31015
31588
30718
31015
30770
31588
32711
30817
30804
30851
30774
30697
31837
32711
31588
32711
32711
30736
31837
32711
31588
32374
30810
31015
32711
30795
30851
31588
31015
30615
32711
32359
31588
32711
31018
31015
30638
30842
31837
31015
31837
31837
31588
30795
30786
30810
30770
31837
30842
31588
31837
31015
31588
32711
31015
31837
32374
31588
32353
30842
30842
30770
31588
32374
32711
32711
30725
31543
31837
31018
30914
31837
31588
31837
30816
30775
30638
32711
31588
30795
32711
31057
32374
32711
31015
30819
31588
31444
32711
30792
31015
31588
31588
30770
32374
32711
31837
30732
32536
30936
31588
31588
31837
30770
31054
32711
39283
30665
31837
30842
32711
3183

31588
32374
31057
31537
30779
30770
30936
31044
32374
31588
31588
31588
30810
30638
31588
32711
30697
32711
31018
31588
30774
32374
31837
31588
31837
31586
31588
31015
31594
30813
30638
30817
31018
32711
30770
30817
38911
31588
30703
31015
32711
30795
31588
31588
31588
32711
30817
30936
32374
31684
32711
32711
30769
31018
31588
31018
39958
30842
31588
30810
31588
32374
31588
32353
31837
31837
30810
31588
31579
30665
30774
31015
32711
31837
30898
30842
31837
30936
30643
31837
32711
30778
30796
32711
32711
32711
31588
32711
30842
32711
31588
31588
31837
31586
31588
30789
31837
31837
30638
31015
31044
31110
30732
31837
31588
32374
31588
31588
31837
32374
31018
30936
30619
31492
30719
31588
31588
30775
30903
32711
30780
31837
30817
30781
31588
31588
31018
31837
30862
30666
30670
31588
31837
31837
40135
30619
31837
31588
32711
30810
30842
31837
31015
32711
30862
31588
31540
31015
31018
32711
30817
32374
30770
31837
30655
32359
30795
30936
32374
30729
31588
31837
31561
31837
30816
30770
3077

32138
32047
32632
31319
31463
31978
31371
32632
32182
31891
32182
32182
31947
31319
31508
32182
31978
31891
31891
31978
32756
31319
31891
32243
32761
32182
32243
31883
31978
32822
32822
32822
32761
31891
31948
32243
32047
31463
32243
32822
31371
32182
31319
32755
32714
32182
32182
31891
32010
32822
32714
32182
31463
32243
31978
32243
31371
31978
31344
31978
32182
31463
32182
31891
32714
32182
31508
32182
32176
31948
31891
31463
32755
31891
31978
32047
32632
31891
31319
32714
31371
32182
31463
32182
31978
31978
32182
31891
31947
31891
32243
31463
31319
32243
31891
31463
31463
32182
31468
32179
31319
32632
31371
31947
32755
31319
32182
32761
32182
32047
32714
32182
31371
31978
31371
32822
31978
31371
32822
32761
32761
32714
31319
32243
31463
32811
31948
32182
32182
32047
32632
32761
31463
31891
31891
32243
32822
31948
32243
31371
31891
31978
31891
31319
32182
31891
32755
31978
32761
31856
31468
31891
32761
31371
32182
32243
31978
32632
32761
31344
31319
32632
31463
31465
32632
32243
3276

152020
152047
90963
121452
125986
34543
89425
31096
65024
91066
61209
121566
67635
152020
243463
61193
60604
121539
94083
121361
61323
91064
62008
125986
156467
121539
60332
34543
91058
91085
60479
62008
65024
30721
92511
121473
60399
92149
125986
91073
31527
90793
65024
35481
91080
95505
60593
60571
122970
92489
91085
60307
152046
34543
91074
91037
91058
60604
30883
91075
128034
91084
38153
61130
125986
34543
61598
152011
60604
90336
60604
60600
91690
61195
61577
152047
61191
92685
60597
62008
32338
92489
92184
34543
121539
31168
95505
60536
31078
91874
122531
34543
109988
121566
90971
65024
124112
121566
91081
31723
31268
60562
62008
60596
65024
93943
61540
90880
151928
182516
122539
213009
65024
61577
123957
121566
30897
61369
121566
91080
91085
61540
62230
121557
92489
69220
156467
92487
121274
122539
97612
122825
121554
31527
62008
91085
33152
91085
151979
109943
31347
91001
90863
65024
65024
31096
121566
65024
90742
33855
91663
90880
92487
122531
90717
61364
121511
65024
31527
12

KeyboardInterrupt: 

In [16]:
max_len

30483