#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

#### Define Dataset Class

In [2]:
from torch.utils.data import Dataset
import joblib

class JUND_Dataset(Dataset):
    def __init__(self, data_dir):
        '''load X, y, w, a from data_dir'''        
        super(JUND_Dataset, self).__init__()

        # load X, y, w, a from given data_dir
        # convert them into torch tensors
        self.path = os.path.join('.', data_dir)
        self.X = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-X.joblib')), dtype=torch.float)
        self.y = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-y.joblib')), dtype=torch.float)
        self.w = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-w.joblib')), dtype=torch.float)
        self.a = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-a.joblib')), dtype=torch.float)

    def __len__(self):
        '''return len of dataset'''
        return len(self.X)

    def __getitem__(self, idx):
        '''return X, y, w, and a values at index idx'''
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]

#### Define Model

In [3]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        '''in_dim: input layer dim
           hidden_dim: hidden layer dim
           out_dim: output layer dim'''
        
        super(MLP, self).__init__()
        
        # images are 101x4 so flatten them into 404d vec
        self.flatten = nn.Flatten()
        
        # two fully connected layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        # add one since we have accessibility value
        self.fc2 = nn.Linear(hidden_dim + 1, out_dim)

    def forward(self, x, a):
        # since x is 101x4, flatten it first
        x = self.flatten(x)
        
        # compute output of fc1, and apply relu activation
        x = F.relu(self.fc1(x))
        # print(x.size())
        # print(a.size())
        # concatenate the accessibility value
        x = torch.cat((a, x), 1)
        
        # compute output layer
        x = self.fc2(x)
        return x

#### Set Up Training

In [4]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
batch_size = 1000
learning_rate = 0.01
epochs = 5

# set model and optimizer
# images are 101x4 as inputs
# use 128d hidden layer
# output is 1d since there is only 0 and 1 classe
model = MLP(101*4, 128, 1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# load training data in batches
train_loader = du.DataLoader(dataset=JUND_Dataset('train_dataset'), 
                        batch_size=batch_size, 
                        shuffle=True)

# send model over to device
model = model.to(device)
model.train()

# train_features = next(iter(train_loader))
# train_features

using device: cuda:0


MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=404, out_features=128, bias=True)
  (fc2): Linear(in_features=129, out_features=1, bias=True)
)

In [6]:
for epoch in range(1, epochs + 1):    
    sum_loss = 0.
    for batch_idx, (X, y, w, a) in enumerate(train_loader):
        # send batch over to device
        X, y, w, a = X.to(device), y.to(device), w.to(device), a.to(device)
        
        # zero out prev gradients
        optimizer.zero_grad()
        
        # run the forward pass
        #print(torch.reshape(a, (1, -1)))
        output = model(X, a)
        
        # compute loss/error
        loss = F.binary_cross_entropy_with_logits(output, y, weight = w)
        
        # sum up batch losses
        sum_loss += loss.item()
        
        # compute gradients and take a step
        loss.backward()
        optimizer.step()
    
    # average loss per example    
    sum_loss /= len(train_loader)
    print(f'Epoch: {epoch}, Loss: {sum_loss:.6f}')

Epoch: 1, Loss: 0.696123
Epoch: 2, Loss: 0.688886
Epoch: 3, Loss: 0.678956
Epoch: 4, Loss: 0.666717
Epoch: 5, Loss: 0.649644
