In [149]:
import os
import json
import matplotlib.pyplot as plt
from datetime import datetime

class MetricsAnalyzer:
    def __init__(self, root_dir):
        self.root_dir = root_dir
    
    def read_json(self, file_path):
        with open(file_path, 'r') as file:
            return json.load(file)
    
    def get_params_string(self, params):
        summary = params.get('summary', {})
        ranking = params.get('ranking', {})
        entity = params.get('entity', {})
        relation = params.get('relation', {})

        preprocess_str = f"Preprocess: {params.get('preprocess', None)}" if params.get('preprocess', None) is not None else ''
        summary_method = f"Summary Method: {summary.get('summary_method', None)}" if summary.get('summary_method', None) is not None else ''
        summary_percentage = f"Summary Percentage: {summary.get('summary_percentage', None)}" if summary.get('summary_percentage', None) is not None else ''
        ranking_method = f"Ranking Method: {ranking.get('ranking', None)}" if ranking.get('ranking', None) is not None else ''
        ranking_perc_threshold = f"Ranking Perc Threshold: {ranking.get('ranking_perc_threshold', None)}" if ranking.get('ranking_perc_threshold', None) is not None else ''
        options_ent = f"Options Ent: {', '.join(sorted(entity.get('options_ent', [None])))}" if sorted(entity.get('options_ent', [None])) != [None] else ''
        options_rel = f"Options Rel: {', '.join(sorted(relation.get('options_rel', [None])))}" if sorted(relation.get('options_rel', [None])) != [None] else ''
        local_rm = f"Local RM: {relation.get('local_rm', None)}" if relation.get('local_rm', None) is not None else ''
        rebel_model = f"Rebel Model: {relation.get('rebel_model', None)}" if relation.get('rebel_model', None) is not None else ''
        
        params_list = [preprocess_str, summary_method, summary_percentage, ranking_method, ranking_perc_threshold, options_ent, options_rel, local_rm, rebel_model]
        params_str = '\n'.join(filter(None, params_list))

        return params_str
    
    def calculate_duration(self, start, end):
        start_time = datetime.fromisoformat(start)
        end_time = datetime.fromisoformat(end)
        return (end_time - start_time).total_seconds()
    
    def normalize_times(self, times, num_files):
        return [time / num_files for time in times]
    
    def plot_times(self, file_ids, normalized_times, params_str):
        fig_dir = 'plot/time_folder'
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)
        params_str_clean = params_str.replace('\n', ' ')
        plt.figure(figsize=(10, 6))
        plt.bar(file_ids, normalized_times, color='skyblue')
        plt.xlabel('File ID')
        plt.ylabel('Normalized Time (seconds)')
        plt.title(f'Normalized Time for each file\n{params_str}', fontsize=14, fontweight='bold')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(fig_dir, f"{params_str_clean}_time_plot.png"))
        plt.close()

    def analyze(self):
        metrics_by_folder = {}
        params_str = ''  # Initialize params_str here
        
        for root, dirs, _ in os.walk(self.root_dir):
            for dir_name in dirs:
    
                metrics_path = os.path.join(root, dir_name, 'metrics.json')
                logs_path = os.path.join(root, dir_name, 'logs.json')
                params_path = os.path.join(root, dir_name, 'params.json')
                
                if all(os.path.exists(p) for p in [metrics_path, logs_path, params_path]):
                    metrics = self.read_json(metrics_path)
                    logs = self.read_json(logs_path)
                    params = self.read_json(params_path)
                    
                    params_str = self.get_params_string(params)  # Update params_str here
                    
                    times = []
                    file_ids = []
                    for file_id, log in logs.items():
                        if file_id != 'finished':
                            start = log['start']
                            end = log['end']
                            duration = self.calculate_duration(start, end)
                            times.append(duration)
                            file_ids.append(file_id)
                    
                    num_files = len(file_ids)
                    normalized_times = self.normalize_times(times, num_files)
                    self.plot_times(file_ids, normalized_times, params_str)

In [150]:
# Usage
root_dir = '/Users/martina/Desktop/concept_map/experiments'
analyzer = MetricsAnalyzer(root_dir)
analyzer.analyze()