In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the CC BY-NC license found in the
LICENSE.md file in the root directory of this source tree.
"""
import pickle
import random
import time
import gym
import numpy as np

import sys
import os
from munch import Munch
import yaml

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from PIL import Image
import os
import pandas as pd


from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from PIL import Image


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


In [None]:
df = pd.read_csv('ema.csv', sep=',')
df = df.sort_values(['subid', 'dttm_obs'])[['dttm_obs', 'subid', 'ema_1', 'ema_2', 'ema_3', 'ema_4', \
                                            'ema_5', 'ema_6', 'ema_7', 'ema_8', 'ema_9', 'ema_10']]

# Handling NA: drop the entire row if 4th entry is NA 
# For morning-only survey, set dummy=1 if survey reported, otherwise dummy=0
df = df.fillna(-6.0)
df = df[df['ema_4'] > -1]
df['ema_dummy'] = (df['ema_8'] > -1)
df.loc[df['ema_8'] < 0, ['ema_8', 'ema_9', 'ema_10']] = 0

# Encoding: 'Yes/No -> 0/1'
df = df.replace('No',0)
df = df.replace('Yes',1)

# Date-time encoding -- in hour unit 
df['date'] = (pd.to_datetime(df['dttm_obs']).astype('int64') // (10**9)) / 3600
df2 = df[['subid', 'ema_1', 'ema_2', 'ema_3', 'ema_4', \
            'ema_5', 'ema_6', 'ema_7', 'ema_8', 'ema_9', 'ema_10', 'ema_dummy', 'date']]


time_df = pd.read_csv('labels_1day.csv', sep=',')
time_df = time_df.sort_values(['subid', 'dttm_label']) 
time_df['date'] = (pd.to_datetime(time_df['dttm_label']).astype('int64') // (10**9)) / 3600
time_df = time_df.replace('no',0)
time_df = time_df.replace('yes',1)

time_df = time_df[['subid', 'lapse', 'date']]

In [None]:
def input_data(org_data, time_label):
    # 24hr window right-index
    fj = 0    
    
    # offset starting date-time to 0 
    offset_time = time_label[0,-1] + 0.0
    org_data[:,-1] -= offset_time
    time_label[:,-1] -= offset_time

    dataset = []
    tq = ts = 0
    while(tq < time_label.shape[0] and ts < org_data.shape[0] ):
        if(ts == org_data.shape[0] or (tq < time_label.shape[0] and time_label[tq, -1] < org_data[ts, -1])):
            y = 1 if( time_label[tq,1] > 0 ) else 0
            dataset.append({'type':'query', 'out': y, 'time': time_label[tq,-1]})
            
            while(tq < time_label.shape[0] and time_label[tq,-1] < org_data[ts,-1] ):
                tq += 1
        else:
            dataset.append({'type':'survey', 'obs': org_data[ts, 1:]+0.0, 'time': org_data[ts,-1]})
            ts += 1

    return dataset


In [None]:
import xgboost as xgb
from sklearn import metrics
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

def input_xgboost_data(org_data, time_label):
    j = [0,0,0,0,0]
    fj = 0
    window = [12, 24, 48, 72, 168]
    
    dataset = []
    tq = ts = 0

    while( tq < time_label.shape[0] ):
        time = time_label[tq,-1]

        while( ts < org_data.shape[0] and org_data[ts,-1] < time ):
            ts += 1
        
        for k in range(5):
            while( j[k] < ts-1 and org_data[j[k], -1] < time - window[k] ):
                j[k] = j[k] + 1
                   
        if( ts == 0 ):
            tq += 1
            continue
            
        X1 = np.concatenate(([np.mean(org_data[j[k]:ts, 1:], axis=0) for k in range(5)], \
                            [np.min(org_data[j[k]:ts, 1:], axis=0) for k in range(5)] , \
                            [np.max(org_data[j[k]:ts, 1:], axis=0) for k in range(5)]), axis=1)
        X1 = np.concatenate(X1)

        sub_mean = np.mean(org_data[:ts], axis=0)
        X = np.concatenate((X1, sub_mean[1:]))
        X = np.concatenate((X, org_data[ts-1, 1:]))

        y = 1 if (time_label[tq,1] > 0 ) else 0
        tq += 1
        
        dataset.append({'in':X, 'out':y})

    return dataset


def xgboost_main(train_index, test_index):
    trainset = []
    for m in train_index:
        subid = subid_list[m] + 1
        sub_data = raw_data[raw_data[:,0] == subid] 
        time_label = time_raw_data[time_raw_data[:,0]==subid]                
        trainset.extend(input_xgboost_data(sub_data, time_label))

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=500000, shuffle=True, drop_last=False)

    testset = []
    for m in test_index:
        subid = subid_list[m] + 1
        sub_data = raw_data[raw_data[:,0] == subid]     
        time_label = time_raw_data[time_raw_data[:,0]==subid]                
        testset.extend(input_xgboost_data(sub_data, time_label))

    test_loader = torch.utils.data.DataLoader(testset, batch_size=60000, shuffle=False, drop_last=False)

    data_iter = iter(train_loader)
    train_data = next(data_iter)
    X = train_data['in']
    y = train_data['out']

    test_iter = iter(test_loader)
    test_data = next(test_iter)

    # Create regression matrices
    dtrain_reg = xgb.DMatrix(X, y)

    X_test = test_data['in']
    y_test = test_data['out']
    dtest_reg = xgb.DMatrix(X_test, y_test)

    params = {"objective": "reg:squarederror", "tree_method": "hist", "device": "cuda"}

    print(X.shape, X_test.shape)
    
    n = 10
    xgmodel = xgb.train(
       params=params,
       dtrain=dtrain_reg,
       num_boost_round=n,
    )

    preds = xgmodel.predict(dtest_reg)
    print(len(y_test[y_test>0]), len(preds[preds>0.1]), len(preds))
    print('XgBoost: AUC = %.2f' % roc_auc_score(y_test, preds))  #1


In [None]:
raw_data = df2.values.astype(np.float32)
time_raw_data = time_df.values.astype(np.float32)

# action encoding: 
# 0 - no survey, predict No
# 1 - yes survey, predict No
# 2 - no survey, predict Yes
# 3 - yes survey, predict Yes
action_max = 4

# feature encoding
# 0 - no lapse within 24hr
# 1 - yes lapse within 24hr
feature_max = 2

state_dim, embed_dim, feature_dim = 12, 256, 1
env_list, subid_list = [], []
for subid in range(270):
    sub_data = raw_data[raw_data[:,0]==(subid+1)]
    time_label = time_raw_data[time_raw_data[:,0]==(subid+1)]
    if( len(sub_data) == 0 or len(time_label) == 0 ):
        continue
          
    subid_list.append(subid)
    env_list.append(input_data(sub_data, time_label))

print(len(subid_list))

# Retrieve Saved Results

In [None]:
from env import Env, VecEnv

num_process = 50
window_size = 25

In [None]:

def test(venv, model, agent, mode='policy'):
    scatter = []
    num_process = len(venv.test_ids)
    
    cnt, n_measure, n_hit = [0 for i in range(num_process)], [0 for i in range(num_process)], [0 for i in range(num_process)]
    y, outy = [], []

    env = venv.venv[0]
    obs, info = venv.reset(mode='test')
    cumr = 0
    
    ts = torch.zeros(num_process, window_size, env.state_dim).to(device)
    ta = torch.zeros(num_process, window_size, env.action_dim).to(int).to(device)

    bf = model(ts, ta, output_embedding=True)
    bf = bf.detach()

    for t in range(1000):
        if( mode == 'policy' ):
            actions, _, _ = agent.select_action(bf)
            actions = actions.detach().cpu()
        else:
            actions = torch.zeros(num_process, env.action_dim).to(int) + 1

        obs, rewards, dones, infos = venv.step(actions)
        cumr += sum(rewards)
        
        predict_features = model.predict_forward(bf).detach()       
   
        # add transition data to dataset
        all_done = True
        for i in range(num_process):
            if( infos[i] == -1 ):
                continue

            all_done = False
            if (infos[i] != 'survey'):
                check_y = nn.Softmax(dim=-1)(predict_features[i]).reshape(feature_max).detach().cpu()                
                
                y.append(infos[i])
                outy.append(check_y[1])
            
                ts[i][-1, -1] = obs[i, -1]
            else:    
                ts[i] = torch.cat([ts[i][1:,:], obs[i].view(-1, env.state_dim).to(device)], dim=0)
                ta[i] = torch.cat([ta[i][1:,:], actions[i].view(-1, 1).to(device)], dim=0)

                cnt[i] += 1
                if( actions[i] % 2 == 1 ):
                    n_measure[i] += 1
                if( obs[i,0] > 0 ):
                    n_hit[i] += 1

        if( all_done ): 
            break
                
        bf = model(ts.reshape(num_process,-1,state_dim), ta.reshape(num_process,-1, 1), output_embedding=True).detach()

    for i in range(num_process):
        scatter.append([n_measure[i]/cnt[i], n_hit[i]/cnt[i]])
        print(venv.test_ids[i], ':', n_measure[i], n_hit[i], cnt[i])
    
    scatter = np.array(scatter)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(4, 3))  

    plt.scatter(scatter[:,1], scatter[:,0])
    plt.xlabel("Average Lapse Count")
    plt.ylabel("Survey Ratio")
    plt.show()

    print('Average Survey Ratio:', np.mean(scatter[:,0]))

    import xgboost as xgb
    from sklearn import metrics
    from sklearn.metrics import roc_auc_score
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split

    auc_score = roc_auc_score(y, outy)
    print('AUC = %.2f' % roc_auc_score(y, outy))  #1

    return venv.test_ids, n_measure, n_hit, cnt, auc_score
    

In [None]:
import random
from SACD import init_SACD_agent
from model import build_model
from buffer import ReplayBuffer

import torch.nn.functional as F
import time
import copy
import csv

for run_id in [1,2,3]:
    for penalty in [-1, 0.02, 0.05, 0.08, 0.12]:
        out_filename = f"freq_data_{run_id}_{penalty}.csv"        
        
        with open(out_filename, 'w', newline='') as fp:
            writer = csv.writer(fp)
            writer.writerow(["id", "revealed", "lapses", "replied", "AUC"])
            
            for kfold in range(10):
                filename = os.path.join(f"./results/run_{run_id}/", f"run_{run_id}_kfold_{kfold}_penalty_{penalty}.pt")
                loaded_data = torch.load(filename, map_location=device) 

                test_index = loaded_data['test_ids']
                train_index = list(set(range(151)) - set(test_index))
                # xgboost_main(train_index, test_index)

                venv = VecEnv(env_list, train_index, test_index)
                env = venv.venv[0]
                model = build_model(embed_dim, feature_dim, feature_max, model_family='gpt2').to(device)
                agent = init_SACD_agent(env, device, belief_dim=embed_dim, summary_model=model, \
                                        lr_actor=1e-4, lr_critic=2e-4, lr_summary=3e-5, entropy_regularizer=0.03)

                model_dict = loaded_data['model_state_dict']  # rm_orig_mod(state["model_state_dict"])
                agent_dict = loaded_data['agent_state_dict']

                model.load_state_dict(model_dict)
                agent.load_dict(agent_dict)

                mode = 'full' if( penalty == -1 ) else 'policy'
                test_id, n_measure, n_hit, cnt, auc_score = test(venv, model, agent, mode=mode)

                for i in range(len(test_id)):
                    c = auc_score if ( i == 0 ) else 0
                    writer.writerow([test_id[i], n_measure[i], n_hit[i], cnt[i], c])

            
        
            