#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

#### Define Dataset Class

In [2]:
from torch.utils.data import Dataset
import joblib

class JUND_Dataset(Dataset):
    def __init__(self, data_dir):
        '''load X, y, w, a from data_dir'''        
        super(JUND_Dataset, self).__init__()

        # load X, y, w, a from given data_dir
        # convert them into torch tensors
        self.path = os.path.join('.', data_dir)
        self.X = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-X.joblib')), dtype=torch.float)
        self.y = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-y.joblib')), dtype=torch.float)
        self.w = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-w.joblib')), dtype=torch.float)
        self.a = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-a.joblib')), dtype=torch.float)

    def __len__(self):
        '''return len of dataset'''
        return len(self.X)

    def __getitem__(self, idx):
        '''return X, y, w, and a values at index idx'''
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]

#### Define Model

In [3]:
class CNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, mlp_hidden_dim, out_dim):
        
        super(CNN, self).__init__()
        
        # define the CNN layer
        self.cnn = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        
        # dropout layer to regularize
        self.dp1 = nn.Dropout()
        
        # define the MaxPooling Layer
        self.maxpool = nn.MaxPool1d(kernel_size, stride, padding)
        
        # flatten (layer L + 2 * padding - 1) / stride + 1 d vec
        self.flatten = nn.Flatten()
        
        # define the fully connected layer
        # calculate the dim of output of CNN
        in_dim = int((101 + 2 * padding - kernel_size) / stride) + 1
        in_dim *= out_channels
        self.fc1 = nn.Linear(in_dim, mlp_hidden_dim)
        
        # another dropout layer
        self.dp2 = nn.Dropout()
        
        # the last fully connected layer
        self.fc2 = nn.Linear(mlp_hidden_dim + 1, out_dim)
        
    def forward(self, x, a):
        # swap the axis
        x = torch.swapaxes(x, 1, 2)
        
        # feed the data to CNN
        x = self.cnn(x)
        
        # apply relu
        x = F.relu(x)
        
        # add the dropout
        x = self.dp1(x)
        
        # MaxPool layer
        x = self.maxpool(x)
        
        # flatten out the last CNN layer
        x = self.flatten(x)
        
        # apply relu again
        x = F.relu(self.fc1(x))
        
        # add another dropout
        x = self.dp2(x)
        
        # concatenate the accessibility value
        x = torch.cat((a, x), 1)
        
        # compute output layer
        x = self.fc2(x)
        return x

#### Set Up Training

In [4]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
batch_size = 500
learning_rate = 0.05
epochs = 30

out_channels = 100
kernel_size = 5
stride = 1
padding = 2
mlp_hidden_dim = 256
model = CNN(4, out_channels, kernel_size, stride, padding, mlp_hidden_dim, 1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# load training data in batches
train_loader = du.DataLoader(dataset=JUND_Dataset('train_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# send model over to device
model = model.to(device)
model.train()

using device: cuda:0


CNN(
  (cnn): Conv1d(4, 100, kernel_size=(5,), stride=(1,), padding=(2,))
  (dp1): Dropout(p=0.5, inplace=False)
  (maxpool): MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=10100, out_features=256, bias=True)
  (dp2): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=257, out_features=1, bias=True)
)

#### Training Loop Over Batches

In [5]:
for epoch in range(1, epochs + 1):    
    sum_loss = 0.
    for batch_idx, (X, y, w, a) in enumerate(train_loader):
        # send batch over to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # zero out prev gradients
        optimizer.zero_grad()
        
        # run the forward pass
        output = model(X, a)
        
        # compute loss/error
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch losses
        sum_loss += loss.item()
        
        # compute gradients and take a step
        loss.backward()
        optimizer.step()
    
    # average loss per example
    sum_loss /= len(train_loader)
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}')

Epoch: 1, Loss: 2.639094
Epoch: 2, Loss: 0.668354
Epoch: 3, Loss: 0.652059
Epoch: 4, Loss: 0.638180
Epoch: 5, Loss: 0.625494
Epoch: 6, Loss: 0.615118
Epoch: 7, Loss: 0.604973
Epoch: 8, Loss: 0.596842
Epoch: 9, Loss: 0.589022
Epoch: 10, Loss: 0.582821
Epoch: 11, Loss: 0.576784
Epoch: 12, Loss: 0.571529
Epoch: 13, Loss: 0.566216
Epoch: 14, Loss: 0.562194
Epoch: 15, Loss: 0.558102
Epoch: 16, Loss: 0.555573
Epoch: 17, Loss: 0.552401
Epoch: 18, Loss: 0.548509
Epoch: 19, Loss: 0.546166
Epoch: 20, Loss: 0.544736
Epoch: 21, Loss: 0.543105
Epoch: 22, Loss: 0.540954
Epoch: 23, Loss: 0.540117
Epoch: 24, Loss: 0.538454
Epoch: 25, Loss: 0.537913
Epoch: 26, Loss: 0.535464
Epoch: 27, Loss: 0.535065
Epoch: 28, Loss: 0.533405
Epoch: 29, Loss: 0.533139
Epoch: 30, Loss: 0.533407


#### Validation

In [6]:
# load the validation data for validation
valid_loader = du.DataLoader(dataset=JUND_Dataset('valid_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# set model in eval mode, since we are no longer training
model.eval()
valid_loss = 0
correct = 0
weight = 0

# turn off gradient computation, will speed up testing
with torch.no_grad():
    for batch_idx, (X, y, w, a) in enumerate(valid_loader):
        # send batches to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # compute forward pass and loss
        output = model(X, a)
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch loss
        valid_loss += loss.item()
        
        # compute the validation accuracy
        pred = torch.clone(output)
        pred = torch.sigmoid(pred)
        pred[pred < float(0.5)] = 0
        pred[pred != 0] = 1
        correct += torch.sum((pred == y) * w)
        weight += torch.sum(w)
    
    # valid loss per example
    valid_loss /= len(valid_loader.dataset)
    
    # final test accuracy
    valid_acc = correct / weight
    
    print(f'Valid loss: {valid_loss:.6f}, valid accuracy: {valid_acc:.4f}')

Valid loss: 0.001085, valid accuracy: 0.7182


#### Testing

In [7]:
# load the validation data for validation
test_loader = du.DataLoader(dataset=JUND_Dataset('test_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# set model in eval mode, since we are no longer training
model.eval()
test_loss = 0
correct = 0
weight = 0

# turn off gradient computation, will speed up testing
with torch.no_grad():
    for batch_idx, (X, y, w, a) in enumerate(test_loader):
        # send batches to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # compute forward pass and loss
        output = model(X, a)
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch loss
        test_loss += loss.item()
        
        # compute the validation accuracy
        pred = torch.clone(output)
        pred = torch.sigmoid(pred)
        pred[pred < float(0.5)] = 0
        pred[pred != 0] = 1
        correct += torch.sum((pred == y) * w)
        weight += torch.sum(w)
    
    # valid loss per example
    test_loss /= len(test_loader.dataset)
    
    # final test accuracy
    test_acc = correct / weight
    
    print(f'Test loss: {test_loss:.6f}, test accuracy: {test_acc:.4f}')

Test loss: 0.001071, test accuracy: 0.7407
