In [1]:
import sys
from pathlib import Path
project_root = Path().cwd().resolve().parent
sys.path.insert(0, str(project_root))

In [13]:
import sqlite3

import os, argparse, path_config, shutil
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from loguru import logger

from tqdm.notebook import tqdm
import time

import torch
import torch.nn.functional as F
from torch_geometric.utils import to_networkx, k_hop_subgraph
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from graph.train_val import train_gat, validation_gat, FocalLoss

from graph._multimodal_model_bilstm.GAT_explanation import GATJKClassifier as BiLSTMV2GAT
from graph.multimodal_topic_bilstm_proxy.dataset_explanation import make_graph as TopicProxyBiLSTM_make_graph

plt.rcParams['font.family'] ='Malgun Gothic'
plt.rcParams['axes.unicode_minus'] =False

In [7]:
logger.remove()
logger.add(
  sys.stdout,
  colorize=True,
  format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
)

V2_MODEL = {
  'multimodal_topic_bilstm_proxy':BiLSTMV2GAT
}

MAKE_GRAPH = {
  'multimodal_topic_bilstm_proxy':TopicProxyBiLSTM_make_graph
}

In [8]:
def fetch_from_db(db_path):
  con = sqlite3.connect(db_path)
  cursor = con.cursor()
  cursor.execute('''
    SELECT param_name, param_value 
    FROM trial_params
    WHERE trial_id = (
      SELECT trial_id
      FROM trial_values
      ORDER BY value DESC
      LIMIT 1
    );
  ''')
  best_hyperparams_list = cursor.fetchall()
  best_hyperparams_dict = {}

  for k, v in best_hyperparams_list:
    if k not in ['batch_size', 'focal_alpha', 'focal_gamma', 'lr', 'optimizer', 'weight_decay']:
      if k in ['use_text_proj', 'use_attention']:
        best_hyperparams_dict[k] = True if v==0.0 else False
      elif k in ['num_layers', 'bilstm_num_layers']:
        best_hyperparams_dict[k] = int(v)
      else:
        best_hyperparams_dict[k] = v

  cursor.execute('''
    SELECT value
    FROM trial_values
    ORDER BY value DESC
    LIMIT 1
  ''')
  best_f1 = cursor.fetchone()[0]
  
  return best_hyperparams_dict, best_f1

In [9]:
model_dir = 'checkpoints_optuna'
model_dir_ = 'multimodal_topic_bilstm_proxy_v2'
save_dir = 'graph_visualization'
save_dir_ = 'multimodal_topic_bilstm_proxy_v2_id_405_ipynb'
mode = 'multimodal_topic_bilstm_proxy'
version = 2

best_model_path = os.path.join(path_config.ROOT_DIR, model_dir, model_dir_, 'best_model.pth')
db_path = os.path.join(path_config.ROOT_DIR, model_dir, model_dir_, 'logs', 'optuna_study.db')
assert os.path.exists(best_model_path) and os.path.exists(db_path), logger.error("Model path is wrong. Try again.")

In [10]:
logger.info(f"Processing data (Mode: {mode})")

test_df = pd.read_csv(os.path.join(path_config.DATA_DIR, 'full_test_split.csv'))
test_id = test_df.Participant_ID.tolist()
test_label = test_df.PHQ_Binary.tolist()

if "multimodal" in mode:
  logger.info(f"Doing with multimodal mode")
  test_graphs, dim_list = MAKE_GRAPH[mode](
    ids = test_id,
    labels = test_label,                   # Temporary Label
    model_name = 'sentence-transformers/all-MiniLM-L6-v2',
    use_summary_node = True,
    t_t_connect = False,
    v_a_connect = False
  )

  t_dim = dim_list[0]
  v_dim = dim_list[1]
  a_dim = dim_list[2]

else:
  logger.info(f"Doing with non-multimodal mode")
  graphs, dim_list, extras = MAKE_GRAPH[mode](
    ids = [id],
    labels = 1,                   # Temporary Label
    model_name = 'sentence-transformers/all-MiniLM-L6-v2',
    use_summary_node = True,
    t_t_connect = False,
    explanation = True
  )

  t_dim = dim_list[0]
  if 'bimodal' in mode:
    v_dim = dim_list[1]

[32m13:25:39[0m | [1mINFO    [0m | [1mProcessing data (Mode: multimodal_topic_bilstm_proxy)[0m
[32m13:25:39[0m | [1mINFO    [0m | [1mDoing with multimodal mode[0m
[32m13:25:39[0m | [1mINFO    [0m | [1mGetting your model[0m
[32m13:25:42[0m | [1mINFO    [0m | [1mModel loaded[0m
[32m13:25:42[0m | [1mINFO    [0m | [1mSwitching CSV into Graphs[0m


Dataframe -> Graph: 100%|██████████| 45/45 [01:04<00:00,  1.43s/it]


In [11]:
best_hyperparams_dict, best_f1 = fetch_from_db(db_path)

logger.info(f"Best Params")
for k, v in best_hyperparams_dict.items():
  logger.info(f"  - {k}: {v}")
logger.info(f"=> F1-score: {best_f1}")

[32m13:26:46[0m | [1mINFO    [0m | [1mBest Params[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - a_dropout: 0.39294858998728843[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - bilstm_num_layers: 2[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - g_dropout: 0.24654580705928375[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - num_layers: 3[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - t_dropout: 0.25237807640094945[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - use_attention: True[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - use_text_proj: False[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m  - v_dropout: 0.3430548105111857[0m
[32m13:26:46[0m | [1mINFO    [0m | [1m=> F1-score: 0.7586206896551724[0m


In [12]:
logger.info("==============================")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Loading your model (Device: {device})")

assert version in [1,2], logger.error("Version should be int type 1 or 2")

if version == 2:
  model_dict = V2_MODEL

dropout_dict = {
  'text_dropout':best_hyperparams_dict.get('t_dropout', 0.0),
  'graph_dropout':best_hyperparams_dict.get('g_dropout', 0.0),
  'vision_dropout':best_hyperparams_dict.get('v_dropout', 0.0),
  'audio_dropout':best_hyperparams_dict.get('a_dropout', 0.0)
}


model = model_dict[mode](
  text_dim=t_dim,
  vision_dim=v_dim,
  audio_dim=a_dim,
  hidden_channels=256 if best_hyperparams_dict['use_text_proj'] else t_dim,
  num_layers=best_hyperparams_dict['num_layers'],
  bilstm_num_layers=best_hyperparams_dict['bilstm_num_layers'],
  num_classes=2,
  dropout_dict=dropout_dict,
  heads=8,
  use_attention=best_hyperparams_dict['use_attention'],
  use_summary_node=True,
  use_text_proj=best_hyperparams_dict['use_text_proj']
).to(device)

best_model_state_dict = torch.load(best_model_path)
model.load_state_dict(best_model_state_dict)

[32m13:26:49[0m | [1mINFO    [0m | [1mLoading your model (Device: cuda)[0m


<All keys matched successfully>

In [17]:
test_loader = DataLoader(test_graphs, batch_size=8, shuffle=False)

model.eval()
val_acc, val_f1 = validation_gat(
  val_loader=test_loader,
  model=model,
  device=device,
  num_classes=2
)
logger.info(f"Validation_F1_Score: {val_f1}")

Validation: 100%|█████████████████████████████████████████████████████████████| 6/6 [00:01<00:00,  4.65it/s, Acc=0.8444]

[32m13:30:02[0m | [1mINFO    [0m | [1mValidation_F1_Score: 0.7586206896551724[0m



