In [1]:
# import pickle
from pickle5 import pickle
import torch
from torch import nn, optim
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix, classification_report, f1_score, recall_score
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, Dataset
from sklearn.metrics import mean_absolute_error, mean_squared_error
import math
from collections import Counter

from util.test_data_processing import prepare_test_data
from util.train_data_processing import prepare_training_data
from util.train_model import train_model
from util.evaluation import cnn_evaluate

def main_run(train_epochs, scale):
    seed = 32
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

    sample_length = 42
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    with open('../router/group_router_prob.pkl', 'rb') as f:
        original_dict = pickle.load(f)

    group_router_dic = {}
    for key, value in original_dict.items():
        new_list = [0] * 5
        new_list[value] = 1
        group_router_dic[key] = new_list


    df = pd.read_csv('../../processed_data/Detailed_PHQ8_Labels.csv')
    all_labels = dict(zip(df['Participant_ID'], df['PHQ_8Total']))

    test_X, test_y, user_names = prepare_test_data(sample_length, all_labels)

    all_out = []

    experts = [0, 1, 2, 3, 4]

    for expert in experts:
        min_score, max_score = expert * 5, expert * 5 + 4
        keys_in_range = []
        values_in_range = []
        for key, value in all_labels.items():
            if min_score <= value <= max_score:
                keys_in_range.append(key)
                values_in_range.append(value - expert * 5)

        distribution = Counter(values_in_range)
        most_frequent = float(distribution.most_common(1)[0][1])

        augmented_values = [int(most_frequent / distribution[i])*scale if distribution[i] != 0 else 0 for i in range(5)]

        tensor_X_train, tensor_y_train = prepare_training_data(all_labels, augmented_values, sample_length, keys_in_range, expert)

        trainDataset = TensorDataset(tensor_X_train, tensor_y_train)
        trainLoader = DataLoader(trainDataset, batch_size=32, shuffle=True)

        trained_model = train_model(trainLoader, train_epochs)
        torch.save(trained_model.state_dict(), str(expert)+'.pth')

        y_true, y_pred = cnn_evaluate(trained_model, test_X, test_y, expert, group_router_dic, user_names)
        all_out.append(y_pred)

    y_true = np.where(np.array(test_y) < 10, 0, 1)
    phq_pred = np.sum(np.array(all_out), axis=0)
    final_pred = np.where(phq_pred < 10, 0, 1)
    print('pred phq score:', phq_pred)
    print('MAE', mean_absolute_error(test_y, phq_pred), 'RMSE', np.sqrt(mean_squared_error(test_y, phq_pred)))
    print(classification_report(y_true, final_pred, digits=4, zero_division=0.0))
    return None

params = {'scale': 5, 'train_epochs': 11}
train_epochs = params['train_epochs']
scale = params['scale']
main_run(train_epochs, scale)

pred phq score: [ 7  7 11 20 10  4 20 10  4 10 11  7  7  7  4 10 11  7  7  4 11  4  7 11
  4 11  4  7 11  7  7 11  7  7  4]
MAE 4.6571428571428575 RMSE 5.774327419090222
              precision    recall  f1-score   support

           0     0.8571    0.7826    0.8182        23
           1     0.6429    0.7500    0.6923        12

    accuracy                         0.7714        35
   macro avg     0.7500    0.7663    0.7552        35
weighted avg     0.7837    0.7714    0.7750        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
