In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [3]:
import pandas as pd

In [4]:
import os

In [22]:
batch_size = 4
device = 'cuda:0'

# Conv + LSTM 

In [6]:
class ConvLSTM(nn.Module):
    def __init__(self, n_features, n_hidden, seq_len, n_layers):
        super(ConvLSTM, self).__init__()
        self.n_hidden = n_hidden
        self.seq_len = seq_len
        self.n_layers = n_layers
        self.c1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size = 2, stride = 1) # Add a 1D CNN layer
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers
        )
        self.linear = nn.Linear(in_features=n_hidden, out_features=1)
    def reset_hidden_state(self):
        self.hidden = (
            torch.zeros(self.n_layers, self.seq_len-1, self.n_hidden),
            torch.zeros(self.n_layers, self.seq_len-1, self.n_hidden)
        )
    def forward(self, sequences):
        sequences = self.c1(sequences.view(len(sequences), 1, -1))
        self.hidden = self.reset_hidden_state()
        lstm_out, self.hidden = self.lstm(
            sequences.view(len(sequences), self.seq_len-1, -1),
            self.hidden
        )
        last_time_step = lstm_out.view(self.seq_len-1, len(sequences), self.n_hidden)[-1]
        y_pred = self.linear(last_time_step)
        return y_pred

# Read Data

In [7]:
#train data
virus_data = []
for idx, file in enumerate(os.listdir('data/training/')):
    array = pd.read_pickle(os.path.join('data/training',file))
    for i in array:
        virus_data.append([idx,i])

#val data

# virus_data_val = []
# for idx, file in enumerate(os.listdir('data/validation/')):
#     array = pd.read_pickle(os.path.join('data/validation',file))
#     for i in array:
#         virus_data_val.append([idx,i])

In [8]:
max_len = max([len(i[1]) for i in virus_data])

In [9]:
for idx,data in enumerate(virus_data):
    virus_data[idx][1].extend([-1]*(max_len - len(virus_data[idx][1])))


# Dataset

In [10]:
class VirusDataset(Dataset):
    def __init__(self,virus_data):
        self.data = [i[1] for i in virus_data]
        self.label = [i[0] for i in virus_data]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return torch.Tensor(self.data[idx]),self.label[idx]

In [11]:
train_dataset = VirusDataset(virus_data)
# val_dataset = VirusDataset(virus_data_val)

In [23]:
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True)

# Training

In [13]:
model = ConvLSTM(n_features = 1 , n_hidden = 64, seq_len = max_len, n_layers = 2)

In [14]:
model.to(device)

ConvLSTM(
  (c1): Conv1d(1, 1, kernel_size=(2,), stride=(1,))
  (lstm): LSTM(1, 64, num_layers=2)
  (linear): Linear(in_features=64, out_features=1, bias=True)
)

In [15]:
learning_rate = 0.00001
num_epochs = 20
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [20]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device).float()  # Convert targets to Float data type

        # Forward pass
        outputs = model(inputs).squeeze()
#         outputs.squeeze()
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total += targets.size(0)
        correct += ((outputs > 0.5) == targets).sum().item()

    return running_loss / (batch_idx + 1), correct / total

In [24]:
for epoch in range(num_epochs):
    train_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
