In [1]:
import os
import time
import datetime
import numpy as np
import pandas as pd
import sklearn
import torch

FILE_NAME_SUFFIX = ".multi_admission" #"_1000"
DATASET_PATH = "Datasets/"

def read_samples(input_file, file_name_suffix):
    file_name = DATASET_PATH + input_file + file_name_suffix
    return pd.read_csv(file_name, error_bad_lines=False, keep_default_na=False)


In [2]:
patients = read_samples("PATIENTS.csv", FILE_NAME_SUFFIX)
patients['birth_date'] = patients["DOB"].apply(lambda s: (datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S') - datetime.datetime(1970,1,1)).days)

admissions = read_samples("ADMISSIONS.csv", FILE_NAME_SUFFIX)
admissions["admit_date"] = admissions["ADMITTIME"].apply(lambda s: (datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S') - datetime.datetime(1970,1,1)).days)
admissions["discharge_date"] = admissions["DISCHTIME"].apply(lambda s: (datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S') - datetime.datetime(1970,1,1)).days)
admissions['LOS'] = admissions["discharge_date"] - admissions["admit_date"]

diagnoses = read_samples("DIAGNOSES_ICD.csv", FILE_NAME_SUFFIX)
procedures = read_samples("PROCEDURES_ICD.csv", FILE_NAME_SUFFIX)
icu_stays = read_samples("ICUSTAYS.csv", FILE_NAME_SUFFIX)
notes = read_samples("NOTEEVENTS.csv", FILE_NAME_SUFFIX)


  if self.run_code(code, result):


In [3]:
patients = patients.set_index("SUBJECT_ID", drop=False)
patients["num_admissions"] = admissions.groupby("SUBJECT_ID").size().to_frame("num_admissions")
patients = patients[patients.num_admissions > 1]
print("Limit the patients to the ones with more than 1 admission:\n", patients.num_admissions.describe())
print("Prevalence of patients with more than 1 admission: ", len(patients[patients["num_admissions"] > 1]) / len(patients))
print("Prevalence of patients with more than 2 admissions: ", len(patients[patients["num_admissions"] > 2]) / len(patients))
print("Prevalence of patients with more than 3 admissions: ", len(patients[patients["num_admissions"] > 3]) / len(patients))
print("Prevalence of patients with more than 4 admissions: ", len(patients[patients["num_admissions"] > 4]) / len(patients))
print("Prevalence of patients with more than 5 admissions: ", len(patients[patients["num_admissions"] > 5]) / len(patients))

Limit the patients to the ones with more than 1 admission:
 count    7537.000000
mean        2.652647
std         1.621112
min         2.000000
25%         2.000000
50%         2.000000
75%         3.000000
max        42.000000
Name: num_admissions, dtype: float64
Prevalence of patients with more than 1 admission:  1.0
Prevalence of patients with more than 2 admissions:  0.31537747114236436
Prevalence of patients with more than 3 admissions:  0.13732254212551412
Prevalence of patients with more than 4 admissions:  0.06992171951704923
Prevalence of patients with more than 5 admissions:  0.03728273849011543


In [4]:
admissions = admissions[admissions.SUBJECT_ID.isin(patients.SUBJECT_ID)]
procedures = procedures[procedures.SUBJECT_ID.isin(patients.SUBJECT_ID)]
diagnoses = diagnoses[diagnoses.SUBJECT_ID.isin(patients.SUBJECT_ID)]
icu_stays = icu_stays[icu_stays.SUBJECT_ID.isin(patients.SUBJECT_ID)]
notes = notes[notes.SUBJECT_ID.isin(patients.SUBJECT_ID)]

In [5]:
last_admission = admissions[admissions.groupby(['SUBJECT_ID'])['admit_date'].transform(max) == admissions['admit_date']]
previous_admissions = admissions[admissions.groupby(['SUBJECT_ID'])['admit_date'].transform(max) != admissions['admit_date']]
patients["record_start_date"] = previous_admissions.groupby("SUBJECT_ID").admit_date.agg(['min'])
patients["record_end_date"] = previous_admissions.groupby("SUBJECT_ID").discharge_date.agg(['max'])
patients["record_length"] = patients.record_end_date - patients.record_start_date
patients["final_admission_date"] = admissions.groupby("SUBJECT_ID").admit_date.agg(['max'])
patients["final_admission_interval"] = patients.final_admission_date - patients.record_end_date

print("Record length before the final admission\n", patients.record_length.describe())

print("Interval before the final admission\n", patients.final_admission_interval.describe())

print("Prevalence of readmission in 30 days = ", len(patients[patients.final_admission_interval < 30]) / len(patients))


Record length before the final admission
 count    7533.000000
mean      242.149608
std       579.288149
min         0.000000
25%         6.000000
50%        13.000000
75%        84.000000
max      4145.000000
Name: record_length, dtype: float64
Interval before the final admission
 count    7533.000000
mean      445.034382
std       678.907514
min       -19.000000
25%        25.000000
50%       131.000000
75%       570.000000
max      4108.000000
Name: final_admission_interval, dtype: float64
Prevalence of readmission in 30 days =  0.27517579938967757


In [6]:
icu_stays["admit_date"] = icu_stays["INTIME"].apply(lambda s: (datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S') - datetime.datetime(1970,1,1)).days)
final_icu_admission = icu_stays[icu_stays.groupby(['SUBJECT_ID'])['admit_date'].transform(max) != icu_stays['admit_date']]
patients["final_icu_admission_date"] = final_icu_admission.groupby("SUBJECT_ID").admit_date.agg(['min'])
patients["final_icu_admission_interval"] = patients.final_icu_admission_date - patients.record_end_date
print("Prevalence of ICU admission in 30 days = ", len(patients[(patients["final_icu_admission_interval"] >= 0) & (patients["final_icu_admission_interval"] < 30)]) / len(patients))


Prevalence of ICU admission in 30 days =  0.006103224094467295


In [7]:
print(procedures.groupby("ICD9_CODE").size().to_frame("procedure_freq_by_icd9_code").describe())
print(procedures.groupby("SUBJECT_ID").size().to_frame("procedure_freq_by_patient").describe())
print(procedures.groupby("HADM_ID").size().to_frame("procedure_freq_by_admission").describe())

print(procedures.groupby("ICD9_CODE").size().to_frame("freq").sort_values("freq", ascending=False).head(20))

       procedure_freq_by_icd9_code
count                  1513.000000
mean                     54.035030
std                     275.114179
min                       1.000000
25%                       1.000000
50%                       4.000000
75%                      19.000000
max                    6505.000000
       procedure_freq_by_patient
count                7364.000000
mean                   11.101983
std                     8.891169
min                     1.000000
25%                     5.000000
50%                     9.000000
75%                    15.000000
max                    98.000000
       procedure_freq_by_admission
count                 17393.000000
mean                      4.700454
std                       3.961372
min                       1.000000
25%                       2.000000
50%                       3.000000
75%                       6.000000
max                      40.000000
           freq
ICD9_CODE      
3893       6505
9604       3440
966      

In [8]:
print(diagnoses.groupby("ICD9_CODE").size().to_frame("diagnosis_freq_by_icd9_code").describe())
print(diagnoses.groupby("SUBJECT_ID").size().to_frame("diagnosis_freq_by_patient").describe())
print(diagnoses.groupby("HADM_ID").size().to_frame("diagnosis_freq_by_admission").describe())

print(diagnoses.groupby("ICD9_CODE").size().to_frame("freq").sort_values("freq", ascending=False).head(20))

       diagnosis_freq_by_icd9_code
count                  4894.000000
mean                     53.192889
std                     258.042880
min                       1.000000
25%                       1.000000
50%                       4.000000
75%                      20.000000
max                    7183.000000
       diagnosis_freq_by_patient
count                7537.000000
mean                   34.539737
std                    28.731059
min                     2.000000
25%                    18.000000
50%                    28.000000
75%                    41.000000
max                   540.000000
       diagnosis_freq_by_admission
count                 19993.000000
mean                     13.020857
std                       6.860812
min                       1.000000
25%                       9.000000
50%                      11.000000
75%                      17.000000
max                      39.000000
           freq
ICD9_CODE      
4019       7183
4280       6588
42731    

In [9]:
notes['text_len'] = notes['TEXT'].apply(lambda s: len(s.split()))
print(notes["text_len"].describe())
print(notes.groupby("SUBJECT_ID").size().to_frame("notes_freq_by_patient").describe())
print(notes.groupby("HADM_ID").size().to_frame("notes_freq_by_admission").describe())

count    739127.000000
mean        280.506953
std         381.865602
min           0.000000
25%          72.000000
50%         160.000000
75%         318.000000
max        7980.000000
Name: text_len, dtype: float64
       notes_freq_by_patient
count            7535.000000
mean               98.092502
std               117.930373
min                 1.000000
25%                34.000000
50%                62.000000
75%               116.000000
max              1420.000000
       notes_freq_by_admission
count             19758.000000
mean                 37.408999
std                 948.520727
min                   1.000000
25%                   8.000000
50%                  15.000000
75%                  31.000000
max              133139.000000


In [10]:
discharge_summaries = notes[notes.CATEGORY == "Discharge summary"]
print(discharge_summaries["text_len"].describe())
print(discharge_summaries.groupby("SUBJECT_ID").size().to_frame("notes_freq_by_patient").describe())
print(discharge_summaries.groupby("HADM_ID").size().to_frame("notes_freq_by_admission").describe())

count    21740.000000
mean      1603.497148
std        884.371393
min          9.000000
25%       1008.000000
50%       1527.000000
75%       2111.000000
max       7980.000000
Name: text_len, dtype: float64
       notes_freq_by_patient
count            7451.000000
mean                2.917729
std                 1.935627
min                 1.000000
25%                 2.000000
50%                 2.000000
75%                 3.000000
max                47.000000
       notes_freq_by_admission
count             19050.000000
mean                  1.141207
std                   0.438223
min                   1.000000
25%                   1.000000
50%                   1.000000
75%                   1.000000
max                   7.000000


In [11]:
patients["death_date"] = patients["DOD"].apply(lambda s: (datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S') - datetime.datetime(1970,1,1)).days if s != '' else np.nan)
patients['death_interval'] = patients.death_date - patients.record_end_date
print(patients['death_interval'].describe())
print("Prevalence of death in 30 days = ", len(patients[patients.death_interval < 30]) / len(patients))
print("Prevalence of death = ", len(patients[patients.death_interval >= 0]) / len(patients))

count    3902.000000
mean      663.113275
std       811.485606
min         0.000000
25%        80.250000
50%       316.500000
75%       937.750000
max      4328.000000
Name: death_interval, dtype: float64
Prevalence of death in 30 days =  0.05532705320419265
Prevalence of death =  0.5177126177524214


In [12]:
# Limit the procedures to the most common procedures
NUM_PROCEDURE_CODES = 1024
top_procedures = procedures.groupby("ICD9_CODE").size().to_frame("freq").sort_values("freq", ascending=False).head(NUM_PROCEDURE_CODES).index.tolist()
procedures = procedures[procedures.ICD9_CODE.isin(top_procedures)]
print(procedures.groupby("ICD9_CODE").size().to_frame("procedure_freq_by_icd9_code").describe())
print(procedures.groupby("SUBJECT_ID").size().to_frame("procedure_freq_by_patient").describe())
print(procedures.groupby("HADM_ID").size().to_frame("procedure_freq_by_admission").describe())

       procedure_freq_by_icd9_code
count                  1024.000000
mean                     79.286133
std                     331.499928
min                       2.000000
25%                       4.000000
50%                      10.500000
75%                      38.000000
max                    6505.000000
       procedure_freq_by_patient
count                7361.000000
mean                   11.029616
std                     8.854558
min                     1.000000
25%                     5.000000
50%                     9.000000
75%                    15.000000
max                    98.000000
       procedure_freq_by_admission
count                 17355.000000
mean                      4.678133
std                       3.937743
min                       1.000000
25%                       2.000000
50%                       3.000000
75%                       6.000000
max                      39.000000


In [13]:
# Limit the diagnoses to the most common diagnoses
NUM_DIAGNOSIS_CODES = 4096
top_diagnoses = diagnoses.groupby("ICD9_CODE").size().to_frame("freq").sort_values("freq", ascending=False).head(NUM_DIAGNOSIS_CODES).index.tolist()
diagnoses = diagnoses[diagnoses.ICD9_CODE.isin(top_diagnoses)]
print(diagnoses.groupby("ICD9_CODE").size().to_frame("diagnosis_freq_by_icd9_code").describe())
print(diagnoses.groupby("SUBJECT_ID").size().to_frame("diagnosis_freq_by_patient").describe())
print(diagnoses.groupby("HADM_ID").size().to_frame("diagnosis_freq_by_admission").describe())

       diagnosis_freq_by_icd9_code
count                  4096.000000
mean                     63.361328
std                     280.940588
min                       1.000000
25%                       2.000000
50%                       7.000000
75%                      28.000000
max                    7183.000000
       diagnosis_freq_by_patient
count                7537.000000
mean                   34.433860
std                    28.696845
min                     2.000000
25%                    18.000000
50%                    27.000000
75%                    41.000000
max                   539.000000
       diagnosis_freq_by_admission
count                 19990.000000
mean                     12.982891
std                       6.851016
min                       1.000000
25%                       9.000000
50%                      11.000000
75%                      17.000000
max                      39.000000


In [14]:
from torch.utils.data import Dataset

other_admission_info_dim = 2 # age and LOS

class CustomDataset(Dataset):
    
    def __init__(self, patients, admissions, procedures, top_procedures, diagnoses, top_diagnoses, prediction_window):
        top_procedures_dict = dict(zip(top_procedures, range(len(top_procedures))))
        top_diagnoses_dict = dict(zip(top_diagnoses, range(len(top_procedures), len(top_procedures) + len(top_diagnoses))))
        self.x = []
        self.y = []
        for _, patient in patients.iterrows():
            patient_admissions = []
            for _, admission in admissions[admissions.SUBJECT_ID == patient.SUBJECT_ID].iterrows():
                icd9_codes = []
                for _, admission_procedure in procedures[procedures.HADM_ID == admission.HADM_ID].iterrows():
                    icd9_codes.append(top_procedures_dict[admission_procedure.ICD9_CODE])
                for _, admission_diagnosis in diagnoses[diagnoses.HADM_ID == admission.HADM_ID].iterrows():
                    icd9_codes.append(top_diagnoses_dict[admission_diagnosis.ICD9_CODE])
                other_info = [(admission.admit_date - patient.birth_date) / 36500.0, admission.LOS / 100.0]
                patient_admissions.append((other_info, icd9_codes))
            self.x.append(patient_admissions)
            self.y.append([patient.final_admission_interval < prediction_window, patient.death_interval < prediction_window])
            # self.y.append(patient.final_admission_interval < prediction_window)
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        

dataset = CustomDataset(patients, previous_admissions, procedures, top_procedures, diagnoses, top_diagnoses, 30)
print(len(dataset))
# for i in range(len(dataset)):
#     x, y = dataset[i]
#     print(y)

7537


In [15]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, encoding size). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, encoding size) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, encoding size) of type torch.bool
        y: a tensor of shape (# patiens) of type torch.float
    """

    sequences, labels = zip(*data)
    num_patients = len(sequences)
    max_admissions = 0
    max_icd9_codes = 0
    for sequence in sequences:
        max_admissions = max(max_admissions, len(sequence))
        for admission in sequence:
            (other_info, icd9_codes) = admission
            max_icd9_codes = max(max_icd9_codes, len(icd9_codes))
    
    dim = (num_patients, max_admissions, max_icd9_codes)
    x_data = np.zeros(dim)
    x_other_data = np.zeros((num_patients, max_admissions, other_admission_info_dim))
    masks_data = np.full(dim, False)
    
    for i, sequence in enumerate(sequences):
        num_admissions = len(sequence)
        for j, admission in enumerate(sequence):
            (other_info, icd9_codes) = admission
            for k, icd9_code in enumerate(icd9_codes):
                x_data[i][j][k] = icd9_code
                masks_data[i][j][k] = True
            for k, other_admission_info in enumerate(other_info):
                x_other_data[i][j][k] = other_admission_info
       
    x = torch.tensor(x_data, dtype=torch.long)
    x_other = torch.tensor(x_other_data, dtype=torch.float)
    masks = torch.tensor(masks_data, dtype=torch.bool)
    y = torch.tensor(labels, dtype=torch.float)
    
    return x, x_other, masks, y

In [16]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, x_other, masks, y = next(loader_iter)
print(x.shape, x_other.shape, masks.shape, y.shape)

torch.Size([10, 5, 39]) torch.Size([10, 5, 2]) torch.Size([10, 5, 39]) torch.Size([10, 2])


In [17]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.75)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 5652
Length of val dataset: 1885


In [18]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)
combined_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [19]:
def sum_embeddings_with_mask(x, masks):
    (batch_size, visits, diags, embedding_dim) = x.shape
    masks = masks.unsqueeze(3).expand(batch_size, visits, diags, embedding_dim)
    output = torch.sum(x * masks, dim=2)
    return output

In [20]:
def get_last_visit(hidden_states, masks):
    batch_size, visits, embedding_dim = hidden_states.shape
    masks = torch.sum(masks, 2)
    masks = torch.min(masks, torch.ones_like(masks))
    masks = torch.sum(masks, 1)
    masks = masks - torch.ones_like(masks)
    masks = masks.unsqueeze(1).expand(batch_size, embedding_dim).unsqueeze(1)
    masks = torch.max(masks, torch.zeros_like(masks)) # FIXME: data cleaning problem! some patients have no admission
    last_visit = torch.gather(hidden_states, 1, masks)
    last_visit = torch.flatten(last_visit, 1, 2)
    return last_visit

In [29]:
class NaiveRNN(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_size, other_admission_info_dim, hidden_state_size, output_size):
        super().__init__()
        self.hidden_state_size = hidden_state_size
        self.embedding = torch.nn.Embedding(num_embeddings, embedding_size)
        self.rnn = torch.nn.GRU(embedding_size + other_admission_info_dim, hidden_state_size, batch_first=True)
        self.linear1 = torch.nn.Linear(hidden_state_size, hidden_state_size)
        self.activation1 = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=0.5)
        self.linear2 = torch.nn.Linear(hidden_state_size, output_size)
        self.activation2 = torch.nn.Sigmoid()
    
    
    def forward(self, x, x_other, masks):
        (batch_size, num_admissions, _) = x.shape
        hidden_state = torch.zeros(1, batch_size, self.hidden_state_size)
        hidden_states = []
        embeddings = self.embedding(x)
        sum_embeddings = sum_embeddings_with_mask(embeddings, masks)
        combined_admission_info = torch.cat((sum_embeddings, x_other), 2)
        output, _ = self.rnn(combined_admission_info)
        output = get_last_visit(output, masks)
        output = self.activation1(self.linear1(output))
        output = self.dropout(output)
        output = self.activation2(self.linear2(output))
        return output.squeeze()
    

# load the model here
naive_rnn = NaiveRNN(num_embeddings=len(top_procedures)+len(top_diagnoses), embedding_size = 128, other_admission_info_dim=other_admission_info_dim, hidden_state_size=64, output_size=2)
naive_rnn

NaiveRNN(
  (embedding): Embedding(5120, 128)
  (rnn): GRU(130, 64, batch_first=True)
  (linear1): Linear(in_features=64, out_features=64, bias=True)
  (activation1): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (linear2): Linear(in_features=64, out_features=2, bias=True)
  (activation2): Sigmoid()
)

In [30]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)

In [31]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):
    model.eval()
    Y1_pred = []
    Y1_true = []
    Y2_pred = []
    Y2_true = []
    for x, x_other, masks, y in val_loader:
        with torch.no_grad():
            pred = model(x, x_other, masks)
            # Y1_true.extend(y.detach().numpy().tolist())
            # Y1_pred.extend(pred.detach().numpy().reshape(-1).tolist())
            Y1_true.extend(y.detach().numpy()[:,0].tolist())
            Y1_pred.extend(pred.detach().numpy()[:,0].reshape(-1).tolist())
            Y2_true.extend(y.detach().numpy()[:,1].tolist())
            Y2_pred.extend(pred.detach().numpy()[:,1].reshape(-1).tolist())
    
    # print(len(Y1_true), len(Y1_pred), len(Y2_true), len(Y2_pred))
    precision1, recall1, f11, _ = precision_recall_fscore_support(Y1_true, np.array(Y1_pred)>0.5, average='binary')
    roc_auc1 = roc_auc_score(Y1_true, Y1_pred)
    precision2, recall2, f12, _ = precision_recall_fscore_support(Y2_true, np.array(Y2_pred)>0.5, average='binary')
    roc_auc2 = roc_auc_score(Y2_true, Y2_pred)
    
    return precision1, recall1, f11, roc_auc1, precision2, recall2, f12, roc_auc2, len(Y1_pred), sum(Y1_true), sum(Y2_true)

In [32]:
# precision1, recall1, f11, roc_auc1 = eval_model(naive_rnn, val_loader)
# print('Task1: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f}'.format(precision1, recall1, f11, roc_auc1))

precision1, recall1, f11, roc_auc1, precision2, recall2, f12, roc_auc2, n, p1, p2 = eval_model(naive_rnn, val_loader)
print('Task1: N={} Prevalence={:.3f} P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f}'.format(n, p1/n, precision1, recall1, f11, roc_auc1))
print('Task2: N={} Prevalence={:.3f} P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f}'.format(n, p2/n, precision2, recall2, f12, roc_auc2))

Task1: N=1885 Prevalence=0.266 P=0.262 R=0.853 F1=0.401 ROC AUC=0.501
Task2: N=1885 Prevalence=0.055 P=0.056 R=0.883 F1=0.106 ROC AUC=0.516


In [33]:
def train(model, train_loader, val_loader, combined_loader, n_epochs):
    model.train()
    for epoch in range(n_epochs):
        train_loss = 0
        for x, x_other, masks, y in train_loader:
            optimizer.zero_grad()
            y_pred = model.forward(x, x_other, masks)
            # loss = criterion(torch.flatten(y_pred), torch.flatten(y))
            all_linear1_params = torch.cat([x.view(-1) for x in model.linear1.parameters()])
            all_linear2_params = torch.cat([x.view(-1) for x in model.linear2.parameters()])
            l1_regularization = 0.01 * torch.norm(all_linear1_params, 1) + 0.01 * torch.norm(all_linear2_params, 1)
            # l2_regularization = 0.01 * torch.norm(all_linear2_params, 2)
            loss = criterion(torch.flatten(y_pred), torch.flatten(y)) + l1_regularization
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        val_loss = 0
        for x, x_other, masks, y in val_loader:
            optimizer.zero_grad()
            y_pred = model.forward(x, x_other, masks)
            loss = criterion(torch.flatten(y_pred), torch.flatten(y))
            loss.backward()
            val_loss+=loss.item()
        print('Epoch {}: training loss = {}  validation loss = {}'.format(epoch, train_loss, val_loss))
        precision1, recall1, f11, roc_auc1, precision2, recall2, f12, roc_auc2, n, p1, p2 = eval_model(model, val_loader)
        print('  Validation Set N={}\n\tTask1: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f} \n\tTask2: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f}'.format(n, precision1, recall1, f11, roc_auc1, p1/n, precision2, recall2, f12, roc_auc2, p2/n))
        precision1, recall1, f11, roc_auc1, precision2, recall2, f12, roc_auc2, n, p1, p2 = eval_model(model, train_loader)
        print('  Training Set N={}\n\tTask1: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f} \n\tTask2: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f}'.format(n, precision1, recall1, f11, roc_auc1, p1/n, precision2, recall2, f12, roc_auc2, p2/n))
        # precision1, recall1, f11, roc_auc1, precision2, recall2, f12, roc_auc2, n, p1, p2 = eval_model(model, combined_loader)
        # print('  Combined Set N={}\n\tTask1: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f} \n\tTask2: P={:.3f} R={:.3f} F1={:.3f} ROC AUC={:.3f} Prevalence={:.3f}'.format(n, precision1, recall1, f11, roc_auc1, p1/n, precision2, recall2, f12, roc_auc2, p2/n))
        
# number of epochs to train the model
n_epochs = 50
train(naive_rnn, train_loader, val_loader, combined_loader, n_epochs)

Epoch 0: training loss = 230.64357247948647  validation loss = 26.269061714410782


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.501 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.434 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.542 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.511 Prevalence=0.056
Epoch 1: training loss = 95.10192468762398  validation loss = 24.014312148094177


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.506 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.434 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.568 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.559 Prevalence=0.056
Epoch 2: training loss = 88.61974358558655  validation loss = 23.77272340655327


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.516 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.441 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.590 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.600 Prevalence=0.056
Epoch 3: training loss = 85.26673913002014  validation loss = 23.76401622593403


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.525 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.441 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.611 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.633 Prevalence=0.056
Epoch 4: training loss = 82.75725424289703  validation loss = 23.572050377726555


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.533 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.476 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.000 R=0.000 F1=0.000 ROC AUC=0.632 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.668 Prevalence=0.056
Epoch 5: training loss = 80.2565279006958  validation loss = 23.582590341567993


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.333 R=0.002 F1=0.004 ROC AUC=0.537 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.472 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.014 F1=0.028 ROC AUC=0.652 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.693 Prevalence=0.056
Epoch 6: training loss = 77.87889194488525  validation loss = 23.510186448693275


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.500 R=0.004 F1=0.008 ROC AUC=0.546 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.491 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.017 F1=0.033 ROC AUC=0.670 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.726 Prevalence=0.056
Epoch 7: training loss = 75.88044986128807  validation loss = 23.515614107251167


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.444 R=0.008 F1=0.016 ROC AUC=0.547 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.506 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.029 F1=0.057 ROC AUC=0.692 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.748 Prevalence=0.056
Epoch 8: training loss = 74.27484956383705  validation loss = 23.57992395758629


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.333 R=0.008 F1=0.016 ROC AUC=0.553 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.531 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.041 F1=0.078 ROC AUC=0.716 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.779 Prevalence=0.056
Epoch 9: training loss = 72.7012457549572  validation loss = 23.536003917455673


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.312 R=0.010 F1=0.019 ROC AUC=0.568 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.558 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.055 F1=0.105 ROC AUC=0.743 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.805 Prevalence=0.056
Epoch 10: training loss = 71.19003573060036  validation loss = 23.568525731563568


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.316 R=0.012 F1=0.023 ROC AUC=0.566 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.567 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.064 F1=0.121 ROC AUC=0.769 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.831 Prevalence=0.056
Epoch 11: training loss = 69.71780171990395  validation loss = 23.62417048215866


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.333 R=0.016 F1=0.030 ROC AUC=0.564 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.570 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=1.000 R=0.080 F1=0.148 ROC AUC=0.801 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.858 Prevalence=0.056
Epoch 12: training loss = 68.34005512297153  validation loss = 23.566364377737045


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.250 R=0.012 F1=0.023 ROC AUC=0.566 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.596 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.993 R=0.095 F1=0.174 ROC AUC=0.825 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.879 Prevalence=0.056
Epoch 13: training loss = 66.74182112514973  validation loss = 23.80485664308071


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.312 R=0.020 F1=0.037 ROC AUC=0.571 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.575 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.994 R=0.104 F1=0.189 ROC AUC=0.848 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.900 Prevalence=0.056
Epoch 14: training loss = 65.24767562747002  validation loss = 23.90582524240017


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.373 R=0.038 F1=0.069 ROC AUC=0.578 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.579 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.991 R=0.133 F1=0.234 ROC AUC=0.863 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.910 Prevalence=0.056
Epoch 15: training loss = 63.42444898188114  validation loss = 24.000963270664215


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.316 R=0.036 F1=0.064 ROC AUC=0.584 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.594 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.992 R=0.159 F1=0.274 ROC AUC=0.877 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.919 Prevalence=0.056
Epoch 16: training loss = 61.785952657461166  validation loss = 24.67009499669075


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.386 R=0.044 F1=0.079 ROC AUC=0.585 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.561 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.992 R=0.168 F1=0.287 ROC AUC=0.892 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.930 Prevalence=0.056
Epoch 17: training loss = 60.17084077000618  validation loss = 24.812362283468246


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.337 R=0.064 F1=0.107 ROC AUC=0.581 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.570 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.986 R=0.270 F1=0.424 ROC AUC=0.901 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.938 Prevalence=0.056
Epoch 18: training loss = 58.59080372750759  validation loss = 24.99785514175892


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.394 R=0.082 F1=0.135 ROC AUC=0.583 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.573 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.995 R=0.351 F1=0.518 ROC AUC=0.909 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.941 Prevalence=0.056
Epoch 19: training loss = 57.14198195934296  validation loss = 25.55365665256977


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.402 R=0.106 F1=0.167 ROC AUC=0.581 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.559 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.995 R=0.406 F1=0.577 ROC AUC=0.915 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.944 Prevalence=0.056
Epoch 20: training loss = 55.87356086075306  validation loss = 25.78784702718258


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.451 R=0.145 F1=0.220 ROC AUC=0.586 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.559 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.986 R=0.523 F1=0.683 ROC AUC=0.921 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.949 Prevalence=0.056
Epoch 21: training loss = 54.83954468369484  validation loss = 26.563229754567146


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.421 R=0.127 F1=0.196 ROC AUC=0.582 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.570 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.997 R=0.490 F1=0.658 ROC AUC=0.928 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.952 Prevalence=0.056
Epoch 22: training loss = 53.562115862965584  validation loss = 26.640188366174698


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.394 R=0.137 F1=0.204 ROC AUC=0.578 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.561 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.984 R=0.598 F1=0.744 ROC AUC=0.934 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.954 Prevalence=0.056
Epoch 23: training loss = 52.5171223282814  validation loss = 26.844998836517334


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.406 R=0.159 F1=0.229 ROC AUC=0.580 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.557 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.978 R=0.667 F1=0.793 ROC AUC=0.939 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.956 Prevalence=0.056
Epoch 24: training loss = 51.562500461936  validation loss = 26.77024607360363


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.395 R=0.213 F1=0.277 ROC AUC=0.585 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.579 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.977 R=0.742 F1=0.843 ROC AUC=0.946 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.958 Prevalence=0.056
Epoch 25: training loss = 50.66047412157059  validation loss = 27.640842631459236


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.408 R=0.177 F1=0.247 ROC AUC=0.579 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.554 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.979 R=0.728 F1=0.835 ROC AUC=0.946 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.960 Prevalence=0.056
Epoch 26: training loss = 49.85533806681633  validation loss = 27.379454597830772


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.413 R=0.217 F1=0.285 ROC AUC=0.582 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.579 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.980 R=0.751 F1=0.851 ROC AUC=0.950 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.963 Prevalence=0.056
Epoch 27: training loss = 49.16287752985954  validation loss = 27.834381073713303


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.395 R=0.201 F1=0.266 ROC AUC=0.577 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.577 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.982 R=0.747 F1=0.848 ROC AUC=0.954 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.966 Prevalence=0.056
Epoch 28: training loss = 48.47843787074089  validation loss = 28.28089453279972


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.387 R=0.217 F1=0.278 ROC AUC=0.576 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.565 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.981 R=0.773 F1=0.865 ROC AUC=0.955 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.966 Prevalence=0.056
Epoch 29: training loss = 47.85174612700939  validation loss = 28.139240324497223


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.407 R=0.217 F1=0.283 ROC AUC=0.578 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.575 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.982 R=0.779 F1=0.869 ROC AUC=0.960 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.968 Prevalence=0.056
Epoch 30: training loss = 47.32309092581272  validation loss = 28.69754333794117
  Validation Set N=1885
	Task1: P=0.377 R=0.217 F1=0.276 ROC AUC=0.581 Prevalence=0.266 
	Task2: P=0.096 R=0.049 F1=0.065 ROC AUC=0.582 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.983 R=0.785 F1=0.873 ROC AUC=0.957 Prevalence=0.278 
	Task2: P=0.864 R=0.768 F1=0.813 ROC AUC=0.967 Prevalence=0.056
Epoch 31: training loss = 46.793889075517654  validation loss = 28.80153077840805
  Validation Set N=1885
	Task1: P=0.394 R=0.227 F1=0.288 ROC AUC=0.572 Prevalence=0.266 
	Task2: P=0.106 R=0.049 F1=0.067 ROC AUC=0.567 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.985 R=0.798 F1=0.882 ROC AUC=0.967 Prevalence=0.278 
	Task2: P=0.863 R=0.764 F1=0.811 ROC AUC=0.969 Prevalence=0.056
Epoch 32: training loss = 46.2990747243166  validation loss = 29.3097938597202

  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.381 R=0.213 F1=0.273 ROC AUC=0.572 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.563 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.986 R=0.797 F1=0.881 ROC AUC=0.969 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.971 Prevalence=0.056
Epoch 33: training loss = 45.77777822315693  validation loss = 29.58971095085144


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.369 R=0.215 F1=0.272 ROC AUC=0.573 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.571 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.987 R=0.814 F1=0.892 ROC AUC=0.971 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.971 Prevalence=0.056
Epoch 34: training loss = 45.273187443614006  validation loss = 29.610144302248955


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.370 R=0.227 F1=0.281 ROC AUC=0.570 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.579 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.983 R=0.833 F1=0.902 ROC AUC=0.976 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.973 Prevalence=0.056
Epoch 35: training loss = 44.909795597195625  validation loss = 29.93149395287037


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.365 R=0.245 F1=0.293 ROC AUC=0.565 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.562 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.982 R=0.842 F1=0.906 ROC AUC=0.978 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.973 Prevalence=0.056
Epoch 36: training loss = 44.368709683418274  validation loss = 30.010183334350586


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.360 R=0.235 F1=0.284 ROC AUC=0.567 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.573 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.985 R=0.846 F1=0.910 ROC AUC=0.981 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.975 Prevalence=0.056
Epoch 37: training loss = 43.995497331023216  validation loss = 30.69420798122883


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.361 R=0.245 F1=0.292 ROC AUC=0.562 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.558 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.987 R=0.849 F1=0.913 ROC AUC=0.982 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.974 Prevalence=0.056
Epoch 38: training loss = 43.443069353699684  validation loss = 30.692401811480522


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.359 R=0.253 F1=0.297 ROC AUC=0.562 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.566 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.987 R=0.853 F1=0.915 ROC AUC=0.984 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.976 Prevalence=0.056
Epoch 39: training loss = 42.97400060296059  validation loss = 30.99540400505066


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.366 R=0.217 F1=0.273 ROC AUC=0.558 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.544 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.989 R=0.838 F1=0.907 ROC AUC=0.987 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.979 Prevalence=0.056
Epoch 40: training loss = 42.62888444960117  validation loss = 31.640383511781693


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.369 R=0.229 F1=0.283 ROC AUC=0.556 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.553 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.991 R=0.848 F1=0.914 ROC AUC=0.989 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.980 Prevalence=0.056
Epoch 41: training loss = 42.1186348348856  validation loss = 31.53229396045208
  Validation Set N=1885
	Task1: P=0.355 R=0.233 F1=0.281 ROC AUC=0.557 Prevalence=0.266 
	Task2: P=0.128 R=0.058 F1=0.080 ROC AUC=0.555 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.987 R=0.873 F1=0.927 ROC AUC=0.990 Prevalence=0.278 
	Task2: P=0.933 R=0.793 F1=0.857 ROC AUC=0.982 Prevalence=0.056
Epoch 42: training loss = 41.75564715266228  validation loss = 32.14565922319889


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.347 R=0.243 F1=0.286 ROC AUC=0.556 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.548 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.987 R=0.885 F1=0.933 ROC AUC=0.992 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.984 Prevalence=0.056
Epoch 43: training loss = 41.26949505507946  validation loss = 32.45234189927578
  Validation Set N=1885
	Task1: P=0.370 R=0.259 F1=0.305 ROC AUC=0.557 Prevalence=0.266 
	Task2: P=0.174 R=0.078 F1=0.107 ROC AUC=0.554 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.989 R=0.891 F1=0.938 ROC AUC=0.993 Prevalence=0.278 
	Task2: P=0.912 R=0.822 F1=0.864 ROC AUC=0.983 Prevalence=0.056
Epoch 44: training loss = 40.720056891441345  validation loss = 32.50116318464279


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.366 R=0.253 F1=0.299 ROC AUC=0.552 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.538 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.987 R=0.912 F1=0.948 ROC AUC=0.994 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.985 Prevalence=0.056
Epoch 45: training loss = 40.452274695038795  validation loss = 33.471924886107445


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.350 R=0.257 F1=0.296 ROC AUC=0.550 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.543 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.990 R=0.912 F1=0.950 ROC AUC=0.994 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.985 Prevalence=0.056
Epoch 46: training loss = 40.02595727145672  validation loss = 33.87572979927063


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.354 R=0.249 F1=0.292 ROC AUC=0.551 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.543 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.993 R=0.909 F1=0.949 ROC AUC=0.995 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.986 Prevalence=0.056
Epoch 47: training loss = 39.76027502119541  validation loss = 33.69415009021759


  _warn_prf(average, modifier, msg_start, len(result))


  Validation Set N=1885
	Task1: P=0.357 R=0.263 F1=0.303 ROC AUC=0.554 Prevalence=0.266 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.553 Prevalence=0.055


  _warn_prf(average, modifier, msg_start, len(result))


  Training Set N=5652
	Task1: P=0.993 R=0.913 F1=0.951 ROC AUC=0.996 Prevalence=0.278 
	Task2: P=0.000 R=0.000 F1=0.000 ROC AUC=0.986 Prevalence=0.056
Epoch 48: training loss = 39.27618044614792  validation loss = 33.92325368523598
  Validation Set N=1885
	Task1: P=0.372 R=0.267 F1=0.311 ROC AUC=0.555 Prevalence=0.266 
	Task2: P=0.125 R=0.058 F1=0.079 ROC AUC=0.539 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.993 R=0.927 F1=0.959 ROC AUC=0.997 Prevalence=0.278 
	Task2: P=0.953 R=0.834 F1=0.890 ROC AUC=0.987 Prevalence=0.056
Epoch 49: training loss = 38.977672919631004  validation loss = 34.410685300827026
  Validation Set N=1885
	Task1: P=0.372 R=0.251 F1=0.300 ROC AUC=0.557 Prevalence=0.266 
	Task2: P=0.098 R=0.039 F1=0.056 ROC AUC=0.539 Prevalence=0.055
  Training Set N=5652
	Task1: P=0.994 R=0.925 F1=0.958 ROC AUC=0.997 Prevalence=0.278 
	Task2: P=0.950 R=0.787 F1=0.861 ROC AUC=0.988 Prevalence=0.056
