In [3]:
# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# For data preprocess
import numpy as np
import csv
import os

# For plotting
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

myseed = 42069  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

In [4]:
def get_device():
    ''' Get device (if GPU is available, use GPU) '''
    return 'cuda' if torch.cuda.is_available() else 'cpu'

def plot_learning_curve(loss_record, title=''):
    ''' Plot learning curve of your DNN (train & dev loss) '''
    total_steps = len(loss_record['train'])
    x_1 = range(total_steps)
    x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])]
    figure(figsize=(6, 4))
    plt.plot(x_1, loss_record['train'], c='tab:red', label='train')
    plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev')
    plt.ylim(0.0, 5.)
    plt.xlabel('Training steps')
    plt.ylabel('MSE loss')
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()


def plot_pred(dv_set, model, device, lim=35., preds=None, targets=None):
    ''' Plot prediction of your DNN '''
    if preds is None or targets is None:
        model.eval()
        preds, targets = [], []
        for x, y in dv_set:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                preds.append(pred.detach().cpu())
                targets.append(y.detach().cpu())
        preds = torch.cat(preds, dim=0).numpy()
        targets = torch.cat(targets, dim=0).numpy()

    figure(figsize=(5, 5))
    plt.scatter(targets, preds, c='r', alpha=0.5)
    plt.plot([-0.2, lim], [-0.2, lim], c='b')
    plt.xlim(-0.2, lim)
    plt.ylim(-0.2, lim)
    plt.xlabel('ground truth value')
    plt.ylabel('predicted value')
    plt.title('Ground Truth v.s. Prediction')
    plt.show()

In [24]:
class COVID19Dataset(Dataset):
    def __init__(self, path, mode = 'train', target_only = False):
        self.mode = mode
        
        with open(path, 'r') as fp:
            data = list(csv.reader(fp))
            data = np.array(data[1:])[:,1:].astype(float)
        
        if not target_only:
            feats = list(range(93))
        else:
            feats = list(range(40)) + [57, 75]
        
        if self.mode == 'test':
            data = data[:, feats]
            self.data = torch.FloatTensor(data)
        else:
            data = data[:, feats]
            target = data[:, -1]
            
            if self.mode == 'train':
                indices = [i for i in range(len(data)) if i % 10 != 0]
            else:
                indices = [i for i in range(len(data)) if i % 10 == 0]
            
            self.data = torch.FloatTensor(data[indices, :])
            self.target = torch.FloatTensor(target[indices, :])
        
        self.data[:, 40:] = (self.data[:, 40:] - self.data[:, 40:].mean(dim = 0, keepdim = True)) / self.data[:, 40:].std(dim = 0, keepdim = True)
        
        self.dim = self.data.shape[1]
        
        print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'
              .format(mode, len(self.data), self.dim))
    
    def __getitem__(self, index):
        if self.mode in ['train', 'dev']:
            return self.data[index], self.target[index]
        else:
            return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [25]:
def prep_dataloader(path, mode, batch_size, n_jobs = 0, target_only = False):
    dataset = COVID19Dataset(path, mode, target_only)
    dataloader = DataLoader(dataset, batch_size, shuffle = (mode == 'train'), num_workers = n_jobs, drop_last = False, pin_memory=True)
    return dataloader

In [None]:
class NeuralNet(nn.Module):
    def __init__(self, input_dim):
        super(NeuralNet, self).__init__()
        self.net = nn.Sequential(
            
        )