In [1]:
from tbparse import SummaryReader
import torch
# import required module
import os 
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
log_dir = "../prompt-based-models/logs/"

In [20]:
class TB_DataProcessor:
    
    '''
    Class to handle the tensorboard events files produced by this repos experiments
    
    
    '''
    def __init__(self,log_dir):
        # set the log dir 
        self.log_dir = log_dir
        
        # concatenate all the log files in the directory 
        self.all_logs = self.concatenate_all_logs(self.log_dir)       
            
    
    def concatenate_all_logs(self, log_dir):
        
        '''
        Function to read in, organise and augment tensorboard events files by dynamically reading in and 
        adding variables best on filenames - as per the setup of the training
        
        
        '''
       
        print("concatenating all log files!")
        # empy list to fill
        dfs = []
        # that directory
        for root, dirs, files in tqdm(os.walk(log_dir), desc = "Finding logs!"):
            for filename in files:                
                full_path = os.path.join(root, filename) 
                # the file name for tensorboard logs is horrible so just regex for tfevents
                if "tfevents" in full_path:    
                                            
                    # now parse the tensorboard summary data
                    reader = SummaryReader(full_path)
                    df = reader.scalars
                
                    # get the model paramters based on filename
                    
                    # now add the training type e.g. prompt learning or traditional                       
                        
                    # can also pull the dataset/task name from the filepath
                    df["training_type"] = full_path.split("/")[1]
                    df["task"] = full_path.split("/")[3]

                    # Was plm frozen?
                    if 'frozen_plm' in full_path:
                        df["plm_frozen"] = True
                    else:
                        df["plm_frozen"] = False

                    # which template type
                    if 'tempmanual' in full_path:
                        df["temp_type"] = "manual"
                    elif 'tempsoft' in full_path:
                        df["temp_type"] = "soft"
                    
                    elif 'tempmixed' in full_path:
                        df["temp_type"] = "mixed" 
                    # which verbalizer type 
                    if 'verbmanual' in full_path:
                        df["verb_type"] = "manual"
                    elif 'verbsoft' in full_path:
                        df["verb_type"] = "soft"
                    
                    df.reset_index(inplace = True, drop=True)
                    # append individual df to list
                    dfs.append(df)    
        
        # concat all the dfs
        all_logs = pd.concat(dfs)        
            
        return all_logs
    
    def extract_metric(self, mode = "all",metrics = "all"):
        '''
        Function to pull certain specified features from the full logs dataframes
        
        Args: 
            mode: The dataset you want metrics for e.g. train/valid/test
            metrics: The metric you want to look at e.g. f1, precision etc.
        '''
        # get the logs to work with
        all_logs = self.all_logs
        if metrics == "all" and mode == "all":
            return all_logs
        
        # now for cases where mode is specific        
        elif metrics == "all" and mode != "all": 
            metrics_df = all_logs[all_logs["tag"].str.contains(mode)]            
        
      
        # cases where metric is specific but mode is all
        elif metrics != "all" and mode == "all":
            metrics_df =  all_logs[all_logs["tag"].str.contains(metrics)]
        
        # now for cases where both metric and mode are specific
        elif metrics != "all" and mode != "all":
            metrics_df = all_logs[(all_logs["tag"].str.contains(metrics)) & (all_logs["tag"].str.contains(mode))]    
            
        
        
        else:
            raise NotImplementedError
            
        # reset index and return
        metrics_df.reset_index(inplace=True, drop=True)
        return metrics_df
        
        

# load prompt logs

In [21]:
# instantiate the TB_DataProcessor and provide log dir for the prompt based models
prompt_tb_data_processor = TB_DataProcessor(log_dir = "../prompt-based-models/logs/icd9_triage/")

concatenating all log files!


Finding logs!: 28it [00:03,  8.48it/s]


In [22]:
prompt_tb_data_processor.all_logs

Unnamed: 0,step,tag,value,training_type,task,plm_frozen,temp_type,verb_type
0,0,train/batch_loss,2.553897,prompt-based-models,icd9_triage,False,manual,soft
1,1,train/batch_loss,2.464560,prompt-based-models,icd9_triage,False,manual,soft
2,2,train/batch_loss,2.555808,prompt-based-models,icd9_triage,False,manual,soft
3,3,train/batch_loss,2.480603,prompt-based-models,icd9_triage,False,manual,soft
4,4,train/batch_loss,2.439957,prompt-based-models,icd9_triage,False,manual,soft
...,...,...,...,...,...,...,...,...
11999,9,valid/recall,0.938343,prompt-based-models,icd9_triage,True,soft,manual
12000,0,zero_shot/accuracy,0.132030,prompt-based-models,icd9_triage,True,soft,manual
12001,0,zero_shot/f1,0.062982,prompt-based-models,icd9_triage,True,soft,manual
12002,0,zero_shot/precision,0.444971,prompt-based-models,icd9_triage,True,soft,manual


In [23]:
# just look at f1
prompt_metrics = prompt_tb_data_processor.extract_metric(mode = "valid", metrics = "f1")

In [24]:
prompt_metrics.groupby(["verb_type","temp_type"])["value"].max()

verb_type  temp_type
manual     manual       0.946914
           mixed        0.948727
           soft         0.949100
soft       manual       0.949605
           mixed        0.949256
           soft         0.946888
Name: value, dtype: float64

In [29]:
prompt_metrics[(prompt_metrics["temp_type"]=="manual") &(prompt_metrics["verb_type"]=="manual")]["value"]

30    0.943957
31    0.945584
32    0.934010
33    0.944772
34    0.945066
35    0.944072
36    0.943878
37    0.943819
38    0.946914
39    0.945518
Name: value, dtype: float64

In [None]:
# prompt_metrics[(prompt_metrics["temp_type"]=="soft") &(prompt_metrics["verb_type"]=="soft")]

# same for traditional learning

In [25]:
# instantiate the TB_DataProcessor and provide log dir for the prompt based models
trad_tb_data_processor = TB_DataProcessor(log_dir = "../clinical-longformer/logs/")

concatenating all log files!


Finding logs!: 6it [00:00, 194.94it/s]


In [26]:
trad_tb_data_processor.all_logs

Unnamed: 0,step,tag,value,training_type,task,plm_frozen
0,49,epoch,0.000000,clinical-longformer,icd9_triage,False
1,99,epoch,0.000000,clinical-longformer,icd9_triage,False
2,149,epoch,0.000000,clinical-longformer,icd9_triage,False
3,199,epoch,0.000000,clinical-longformer,icd9_triage,False
4,249,epoch,0.000000,clinical-longformer,icd9_triage,False
...,...,...,...,...,...,...
109,0,valid/f1,0.944510,clinical-longformer,icd9_triage,False
110,0,valid/prec,0.062500,clinical-longformer,icd9_triage,False
111,0,valid/prec,0.948700,clinical-longformer,icd9_triage,False
112,0,valid/recall,0.250000,clinical-longformer,icd9_triage,False


In [27]:
# just look at f1
trad_metrics = trad_tb_data_processor.extract_metric(mode = "valid", metrics = "f1")

In [28]:
trad_metrics

Unnamed: 0,step,tag,value,training_type,task,plm_frozen
0,0,valid/f1,0.1,clinical-longformer,icd9_triage,False
1,0,valid/f1,0.94451,clinical-longformer,icd9_triage,False
