In [46]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import os
import sys
# TODO: update this path
sys.path.append('/nfshomes/vla/low_bit_vision/lavis_clone/lavis')

# args parser for blip-2 runs
from args_parser import args_parser

import ast
from glob import glob
import re
import mmap

In [48]:
def load_result(file_path):
    with open(file_path, 'r+') as f:
        # read last line and grab metrics info as a dict
        result = ast.literal_eval(f.readlines()[-1].split('[INFO] ')[1])
        
        # grab model size output and add to results
        data = mmap.mmap(f.fileno(), 0)
        model_size = float(re.search(rb'\[Model Size\]: (.*)', data).group(1))
        result['model_size'] = model_size
        
    return result

In [49]:
# baseline result
baseline_result = load_result(os.path.join('..', 'results', 'blip2_flickr', 'blip2_flickr_baseline'))
baseline_result

{'txt_r1': 97.6,
 'txt_r5': 100.0,
 'txt_r10': 100.0,
 'txt_r_mean': 99.2,
 'img_r1': 89.74,
 'img_r5': 98.18,
 'img_r10': 98.94,
 'img_r_mean': 95.62,
 'r_mean': 97.41,
 'agg_metrics': 99.2,
 'model_size': 4782.180084}

In [50]:
parser = args_parser()
results_dir = os.path.join('..', 'results', 'blip2_flickr', 'uniform_quant')
results_dir

'../results/blip2_flickr/uniform_quant'

In [54]:
df_results = pd.DataFrame()

for folder in os.listdir(results_dir):
    path = os.path.join(results_dir, folder)
    
    cli_args = []
    with open(os.path.join(path, 'now.txt'), 'r') as f:
        for line in f:
            # skip the torch.distributed args
            params = line.split()[5:]
            args = vars(parser.parse_args(params))
            cli_args.append(args)
    
    gather = []
    # grab all files starting with a number (results)
    for result_path in glob(os.path.join(path, '[0-9]*.txt')):
        
        file_name = os.path.basename(result_path)
        # print(file_name)
        
        index = int(re.search(r'(\d)+', file_name).group(0))
        # print(index)
        
        result = load_result(result_path)
        result['index'] = index
        gather.append(result)
    
    df_args = pd.DataFrame(cli_args)
    
    # need to sort to match up with args from now.txt
    df_metrics = pd.DataFrame(gather)
    df_metrics = df_metrics.sort_values(by='index')
    df_metrics = df_metrics.set_index('index')
    
    df_result = pd.concat([df_args, df_metrics], axis = 1)
    df_results = pd.concat([df_results, df_result], axis = 0)


df_results = df_results.drop(['cfg_path', 'options'], axis = 1)

In [60]:
# add baseline results and convert None values to np.nan
df_results = pd.merge(pd.DataFrame([baseline_result]),
                      df_results,
                      how = 'outer')

df_results = df_results.fillna(value=np.nan)

df_results.sort_values(by = 'model_size', ascending = False)

Unnamed: 0,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,agg_metrics,...,qformer_text_ff_weight_bits,qformer_img_ff_modules,qformer_img_ff_weight_bits,qformer_cls_modules,qformer_cls_transform_weight_bits,qformer_cls_decoder_weight_bits,output_modules,vision_proj_weight_bits,text_proj_weight_bits,itm_head_weight_bits
10,97.6,100.0,100.0,99.2,89.74,98.18,98.94,95.62,97.41,99.2,...,,,,,,,,,,
13,97.8,100.0,100.0,99.266667,89.62,98.2,98.96,95.593333,97.43,99.266667,...,,,,,,,,,,
15,97.8,100.0,100.0,99.266667,89.82,98.16,98.96,95.646667,97.456667,99.266667,...,,,,,,,,,,
8,97.5,100.0,100.0,99.166667,89.24,98.08,98.86,95.393333,97.28,99.166667,...,,,,,,,,,,
5,15.1,26.3,33.4,24.933333,19.88,33.8,40.4,31.36,28.146667,24.933333,...,,,,,,,,,,
1,0.1,0.5,0.8,0.466667,0.14,0.64,1.2,0.66,0.563333,0.466667,...,,,,,,,,,,
11,97.6,100.0,100.0,99.2,89.74,98.2,98.96,95.633333,97.416667,99.2,...,,,,,,,,,,
9,97.6,100.0,100.0,99.2,89.42,98.16,99.02,95.533333,97.366667,99.2,...,,,,,,,,,,
7,63.4,82.1,87.9,77.8,62.98,82.98,88.88,78.28,78.04,77.8,...,,,,,,,,,,
0,0.1,0.3,0.6,0.333333,0.14,0.76,1.54,0.813333,0.573333,0.333333,...,,,,,,,,,,


In [59]:
df_results.to_csv('blip2_flickr_results.csv', index = None)