#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

#### Define Dataset Class

In [2]:
from torch.utils.data import Dataset
import joblib

class JUND_Dataset(Dataset):
    def __init__(self, data_dir):
        '''load X, y, w, a from data_dir'''        
        super(JUND_Dataset, self).__init__()

        # load X, y, w, a from given data_dir
        # convert them into torch tensors
        self.path = os.path.join('.', data_dir)
        self.X = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-X.joblib')), dtype=torch.float)
        self.y = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-y.joblib')), dtype=torch.float)
        self.w = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-w.joblib')), dtype=torch.float)
        self.a = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-a.joblib')), dtype=torch.float)

    def __len__(self):
        '''return len of dataset'''
        return len(self.X)

    def __getitem__(self, idx):
        '''return X, y, w, and a values at index idx'''
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]

#### Define Model

In [3]:
class LSTM(nn.Module):
    def __init__(self, in_dim, lstm_hidden_dim, mlp_hidden_dim, out_dim):
        '''in_dim: input layer dim
           hidden_dim: hidden layer dim
           out_dim: output layer dim'''
        
        super(LSTM, self).__init__()
        
        # define the LSTM layer
        self.lstm = nn.LSTM(input_size=in_dim, hidden_size=lstm_hidden_dim, batch_first=True)
        
        # fully connected layers
        self.fc1 = nn.Linear(lstm_hidden_dim, mlp_hidden_dim)
        
        # dropout layer to regularize
        self.dp = nn.Dropout()
        
        # add one since we have accessibility value
        self.fc2 = nn.Linear(mlp_hidden_dim + 1, out_dim)
        
    def forward(self, x, a):
        # feed the data into the LSTM RNN
        # hn is the last hidden layer: torch.Size([1, 1000, 128])
        x, (hn, cn) = self.lstm(x)

        # use the last hidden layer of LSTM RNN as input of mlp
        x = F.relu(self.fc1(hn[0]))
        
        # add the dropout
        x = self.dp(x)
        
        # concatenate the accessibility value
        x = torch.cat((a, x), 1)
        
        # compute output layer
        x = self.fc2(x)
        return x

#### Set Up Training

In [4]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
batch_size = 500
learning_rate = 0.08
epochs = 20

# set model and optimizer
# DNA bases has one hot vector with 101x4
# use 128d hidden layer
# output is 1d since there is only 0 and 1 classe
lstm_hidden_dim = 128
mlp_hidden_dim = 256
model = LSTM(4, lstm_hidden_dim, mlp_hidden_dim, 1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# load training data in batches
train_loader = du.DataLoader(dataset=JUND_Dataset('train_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# send model over to device
model = model.to(device)
model.train()

using device: cuda:0


LSTM(
  (lstm): LSTM(4, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (dp): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=257, out_features=1, bias=True)
)

#### Training Loop Over Batches

In [5]:
for epoch in range(1, epochs + 1):    
    sum_loss = 0.
    for batch_idx, (X, y, w, a) in enumerate(train_loader):
        # send batch over to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # zero out prev gradients
        optimizer.zero_grad()
        
        # run the forward pass
        output = model(X, a)
        
        # compute loss/error
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch losses
        sum_loss += loss.item()
        
        # compute gradients and take a step
        loss.backward()
        optimizer.step()
    
    # average loss per example
    sum_loss /= len(train_loader)
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}')

Epoch: 1, Loss: 1.024034
Epoch: 2, Loss: 0.655439
Epoch: 3, Loss: 0.633099
Epoch: 4, Loss: 0.615557
Epoch: 5, Loss: 0.602025
Epoch: 6, Loss: 0.590489
Epoch: 7, Loss: 0.579700
Epoch: 8, Loss: 0.570996
Epoch: 9, Loss: 0.565364
Epoch: 10, Loss: 0.558405
Epoch: 11, Loss: 0.554096
Epoch: 12, Loss: 0.550191
Epoch: 13, Loss: 0.547394
Epoch: 14, Loss: 0.544297
Epoch: 15, Loss: 0.540657
Epoch: 16, Loss: 0.540469
Epoch: 17, Loss: 0.535822
Epoch: 18, Loss: 0.537230
Epoch: 19, Loss: 0.534035
Epoch: 20, Loss: 0.534040


#### Validation

In [6]:
# load the validation data for validation
valid_loader = du.DataLoader(dataset=JUND_Dataset('valid_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# set model in eval mode, since we are no longer training
model.eval()
valid_loss = 0
correct = 0
weight = 0

# turn off gradient computation, will speed up testing
with torch.no_grad():
    for batch_idx, (X, y, w, a) in enumerate(valid_loader):
        # send batches to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # compute forward pass and loss
        output = model(X, a)
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch loss
        valid_loss += loss.item()
        
        # compute the validation accuracy
        pred = torch.clone(output)
        pred = torch.sigmoid(pred)
        pred[pred < float(0.5)] = 0
        pred[pred != 0] = 1
        correct += torch.sum((pred == y) * w)
        weight += torch.sum(w)
    
    # valid loss per example
    valid_loss /= len(valid_loader.dataset)
    
    # final test accuracy
    valid_acc = correct / weight
    
    print(f'Valid loss: {valid_loss:.6f}, valid accuracy: {valid_acc:.4f}')

Valid loss: 0.001063, valid accuracy: 0.7392


#### Testing

In [7]:
# load the validation data for validation
test_loader = du.DataLoader(dataset=JUND_Dataset('test_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# set model in eval mode, since we are no longer training
model.eval()
test_loss = 0
correct = 0
weight = 0

# turn off gradient computation, will speed up testing
with torch.no_grad():
    for batch_idx, (X, y, w, a) in enumerate(test_loader):
        # send batches to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # compute forward pass and loss
        output = model(X, a)
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch loss
        test_loss += loss.item()
        
        # compute the validation accuracy
        pred = torch.clone(output)
        pred = torch.sigmoid(pred)
        pred[pred < float(0.5)] = 0
        pred[pred != 0] = 1
        correct += torch.sum((pred == y) * w)
        weight += torch.sum(w)
    
    # valid loss per example
    test_loss /= len(test_loader.dataset)
    
    # final test accuracy
    test_acc = correct / weight
    
    print(f'Test loss: {test_loss:.6f}, test accuracy: {test_acc:.4f}')

Test loss: 0.001097, test accuracy: 0.7336
