In [1]:
import pandas as pd
import glob
import os
from utils import util
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
def get_attention_weight_hist(run_name, epoch, legend_list=None):
    root_dir = "/mnt/iot-qnap3/mochida/medical-care/emotionestimation/reports/PIMD_A"
    target_dir = root_dir + "/" + run_name + f"/epoch{epoch}"
    att_weights_list = pd.DataFrame()
    
    csv_files = glob.glob(os.path.join(target_dir, "**/", "*attention_weights.csv"), recursive=True)
    csv_files2 = glob.glob(os.path.join(target_dir, "**/", "*pred.csv"), recursive=True)
    for csv_file, csv_file2 in zip(csv_files, csv_files2):
        weights = pd.read_csv(csv_file)
        preds = pd.read_csv(csv_file2)
        
        # weights and preds share same columns 'img_path'
        # so we can merge them
        weights = weights.merge(preds, on='img_path')
        att_weights_list = pd.concat([att_weights_list, weights])
     
    att_weights_list = att_weights_list.sort_values(by=['img_path'])   
    att_weights_list = att_weights_list.reset_index(drop=True)
    
    # plot histogram stream0 and stream1 if emo_pred == 1
    fig = make_subplots(rows=1, cols=2, subplot_titles=('pred positive', 'pred negative'))
    color_list = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
    
    att_weights = att_weights_list[att_weights_list['emo_pred'] == 1]
    for i in range(len(att_weights.columns) - 4):
        fig.add_trace(go.Histogram(x=att_weights[f'stream{i}'], name=legend_list[i], marker=dict(color=color_list[i])), row=1, col=1)
        
    att_weights = att_weights_list[att_weights_list['emo_pred'] == 0]
    for i in range(len(att_weights.columns) - 4):
        fig.add_trace(go.Histogram(x=att_weights[f'stream{i}'], name=legend_list[i], marker=dict(color=color_list[i]), showlegend=False), row=1, col=2)
    fig.update_traces(opacity=0.5, histnorm='probability', xbins=dict(start=0, end=1, size=0.05))
    fig.update_yaxes(range=[0, 1], title_text='probability')
    fig.update_xaxes(title_text='attention weight')
    fig.update_layout(
        barmode='overlay',
        legend=dict(
            orientation="h"
            ),
        margin=dict(l=0, r=10, t=30, b=0)
    )
    
    fig.show()
    
    return att_weights_list

In [3]:
def get_attention_weight_pattern(run_name, epoch, video_name, legend_list):
    root_dir = "/mnt/iot-qnap3/mochida/medical-care/emotionestimation/reports/PIMD_A"
    target_dir = root_dir + "/" + run_name + f"/epoch{epoch}"
    att_weights_list = pd.DataFrame()
    
    csv_files = glob.glob(os.path.join(target_dir, "**/", "*attention_weights.csv"), recursive=True)
    csv_files2 = glob.glob(os.path.join(target_dir, "**/", "*pred.csv"), recursive=True)
    for csv_file, csv_file2 in zip(csv_files, csv_files2):
        weights = pd.read_csv(csv_file)
        preds = pd.read_csv(csv_file2)

        weights = weights.merge(preds, on='img_path')
        att_weights_list = pd.concat([att_weights_list, weights])
     
    att_weights_list = att_weights_list.sort_values(by=['img_path'])   
    att_weights_list = att_weights_list.reset_index(drop=True)
    
    att_weights_list['video_name'], att_weights_list['frame_num'] = zip(*att_weights_list['img_path'].map(util.get_video_name_and_frame_num))
    
    att_weights_list = att_weights_list[att_weights_list['video_name'] == video_name]
    att_weights_list = att_weights_list.reset_index(drop=True)
    
    # plot stream
    fig = go.Figure()
    
    for i in range(len(att_weights_list.columns) - 6):
        fig.add_trace(go.Scatter(x=att_weights_list['frame_num'], y=att_weights_list[f'stream{i}'], name=legend_list[i], mode='markers'))
    fig.update_yaxes(range=[0, 1], title_text='attention weight')
    fig.update_xaxes(title_text='frame id')
    fig.update_layout(
        legend=dict(
            orientation="h"
            ),
        margin=dict(l=0, r=10, t=30, b=0)
    )
    fig.show()

In [17]:
a = get_attention_weight_hist("4_d_ag_wsum-intlogits-1dcnn_ws300-ss3", 3, ['AU', 'Gaze'])

In [5]:
get_attention_weight_pattern("4_d_agh_wsum", 5, 'video21', ['AU', 'Gaze', 'HP'])