In [8]:
import pandas as pd
import re
import json
from collections import *


class Node(object):
  def __init__(self, depth, code, descr=None):
    self.depth = depth
    self.descr = descr or code
    self.code = code
    self.parent = None
    self.children = []

  def add_child(self, child):
    if child not in self.children:
      self.children.append(child)

  def search(self, code):
    if code == self.code: return [self]
    ret = []
    for child in self.children:
      ret.extend(child.search(code))
    return ret

  def find(self, code):
    nodes = self.search(code)
    if nodes:
      return nodes[0]
    return None

  @property
  def root(self):
    return self.parents[0]

  @property
  def description(self):
    return self.descr

  @property
  def codes(self):
    return map(lambda n: n.code, self.leaves)

  @property
  def parents(self):
    n = self
    ret = []
    while n:
      ret.append(n)
      n = n.parent
    ret.reverse()
    return ret


  @property
  def leaves(self):
    leaves = set()
    if not self.children:
      return [self]
    for child in self.children:
      leaves.update(child.leaves)
    return list(leaves)

  # return all leaf notes with a depth of @depth
  def leaves_at_depth(self, depth):
    return filter(lambda n: n.depth == depth, self.leaves)

  @property
  def siblings(self):
    parent = self.parent
    if not parent:
      return []
    return list(parent.children)

  def __str__(self):
    return '%s\t%s' % (self.depth, self.code)

  def __hash__(self):
    return hash(str(self))


class ICD9(Node):
  def __init__(self, codesfname):
    # dictionary of depth -> dictionary of code->node
    self.depth2nodes = defaultdict(dict)
    super(ICD9, self).__init__(-1, 'ROOT')

    with open("/content/codes.json", 'r') as f:
      allcodes = json.loads(f.read())
      self.process(allcodes)

  def process(self, allcodes):
    for hierarchy in allcodes:
      self.add(hierarchy)

  def get_node(self, depth, code, descr):
    d = self.depth2nodes[depth]
    if code not in d:
      d[code] = Node(depth, code, descr)
    return d[code]

  def add(self, hierarchy):
    prev_node = self
    for depth, link in enumerate(hierarchy):
      if not link['code']: continue

      code = link['code']
      descr = 'descr' in link and link['descr'] or code
      node = self.get_node(depth, code, descr)
      node.parent = prev_node
      prev_node.add_child(node)
      prev_node = node



class SecondLevelCodes(object):
    def __init__(self,icd_jon):
        tree = ICD9(icd_jon)
        # list of top level codes (e.g., '001-139', ...)
        top_level_nodes = tree.children

        # second level
        self.second_level_codes = []
        for node in top_level_nodes:
            children = node.children
            codes = [node.code for node in children]
            self.second_level_codes.extend(codes)

    def second_level_codes_icd9(self, dxStr):
        code_3digit = convert_to_3digit_icd9(dxStr)
        for code in self.second_level_codes:
            if len(code) > 4:
                codes = code.split('-')
                if codes[0] <= code_3digit <= codes[1]:
                    return code
            elif code == code_3digit:
                return code


def convert_to_3digit_icd9(dxStr):
    if dxStr.startswith('E'):
        if len(dxStr) > 4:
            return dxStr[:4]
        else:
            return dxStr
    else:
        if len(dxStr) > 3:
            return dxStr[:3]
        else:
            return dxStr


def convert_to_high_level_icd9(dxStr):

    k = dxStr[:3]
    if '001' <= k <= '139':
        return 0
    elif '140' <= k <= '239':
        return 1
    elif '240' <= k <= '279':
        return 2
    elif '280' <= k <= '289':
        return 3
    elif '290' <= k <= '319':
        return 4
    elif '320' <= k <= '389':
        return 5
    elif '390' <= k <= '459':
        return 6
    elif '460' <= k <= '519':
        return 7
    elif '520' <= k <= '579':
        return 8
    elif '580' <= k <= '629':
        return 9
    elif '630' <= k <= '679':
        return 10
    elif '680' <= k <= '709':
        return 11
    elif '710' <= k <= '739':
        return 12
    elif '740' <= k <= '759':
        return 13
    elif '760' <= k <= '779':
        return 14
    elif '780' <= k <= '799':
        return 15
    elif '800' <= k <= '999':
        return 16
    elif 'E00' <= k <= 'E99':
        return 17
    elif 'V01' <= k <= 'V90':
        return 18


class ICD_Ontology():
    
    def __init__(self,icd_file, dx_flag):
        self.icd_file = icd_file
        self.dx_flag = dx_flag
        self.df = pd.read_csv(self.icd_file, index_col=0, dtype=object)
        self.rootLevel()



    def rootLevel(self):

        dxs = self.df.ICD9_CODE.tolist()
        dxMaps = dict()
        
        if self.dx_flag:
            
            for dx in dxs:
                dxMaps.setdefault(dx[0:3], 0)

            for k in dxMaps.keys():
                if '001' <= k <= '139':
                    dxMaps[k] = 1
                if '140' <= k <= '239':
                    dxMaps[k] = 2
                if '240' <= k <= '279':
                    dxMaps[k] = 3
                if '280' <= k <= '289':
                    dxMaps[k] = 4
                if '290' <= k <= '319':
                    dxMaps[k] = 5
                if '320' <= k <= '389':
                    dxMaps[k] = 6
                if '390' <= k <= '459':
                    dxMaps[k] = 7
                if '460' <= k <= '519':
                    dxMaps[k] = 8
                if '520' <= k <= '579':
                    dxMaps[k] = 9
                if '580' <= k <= '629':
                    dxMaps[k] = 10
                if '630' <= k <= '679':
                    dxMaps[k] = 11
                if '680' <= k <= '709':
                    dxMaps[k] = 12
                if '710' <= k <= '739':
                    dxMaps[k] = 13
                if '740' <= k <= '759':
                    dxMaps[k] = 14
                if '760' <= k <= '779':
                    dxMaps[k] = 15
                if '780' <= k <= '799':
                    dxMaps[k] = 16
                if '800' <= k <= '999':
                    dxMaps[k] = 17
                if 'E00' <= k <= 'E99':
                    dxMaps[k] = 18
                if 'V01' <= k <= 'V90':
                    dxMaps[k] = 19
            self.rootMaps = dxMaps
            
        else:
            
            for dx in dxs:
                dxMaps.setdefault(dx[0:2], 0)

            for k in dxMaps.keys():
                if k == '00':
                    dxMaps[k] = 1
                if '01' <= k <= '05':
                    dxMaps[k] = 2
                if '06' <= k <= '07':
                    dxMaps[k] = 3
                if '08' <= k <= '16':
                    dxMaps[k] = 4
                if k == '17':
                    dxMaps[k] = 5
                if '18' <= k <= '20':
                    dxMaps[k] = 6
                if '21' <= k <= '29':
                    dxMaps[k] = 7
                if '30' <= k <= '34':
                    dxMaps[k] = 8
                if '35' <= k <= '39':
                    dxMaps[k] = 9
                if '40' <= k <= '41':
                    dxMaps[k] = 10
                if '42' <= k <= '54':
                    dxMaps[k] = 11
                if '55' <= k <= '59':
                    dxMaps[k] = 12
                if '60' <= k <= '64':
                    dxMaps[k] = 13
                if '65' <= k <= '71':
                    dxMaps[k] = 14
                if '72' <= k <= '75':
                    dxMaps[k] = 15
                if '76' <= k <= '84':
                    dxMaps[k] = 16
                if '85' <= k <= '86':
                    dxMaps[k] = 17
                if '87' <= k <= '99':
                    dxMaps[k] = 18
            dxMaps['E'] = 19
            dxMaps['V'] = 20
            self.rootMaps = dxMaps

    def getRootLevel(self,code):
        
        if self.dx_flag:
            root = code[0:3]
        else:
            if code.startswith('E'):
                root = 'E'
            elif code.startswith('V'):
                root = 'V'
            else:
                root = code[0:2]
        return self.rootMaps[root]


class CCS_Ontology(object):
    
    def __init__(self, ccs_file):
        self.ccs_file = ccs_file
        self.rootLevel()
        
    def rootLevel(self):
        # ccs_file = '../data/CCS/SingleDX-edit.txt'
        with open(self.ccs_file) as f:
            content = f.readlines()

        pattern_code = '^\w+'  # match code line in file
        pattern_newline = '^\n'  # match new line '\n'

        prog_code = re.compile(pattern_code)
        prog_newline = re.compile(pattern_newline)

        catIndex = 0
        catMap = dict()  # store index:code list
        codeList = list()
        for line in content:
            
            # if the current line is code line, parse codes to a list and add to existing code list.
            result_code = prog_code.match(line)
            if result_code:
                codes = line.split()
                codeList.extend(codes)

            # if current line is a new line, add new index and corresponding code list to the catMap dict.
            result_newline = prog_newline.match(line)
            if result_newline:
                catMap[catIndex] = codeList
                codeList = list()  # initualize the code list to empty
                catIndex += 1  # next index
                
        code2CatMap = dict()
        for key, value in catMap.items():
            for code in value:
                code2CatMap.setdefault(code, key)

        self.rootMaps = code2CatMap
    
    def getRootLevel(self, code):
        return self.rootMaps[code]


---

In [9]:
import os
import datetime
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABCMeta
import collections

In [10]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

DATA_PATH = "/content/"

In [11]:
import json

visit_threshold = 2

with open(DATA_PATH+"patients_mimic3_full.json") as read_file:
    patients = json.load(read_file)
patients = [patient for patient in patients if len(patient['visits']) >= visit_threshold]

In [12]:

total_visits = 0

dx_only = False
min_freq = 5

In [13]:
all_codes = []  # store all diagnosis codes
all_cpt_codes = []

for patient in patients:
    for visit in patient['visits']:
        total_visits += 1
        dxs = visit['DXs']
        for dx in dxs:
            all_codes.append('D_' + str(dx))
        if not dx_only:
            txs = visit['CPTs']
            all_cpt_codes.extend(txs)
            for tx in txs:
                all_codes.append('T_' + str(tx))

In [14]:
# store all codes and corresponding counts
count_org = []
count_org.extend(collections.Counter(all_codes).most_common())

In [15]:
count = []
words_count = 0

# store filtering codes and counts
for word, c in count_org:
    word_tuple = [word, c]
    if c >= min_freq:
        count.append(word_tuple)
        words_count += c

In [None]:
slc = SecondLevelCodes("/content/codes_pretty_printed.json")
second_level_codes = []
# add padding
dictionary = dict()
code_to_second_level_code_dict=dict()
dictionary_3digit = dict()
dictionary['PAD'] = 0

for word, cnt in count:
    index = len(dictionary)
    dictionary[word] = index
    if word[:2] == 'D_':
        digit = slc.second_level_codes_icd9(word[2:])
        second_level_codes.append(digit)
        code_to_second_level_code_dict[word] = digit

count_second_level_codes = []
count_second_level_codes.extend(collections.Counter(second_level_codes).most_common())
for word, cnt in count_second_level_codes:
    index = len(dictionary_3digit)
    dictionary_3digit[word] = index

In [None]:

max_len_visit = 0
max_visits = 0


for patient in patients:
    visits = patient['visits']
    len_visits = len(visits)
    if len_visits > max_visits:
        max_visits = len_visits
    for visit in visits:
        visit['Drugs'] = []
        dxs = visit['DXs']
        if len(dxs) == 0:
            continue
        else:
            visit['DXs'] = [dictionary['D_' + str(dx)] for dx in dxs if 'D_' + str(dx) in dictionary]
        len_current_visit = len(visit['DXs'])

        if not dx_only:
            txs = visit['CPTs']
            visit['CPTs'] = [dictionary['T_' + str(tx)] for tx in txs if 'T_' + str(tx) in dictionary]
            # len_current_visit = len(visit['DXs']+visit['CPTs'])
        if len_current_visit > max_len_visit:
            max_len_visit = len_current_visit
reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))

with open("/content/mimic_dx.json", 'w') as fp:
    json.dump(reverse_dictionary, fp)

# with open(self.patients_codes_file + '.json', 'w') as fp:
#     json.dump(self.patients, fp)


In [18]:
train_context_codes = None
train_labels_1 = None
train_labels_2 = None
train_context_onehots = None
dev_context_codes = None
dev_labels_1 = None
dev_labels_2 = None
dev_context_onehots = None
test_context_codes = None
test_labels_1 = None
test_labels_2 = None
test_context_onehots = None
train_size = 0
train_pids = None
dev_pids = None
test_pids = None
train_intervals = None
dev_intervals = None
test_intervals = None

In [None]:
cfg_valid_visits = 10

if not dx_only:
    max_len_visit += 10
    print(f"max_len_visit:{max_len_visit}")
batches = []
n_zeros = 0

for patient in patients:
    pid = patient['pid']
    # get patient's visits
    visits = patient['visits']
    # sorting visits by admission date
    sorted_visits = sorted(visits, key=lambda visit: visit['admsn_dt'])
    valid_visits = []
    for v in sorted_visits:
        if len(v['DXs']) > 0 and sum(v['DXs']) > 0:
            valid_visits.append(v)

    if (len(valid_visits)) < 2:
        continue

    # number of visits and only use 10 visits to predict last one if number of visits is larger than 11
    no_visits = len(valid_visits)
    last_visit = valid_visits[no_visits - 1]
    second_last_visit = valid_visits[no_visits - 2]

    ls_codes = []
    ls_intervals = []
    # only use 10 visits to predict last one if number of visits is larger than 11
    if no_visits > cfg_valid_visits+1:
        feature_visits = valid_visits[no_visits-(cfg_valid_visits+1):no_visits-1]
    else:
        feature_visits = valid_visits[0:no_visits - 1]

    n_visits = len(feature_visits)
    # if n_visits == 0:
    #     n_zeros += 1
    first_valid_visit_dt = datetime.datetime.strptime(feature_visits[0]['admsn_dt'], "%Y%m%d")
    for i in range(n_visits):
        visit = feature_visits[i]
        codes = visit['DXs']
        if not dx_only:
            length = len(visit['CPTs'])
            if length < 11:
                codes.extend(visit['CPTs'])
            else:
                codes.extend(visit['CPTs'][:10])

        if sum(codes) == 0:
            n_zeros += 1

        current_dt = datetime.datetime.strptime(visit['admsn_dt'], "%Y%m%d")
        interval = (current_dt - first_valid_visit_dt).days + 1
        ls_intervals.append(interval)
        code_size = len(codes)
        # code padding
        if code_size < max_len_visit:
            list_zeros = [0] * (max_len_visit - code_size)
            codes.extend(list_zeros)
        ls_codes.append(codes)

    # visit padding
    if n_visits < cfg_valid_visits:
        for i in range(cfg_valid_visits - n_visits):
            list_zeros = [0] * max_len_visit
            ls_codes.append(list_zeros)
            ls_intervals.append(0)


    last_dt = datetime.datetime.strptime(last_visit['admsn_dt'], "%Y%m%d")
    second_last_dt = datetime.datetime.strptime(second_last_visit['admsn_dt'], "%Y%m%d")
    days = (last_dt - second_last_dt).days
    if days <= 30:
        adm_label = 1
    else:
        adm_label = 0
    # --------- end readmission label --------------------

    # --------- second level category --------------------
    one_hot_labels = np.zeros(len(dictionary_3digit)).astype(int)
    last_codes = last_visit['DXs']
    for code in last_codes:
        code_str = reverse_dictionary[code]
        cat_code = code_to_second_level_code_dict[code_str]
        index = dictionary_3digit[cat_code]
        one_hot_labels[index] = 1
    # --------- end diagnosis label --------------------

    # --------- high level icd9 diagnosis label --------------------
    # one_hot_labels = np.zeros(19).astype(int)
    # last_codes = last_visit['DXs']
    # for code in last_codes:
    #     code_str = self.reverse_dictionary[code]
    #     index = convert_to_high_level_icd9(code_str[2:])
    #     one_hot_labels[index] = 1
    # --------- end diagnosis label --------------------
    batches.append(
        [np.array(ls_codes, dtype=np.int32), one_hot_labels, np.array([adm_label], dtype=np.int32), pid,
            np.array(ls_intervals, dtype=np.int32)])

print('number of non-context ', n_zeros)
codes = []
dx_labels = []
re_labels = []
pids = []
intervals = []
for batch in batches:
    codes.append(batch[0])
    dx_labels.append(batch[1])
    re_labels.append(batch[2])
    pids.append(batch[3])
    intervals.append(batch[4])




---

In [23]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        self.x = seqs
        self.y = hfs
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        # your code here
        return len(self.x)
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        
        # your code here
        return self.x[index], self.y[index]

In [24]:

dataset = CustomDataset(codes, np.array(re_labels).reshape(-1,))

In [25]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)
    sequences = list(sequences)
    y = torch.tensor(labels, dtype=torch.float)
    num_patients = len(y)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    for i in range(len(sequences)):
        sequences[i] = sequences[i][~np.all(sequences[i] == 0, axis=1)]
    seqs = []
    for i in sequences:
        seq = []
        for j in i:
            seq.append([mm for mm in j if mm != 0])
        seqs.append(seq)
    sequences = seqs

    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)

    for i_patient, patient in enumerate(sequences):
            n_visit = len(patient)
            for j_visit, visit in enumerate(patient):
                pad_visit = F.pad(torch.tensor(visit), pad=(0, max_num_codes-len(visit)), mode='constant', value=0)
                pad_mask = F.pad(torch.full((len(visit),), 1), pad=(0, max_num_codes-len(visit)), mode='constant', value=0)
                x[i_patient, j_visit, : ] = pad_visit
                rev_x[i_patient, n_visit-1-j_visit, : ] = pad_visit
                masks[i_patient, j_visit, : ] = pad_mask
                rev_masks[i_patient, n_visit-1-j_visit, : ] = pad_mask
    return x, masks, rev_x, rev_masks, y

In [26]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)


Length of train dataset: 5992
Length of val dataset: 1498


In [27]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    
    batch_size = 32
    # your code here
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [28]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()

        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        
        # your code here
        alpha = self.a_att(g)
        alpha = F.softmax(alpha, dim=1)
        return alpha

In [30]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        # your code here
        beta = self.b_att(h)
        beta = F.tanh(beta)
        return beta

In [32]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    # your code here
    mask = torch.any(rev_masks, dim=-1).unsqueeze(-1)
    y = torch.sum(alpha*beta*rev_v*mask, dim=1)
    return y

In [34]:
def sum_embeddings_with_mask(x, masks):
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [35]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=150):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        # 1. Pass the reversed sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

retain = RETAIN(num_codes = np.array(codes).max())
retain

RETAIN(
  (embedding): Embedding(3184, 150)
  (rnn_a): GRU(150, 150, batch_first=True)
  (rnn_b): GRU(150, 150, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=150, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=150, out_features=150, bias=True)
  )
  (fc): Linear(in_features=150, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## 3 Training and Inferencing [10 points]

Then, let us implement the `eval()` function first.

In [36]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score


def eval(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)
        y_hat = y_logit > 0.5
        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    pr_auc = average_precision_score(y_true, y_score)
    return p, r, f, pr_auc

Now let us implement the `train()` function. Note that `train()` should call `eval()` at the end of each training epoch to see the results on the validation dataset.

In [37]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            """ 
            TODO: calculate the loss using `criterion`, save the output to loss.
            """
            
            # your code here
            loss = criterion(y_hat,y)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, pr_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
    return round(roc_auc, 2)

In [None]:
# load the model
retain = RETAIN(num_codes = np.array(codes).max()+1)

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
# optimizer = torch.optim.Adam(retain.parameters(), lr=1e-4)
optimizer = torch.optim.RMSprop(retain.parameters(), lr=1e-4)

n_epochs = 10
train(retain, train_loader, val_loader, n_epochs)

In [None]:
# for name, module in retain.named_modules():
#      print(name, sum(param.numel() for param in module.parameters()))

In [None]:
lr_hyperparameter = [1e-1, 1e-3]
embedding_dim_hyperparameter = [8, 128]
n_epochs = 5
results = {}

for lr in lr_hyperparameter:
    for embedding_dim in embedding_dim_hyperparameter:
        print ('='*50)
        print ({'learning rate': lr, "embedding_dim": embedding_dim})
        print ('-'*50)

        # your code here
        retain = RETAIN(num_codes = len(types), embedding_dim=embedding_dim)
        criterion = nn.BCELoss()
        # load the optimizer
        optimizer = torch.optim.Adam(retain.parameters(), lr=lr)
        
        
        
        pr_auc = train(retain, train_loader, val_loader, n_epochs)
        results['lr:{},emb:{}'.format(str(lr), str(embedding_dim))] =  pr_auc