In [1]:
import re
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import os

def extract_metrics(line):
    """
    Extract epoch number and test loss from a log line.
    Returns a tuple of (epoch_number, test_loss) if found, None otherwise.
    """
    pattern = r'Epoch (\d+):.*test_loss=(\d+\.\d+)'
    match = re.search(pattern, line)
    
    if match:
        epoch = int(match.group(1))
        test_loss = float(match.group(2))
        return (epoch, test_loss)
    return None

def process_err_file(file_path):
    """
    Process an .err file and return lists of epochs and corresponding test losses.
    Returns None if no valid metrics are found.
    """
    epochs = []
    test_losses = []
    has_valid_data = False
    
    try:
        with open(file_path, 'r') as f:
            for line in f:
                result = extract_metrics(line)
                if result:
                    has_valid_data = True
                    epoch, test_loss = result
                    epochs.append(epoch)
                    test_losses.append(test_loss)
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None
    
    # Return None if no valid data was found
    if not has_valid_data:
        return None
    
    return epochs, test_losses

def plot_metrics(all_data, output_dir):
    """
    Create a plot of test loss vs epochs for multiple files.
    
    Parameters:
    all_data: dict of format {filename: (epochs, test_losses)}
    output_dir: directory to save the plot
    """
    plt.figure(figsize=(12, 8))
    
    # Plot each file's data with different color/style
    for filename, (epochs, test_losses) in all_data.items():
        # Use just the base filename without path and extension for the legend
        label = Path(filename).stem
        plt.plot(epochs, test_losses, marker='o', label=label, linewidth=2, markersize=4)
        plt.ylim(0, 1)
    
    plt.xlabel('Epoch')
    plt.ylabel('Test Loss')
    plt.title('Test Loss vs. Epoch - All Files')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Adjust layout to prevent legend cutoff
    plt.tight_layout()
    
    # Save the plot
    output_path = os.path.join(output_dir, 'combined_loss_plot.png')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    print(f"Plot saved as: {output_path}")

def process_directory(directory_path):
    """
    Process all .err files in the specified directory.
    Returns a dictionary of valid files and their data.
    """
    # Convert to Path object and resolve to absolute path
    dir_path = Path(directory_path).resolve()
    
    # Create output directory for plots if it doesn't exist
    output_dir = dir_path / 'plots'
    output_dir.mkdir(exist_ok=True)
    
    # Dictionary to store data from all valid files
    valid_files_data = {}
    
    # Find all .err files in the directory
    err_files = list(dir_path.glob('*.err'))
    
    if not err_files:
        print(f"No .err files found in {dir_path}")
        return
    
    print(f"Found {len(err_files)} .err files")
    
    # Process each file
    for file_path in err_files:
        print(f"\nProcessing: {file_path.name}")
        
        result = process_err_file(file_path)
        if result:
            epochs, test_losses = result
            print(f"✓ Valid data found: {len(epochs)} epochs")
            valid_files_data[str(file_path)] = (epochs, test_losses)
        else:
            print(f"✗ No valid data found in {file_path.name}")
    
    return valid_files_data, output_dir

def print_summary(valid_files_data):
    """
    Print a summary of the processed data.
    """
    print("\n=== Summary ===")
    print(f"Total valid files processed: {len(valid_files_data)}")
    
    for filepath, (epochs, losses) in valid_files_data.items():
        filename = Path(filepath).name
        min_loss = min(losses)
        min_loss_epoch = epochs[losses.index(min_loss)]
        print(f"\n{filename}:")
        print(f"  Number of epochs: {len(epochs)}")
        print(f"  Best test loss: {min_loss:.6f} (Epoch {min_loss_epoch})")

def main():
    # Get the current directory if no argument is provided
    directory_path = '/home/xj2173'
    
    # Process all .err files in the directory
    valid_files_data = process_directory(directory_path)
    
    if valid_files_data:
        files_data, output_dir = valid_files_data
        
        if files_data:
            # Create the combined plot
            plot_metrics(files_data, output_dir)
            
            # Print summary of processed data
            print_summary(files_data)
        else:
            print("\nNo valid data found in any of the .err files.")

In [2]:
main()

Found 14 .err files

Processing: slurm_53939750.err
✓ Valid data found: 24 epochs

Processing: slurm_53847581.err
✗ No valid data found in slurm_53847581.err

Processing: slurm_53825437.err
✗ No valid data found in slurm_53825437.err

Processing: slurm_53965604.err
✓ Valid data found: 17 epochs

Processing: slurm_53940220.err
✓ Valid data found: 30 epochs

Processing: slurm_53939362.err
✓ Valid data found: 25 epochs

Processing: slurm_53875906.err
✓ Valid data found: 29 epochs

Processing: slurm_53942661.err
✓ Valid data found: 19 epochs

Processing: slurm_53847579.err
✗ No valid data found in slurm_53847579.err

Processing: slurm_53912964.err
✓ Valid data found: 35 epochs

Processing: slurm_53944641.err
✓ Valid data found: 30 epochs

Processing: slurm_53847562.err
✗ No valid data found in slurm_53847562.err

Processing: slurm_53948431.err
✓ Valid data found: 25 epochs

Processing: slurm_53998446.err
✗ No valid data found in slurm_53998446.err
Plot saved as: /home/xj2173/plots/combined