In [1]:
from experiment import Experiment
from utils.behavior_data import BehaviorData
from visuals import Plotter
import torch
import numpy as np
from utils.state_data import StateData
import matplotlib.pyplot as plt

In [None]:
# model, learning_rate = "BasicNNSplit", .006
# model, learning_rate = "BasicNN", .0054
# model, learning_rate = "LogisticRegressor", .003
model, learning_rate = "AdaptableLSTM", .1
epochs = 1000
seed = 2
include_state = True
estate = True
fullq = False
respond_perc = .50
fullseq = False
insertpreds = True
noise = 0.07
smooth = 0
splitQs = True
splitModel = True

# format for these is [onEpoch, offEpoch, onEpoch, offEpoch....]
# toggles training the category at every epoch in the list
knowSched = [30]
physSched = [40, 100]
conSched = [100]


    
np.random.seed(seed)
torch.manual_seed(seed)


e = Experiment(
    modelSplit = splitModel,
    numValFolds = 5,
    epochsToUpdateLabelMods = 10,
    knowSchedule = knowSched,
    consumpSchedule = conSched,
    physSchedule = physSched,
    data_kw={"minw": 2,
            "maxw": 31,
            "include_state": include_state,
            "include_pid": False,
            "expanded_states": estate,
            "top_respond_perc": respond_perc,
             "full_questionnaire": fullq,
             "full_sequence": fullseq,
             "insert_predictions": insertpreds,
             "one_hot_response_features": True,
             "response_feature_noise": noise,
             "max_state_week": 1,
             "split_model_features": splitModel,
             "split_weekly_questions": splitQs
            },
    model=model,
    model_kw={
        "lossfn": "MSELoss",
        # "lossfn": "NDCG",
        # "lossfn": "CrossEntropyLoss",
        "hidden_size": 10,
        "lr_step_mult": .9, 
        "lr_step_epochs": 100,
        "opt_kw": {
            "lr": learning_rate,
        },
        "splitModel": splitModel,
        "splitWeeklyQuestions": splitQs,
        "labelSmoothPerc": smooth,
        "gaussianNoiseStd": noise
        
    },
    train_kw={
        "epochs": epochs,
        "n_subj": 500,
        "rec_every": 5,
    })


print(len(e.bd.test))
print(len(e.bd.train))

# torch.autograd.set_detect_anomaly(True)

report = e.run()






['state' 'state' 'state' 'state' 'state' 'state' 'state' 'state' 'state'
 'state' 'state' 'state' 'state' 'state' 'state' 'state' 'state' 'state'
 'state' 'state' 'state' 'state' 'response_last_q1' 'response_last_q1'
 'response_last_q1' 'paction_sids_q1' 'paction_sids_q1' 'paction_sids_q1'
 'paction_sids_q1' 'paction_sids_q1' 'pmsg_ids_q1' 'pmsg_ids_q1'
 'pmsg_ids_q1' 'pmsg_ids_q1' 'pmsg_ids_q1' 'pmsg_ids_q1' 'qids_q1'
 'qids_q1' 'qids_q1' 'qids_q1' 'qids_q1' 'qids_q1' 'q1_cat' 'q1_cat']
54
214
LSTM(44, 10)
0	 train loss: 0.1856 train acc: 54.196% test acc: 52.125% train exerAcc: 47.303% test exerAcc: 46.154%
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
5	 train loss: 0.1594 train acc: 56.905% test acc: 54.781% train exerAcc: 47.303% test exerAcc: 46.154%
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
10	 train loss: 0.1555 train acc: 58.152% test acc: 55.644% train exerAcc: 47.303% test exerAcc: 48.077%
LSTM(44, 10)
LSTM(44, 10)
LSTM(44, 10)
LSTM(

In [None]:

print (np.mean(report['train_metrics'], axis=0))
labels = report["metric_labels"]
print(report['train_metrics'][-1, labels.index("Acc")])
print(report['test_metrics'][-1, labels.index("Acc")])

splot = plt.plot(report["rec_epochs"], report["train_metrics"][:, labels.index("Acc")], label="Train Acc.")
splot = plt.plot(report["rec_epochs"], report["test_metrics"][:, labels.index("Acc")], label="Test Acc.")
splot = plt.plot(report["rec_epochs"], report["train_metrics"][:, labels.index("MSE")], label="Train MSE")
splot = plt.plot(report["rec_epochs"], report["test_metrics"][:, labels.index("MSE")], label="Test MSE")
plt.title("Train/Test Performance Over Training")
plt.legend()
plt.ylabel("Metric")
plt.xlabel("Training Epoch")
plt.savefig("simpleNotebookAccPlot.png")

plt.clf()



In [None]:
print(e.bd.features.shape, e.bd.featureList.shape)