In [None]:
import matplotlib.pyplot as plt
import json
import argparse
import os

def parse_log_file(file_path):
    """Parse a log file containing JSON entries, one per line"""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line:  # Skip empty lines
                try:
                    entry = json.loads(line)
                    data.append(entry)
                except json.JSONDecodeError:
                    print(f"Warning: Could not parse line: {line}")
    return data

def plot_metrics(log_files, metrics, output_path=None, title=None, labels=None):
    """Plot specified metrics from multiple log files"""
    plt.figure(figsize=(12, 8))
    
    if labels is None:
        labels = [os.path.basename(f) for f in log_files]
    
    for i, file_path in enumerate(log_files):
        data = parse_log_file(file_path)
        
        if not data:
            print(f"Warning: No data found in {file_path}")
            continue
        
        # Extract epochs
        epochs = [entry.get('epoch', i) for i, entry in enumerate(data)]
        
        # Plot each requested metric
        for metric in metrics:
            # Try to find the metric with various prefixes
            prefixes = ['', 'train_', 'test_', 'val_']
            
            for prefix in prefixes:
                full_metric = f"{prefix}{metric}"
                if full_metric in data[0]:
                    values = [entry.get(full_metric, float('nan')) for entry in data]
                    plt.plot(epochs, values, marker='o', label=f"{labels[i]} - {full_metric}")
                    break
    
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc='best')
    
    if title:
        plt.title(title)
    else:
        plt.title(', '.join(metrics))
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path)
        print(f"Plot saved to {output_path}")
    else:
        plt.show()


: 

In [None]:
### Finetune labram
filepath = "checkpoints/finetune_dtu_labram1/log.txt"
metrics = "loss"
title = "Finetune LaBraM"
labels = None

log = parse_log_file(filepath)
plot_metrics(log,metrics,title,labels)