In [1]:
import os
import re
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List

In [None]:
sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)

In [7]:
def parse_log_file(filepath: str) -> Dict:
    """
    Parse a single log file to extract decay rate, training losses over steps, and final metrics.
    """
    # Extract decay rate from filename
    filename = os.path.basename(filepath)
    decay_match = re.search(r'run_param_([\d.]+|0)\.log', filename)
    if not decay_match:
        print(f"Warning: Could not extract decay rate from filename {filename}. Skipping.")
        return None
    decay_rate = float(decay_match.group(1)) if decay_match.group(1) != '0' else 0.0
    
    losses = []  # List of (step, pos_loss, neg_loss, avg_loss, temp)
    valid_metrics = {}
    test_metrics = {}
    
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        # Extract training losses at steps
        step_match = re.search(r'Training average positive_sample_loss at step (\d+): ([\d.]+)', line)
        if step_match:
            step = int(step_match.group(1))
            pos_loss = float(step_match.group(2))
            losses.append({'step': step, 'pos_loss': pos_loss})
            continue
        
        neg_match = re.search(r'Training average negative_sample_loss at step (\d+): ([\d.]+)', line)
        if neg_match:
            step = int(neg_match.group(1))
            neg_loss = float(neg_match.group(2))
            for loss_dict in losses:
                if loss_dict['step'] == step:
                    loss_dict['neg_loss'] = neg_loss
                    break
        
        avg_match = re.search(r'Training average loss at step (\d+): ([\d.]+)', line)
        if avg_match:
            step = int(avg_match.group(1))
            avg_loss = float(avg_match.group(2))
            for loss_dict in losses:
                if loss_dict['step'] == step:
                    loss_dict['avg_loss'] = avg_loss
                    break
        
        temp_match = re.search(r'Training average current_temperature at step (\d+): ([\d.]+)', line)
        if temp_match:
            step = int(temp_match.group(1))
            temp = float(temp_match.group(2))
            for loss_dict in losses:
                if loss_dict['step'] == step:
                    loss_dict['temp'] = temp
                    break
        
        # Extract valid metrics at step 4999
        valid_mrr = re.search(r'Valid MRR at step 4999: ([\d.]+)', line)
        if valid_mrr:
            valid_metrics['mrr'] = float(valid_mrr.group(1))
        valid_mr = re.search(r'Valid MR at step 4999: ([\d.]+)', line)
        if valid_mr:
            valid_metrics['mr'] = float(valid_mr.group(1))
        valid_hits1 = re.search(r'Valid HITS@1 at step 4999: ([\d.]+)', line)
        if valid_hits1:
            valid_metrics['hits1'] = float(valid_hits1.group(1))
        valid_hits3 = re.search(r'Valid HITS@3 at step 4999: ([\d.]+)', line)
        if valid_hits3:
            valid_metrics['hits3'] = float(valid_hits3.group(1))
        valid_hits10 = re.search(r'Valid HITS@10 at step 4999: ([\d.]+)', line)
        if valid_hits10:
            valid_metrics['hits10'] = float(valid_hits10.group(1))
        
        # Extract test metrics at step 4999
        test_mrr = re.search(r'Test MRR at step 4999: ([\d.]+)', line)
        if test_mrr:
            test_metrics['mrr'] = float(test_mrr.group(1))
        test_mr = re.search(r'Test MR at step 4999: ([\d.]+)', line)
        if test_mr:
            test_metrics['mr'] = float(test_mr.group(1))
        test_hits1 = re.search(r'Test HITS@1 at step 4999: ([\d.]+)', line)
        if test_hits1:
            test_metrics['hits1'] = float(test_hits1.group(1))
        test_hits3 = re.search(r'Test HITS@3 at step 4999: ([\d.]+)', line)
        if test_hits3:
            test_metrics['hits3'] = float(test_hits3.group(1))
        test_hits10 = re.search(r'Test HITS@10 at step 4999: ([\d.]+)', line)
        if test_hits10:
            test_metrics['hits10'] = float(test_hits10.group(1))
    
    # Ensure we have valid data before returning
    if not test_metrics:
        print(f"Warning: No test metrics found in {filename}. Skipping.")
        return None
    
    return {
        'decay_rate': decay_rate,
        'losses': pd.DataFrame(losses),
        'valid_metrics': valid_metrics,
        'test_metrics': test_metrics
    }

In [8]:
def collect_data(log_files: List[str]) -> List[Dict]:
    """
    Collect data from all log files, filtering out None results.
    """
    data = []
    for file in log_files:
        parsed = parse_log_file(file)
        if parsed:
            data.append(parsed)
        else:
            print(f"Skipping file {file} due to parsing issues.")
    return data

In [9]:
def create_visualizations(data: List[Dict], output_dir: str = 'visualizations'):
    """
    Create recommended visualizations and save them as PDF.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare DataFrame for final metrics
    metrics_data = []
    for d in data:
        if d['test_metrics']:  # Only include entries with test metrics
            metrics_data.append({
                'decay_rate': d['decay_rate'],
                'test_mrr': d['test_metrics'].get('mrr', None),
                'test_mr': d['test_metrics'].get('mr', None),
                'test_hits1': d['test_metrics'].get('hits1', None),
                'test_hits3': d['test_metrics'].get('hits3', None),
                'test_hits10': d['test_metrics'].get('hits10', None),
                'valid_mrr': d['valid_metrics'].get('mrr', None),
            })
    
    if not metrics_data:
        raise ValueError("No valid metrics data found to create visualizations.")
    
    metrics_df = pd.DataFrame(metrics_data)
    print("Final Metrics Table:")
    print(metrics_df)
    
    # Sort by decay_rate, ensuring it exists
    if 'decay_rate' not in metrics_df.columns:
        raise KeyError("Column 'decay_rate' not found in metrics DataFrame.")
    metrics_df = metrics_df.sort_values('decay_rate')
    
    # 1. Line Plot: Metrics vs Decay Rate
    plt.figure(figsize=(10, 6))
    sns.lineplot(x='decay_rate', y='test_mrr', data=metrics_df, marker='o', label='Test MRR')
    sns.lineplot(x='decay_rate', y='test_hits1', data=metrics_df, marker='o', label='Test HITS@1')
    sns.lineplot(x='decay_rate', y='test_hits3', data=metrics_df, marker='o', label='Test HITS@3')
    sns.lineplot(x='decay_rate', y='test_hits10', data=metrics_df, marker='o', label='Test HITS@10')
    plt.xlabel('Decay Rate')
    plt.ylabel('Metric Value')
    plt.title('Test Metrics vs Decay Rate')
    plt.xscale('log')  # Log scale for better visualization of small decay rates
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'metrics_vs_decay_line.pdf'), bbox_inches='tight')
    plt.close()
    
    # 2. Bar Chart: Clustered bars for metrics per decay rate
    melted_df = pd.melt(metrics_df, id_vars=['decay_rate'], value_vars=['test_mrr', 'test_hits1', 'test_hits3', 'test_hits10'])
    plt.figure(figsize=(12, 6))
    sns.barplot(x='decay_rate', y='value', hue='variable', data=melted_df)
    plt.xlabel('Decay Rate')
    plt.ylabel('Metric Value')
    plt.title('Test Metrics Across Decay Rates')
    plt.xticks(rotation=45)
    plt.legend(title='Metric')
    plt.savefig(os.path.join(output_dir, 'metrics_bar_chart.pdf'), bbox_inches='tight')
    plt.close()
    
    # 3. Heatmap: Metrics across decay rates
    heatmap_df = metrics_df.set_index('decay_rate')[['test_mrr', 'test_hits1', 'test_hits3', 'test_hits10']]
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_df, annot=True, cmap='viridis', fmt='.4f')
    plt.title('Heatmap of Test Metrics vs Decay Rate')
    plt.savefig(os.path.join(output_dir, 'metrics_heatmap.pdf'), bbox_inches='tight')
    plt.close()
    
    # 4. Loss Curves: Average Loss vs Step for each decay rate
    plt.figure(figsize=(12, 6))
    for d in data:
        losses_df = d['losses']
        if 'avg_loss' in losses_df.columns:
            sns.lineplot(x='step', y='avg_loss', data=losses_df, label=f'Decay: {d["decay_rate"]:.6f}')
    plt.xlabel('Training Step')
    plt.ylabel('Average Loss')
    plt.title('Average Training Loss vs Step Across Decay Rates')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.savefig(os.path.join(output_dir, 'loss_vs_step.pdf'), bbox_inches='tight')
    plt.close()
    
    # 5. Temperature Decay: Current Temperature vs Step
    plt.figure(figsize=(12, 6))
    for d in data:
        losses_df = d['losses']
        if 'temp' in losses_df.columns:
            sns.lineplot(x='step', y='temp', data=losses_df, label=f'Decay: {d["decay_rate"]:.6f}')
    plt.xlabel('Training Step')
    plt.ylabel('Current Temperature')
    plt.title('Adversarial Temperature vs Step Across Decay Rates')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.savefig(os.path.join(output_dir, 'temperature_vs_step.pdf'), bbox_inches='tight')
    plt.close()
    
    # 6. Scatter Plot: Trade-off between MRR and MR
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x='test_mrr', y='test_mr', data=metrics_df, hue='decay_rate', palette='viridis', size='decay_rate', sizes=(50, 300))
    plt.xlabel('Test MRR')
    plt.ylabel('Test MR')
    plt.title('Trade-off: Test MRR vs MR (Sized by Decay Rate)')
    plt.savefig(os.path.join(output_dir, 'mrr_vs_mr_scatter.pdf'), bbox_inches='tight')
    plt.close()

In [None]:
if __name__ == "__main__":
    log_files = glob.glob('logs/run_param_*.log')  # Adjust path if needed
    if not log_files:
        raise FileNotFoundError("No log files found matching 'run_param_*.log'. Check the directory.")
    print(f"Found {len(log_files)} log files: {log_files}")
    data = collect_data(log_files)
    if not data:
        raise ValueError("No valid data parsed from log files.")
    create_visualizations(data)

Found 9 log files: ['logs\\run_param_0.0001.log', 'logs\\run_param_0.00025.log', 'logs\\run_param_0.0005.log', 'logs\\run_param_0.00075.log', 'logs\\run_param_0.001.log', 'logs\\run_param_0.0025.log', 'logs\\run_param_0.005.log', 'logs\\run_param_0.0075.log', 'logs\\run_param_0.log']
Final Metrics Table:
   decay_rate  test_mrr      test_mr  test_hits1  test_hits3  test_hits10  \
0     0.00010  0.249737   771.457124    0.177147    0.274382     0.390648   
1     0.00025  0.242231   823.963794    0.170087    0.268176     0.382732   
2     0.00050  0.232334   900.238420    0.162220    0.255912     0.368269   
3     0.00075  0.230468   929.886812    0.160974    0.253665     0.366852   
4     0.00100  0.226376   961.894288    0.156919    0.249707     0.362650   
5     0.00250  0.224012   998.459323    0.155771    0.247826     0.354857   
6     0.00500  0.222537   989.006523    0.153670    0.245431     0.356420   
7     0.00750  0.221864  1004.425169    0.152961    0.246433     0.356420   
8