#### Import Libraries

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

#### Define Dataset Class

In [25]:
from torch.utils.data import Dataset
import joblib

class JUND_Dataset(Dataset):
    def __init__(self, data_dir):
        '''load X, y, w, a from data_dir'''        
        super(JUND_Dataset, self).__init__()

        # load X, y, w, a from given data_dir
        # convert them into torch tensors
        self.path = os.path.join('.', data_dir)
        self.X = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-X.joblib')))
        self.y = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-y.joblib')))
        self.w = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-w.joblib')))
        self.a = torch.tensor(joblib.load(os.path.join(self.path, 'shard-0-a.joblib')))

    def __len__(self):
        '''return len of dataset'''
        return len(self.X)

    def __getitem__(self, idx):
        '''return X, y, w, and a values at index idx'''
        return self.X[idx], self.y[idx], self.w[idx], self.a[idx]

#### Define Model

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        '''in_dim: input layer dim
           hidden_dim: hidden layer dim
           out_dim: output layer dim'''
        
        super(MLP, self).__init__()
        
        # images are 101x4 so flatten them into 404d vec
        self.flatten = nn.Flatten()
        
        #two fully connected layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        # since x is 101x4, flatten it first
        x = self.flatten(x)
        
        # compute output of fc1, and apply relu activation
        x = F.relu(self.fc1(x))
        
        # compute output layer
        # no activation: cross entropy will compute softmax
        x = self.fc2(x)
        return x

#### Load Data

In [27]:
d_train = du.DataLoader(JUND_Dataset('train_dataset'), batch_size=64, shuffle=True)
train_features = next(iter(d_train))