In [1]:
# import pickle
from pickle5 import pickle
import torch
from torch import nn, optim
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix, classification_report, f1_score, recall_score
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, Dataset
from sklearn.metrics import mean_absolute_error, mean_squared_error
import math

from util.test_data_processing import prepare_test_data
from util.train_data_processing import prepare_training_data
from util.train_model import train_model

import csv


def main_run(train_epochs, aug_0, aug_1, aug_2, aug_3, aug_4):
    seed = 32
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

    sample_length = 42

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv('../../processed_data/Detailed_PHQ8_Labels.csv')
    df['PHQ_5way'] = (df['PHQ_8Total'] / 5).astype(int)
    all_labels = dict(zip(df['Participant_ID'], df['PHQ_5way']))

    test_X, test_y, test_user_names = prepare_test_data(sample_length, all_labels)

    tensor_X_train, tensor_y_train = prepare_training_data(all_labels, aug_0, aug_1, aug_2, aug_3, aug_4, sample_length)
    trainDataset = TensorDataset(tensor_X_train, tensor_y_train)
    trainLoader = DataLoader(trainDataset, batch_size=32, shuffle=True)

    trained_model, pred = train_model(trainLoader, train_epochs, test_X, test_y, test_user_names)
    torch.save(trained_model.state_dict(), 'trained_router_model.pth')

    report = classification_report(test_y, pred, zero_division=0, digits=4)
    five_way_f1 = f1_score(test_y, pred, average='macro')

    print('the F1 score for 5-way classification:', five_way_f1)
    return None

params = {'scale': 5, 'train_epochs': 11}
train_epochs = params['train_epochs']
scale = params['scale']
aug_0, aug_1, aug_2, aug_3, aug_4 = [element * scale for element in [1, 2, 3, 4, 12]]
loss = main_run(train_epochs, aug_0, aug_1, aug_2, aug_3, aug_4)



the F1 score for 5-way classification: 0.24574613003095974
