In [1]:
import re
import plotly.graph_objects as go

def parse_loss_file(file_path):
    train_losses = []
    vali_losses = []
    test_losses = []

    with open(file_path, 'r') as f:
        for line in f:
            if "Train Loss:" in line and "Vali Loss:" in line and "Test Loss:" in line:
                try:
                    train_loss_match = re.search(r"Train Loss: (\d+\.\d+)", line)
                    vali_loss_match = re.search(r"Vali Loss: (\d+\.\d+)", line)
                    test_loss_match = re.search(r"Test Loss: (\d+\.\d+)", line)

                    train_losses.append(float(train_loss_match.group(1)))
                    vali_losses.append(float(vali_loss_match.group(1)))
                    test_losses.append(float(test_loss_match.group(1)))
                except Exception as e:
                    print(f"Skipping line due to parsing error: {line.strip()} - {e}")
    return train_losses, vali_losses, test_losses

if __name__ == "__main__":
    file_paths = ["ETTm1(informer).txt", "ETTm1(crossformer).txt", "crossformer_output.txt", "informer.txt"]
    
    fig = go.Figure()

    for file_path in file_paths:
        train_losses, vali_losses, test_losses = parse_loss_file(file_path)
        
        if len(train_losses)==0:
            print("No data found to plot. Please check the log file format.")
        else:
            model_name = file_path.split(".")[0]

            fig.add_trace(go.Scatter(x=list(range(len(train_losses))), y=train_losses, mode='lines+markers', name=f'{model_name}: Train Loss'))
            fig.add_trace(go.Scatter(x=list(range(len(train_losses))), y=vali_losses, mode='lines+markers', name=f'{model_name}: Validation Loss'))
            fig.add_trace(go.Scatter(x=list(range(len(train_losses))), y=test_losses, mode='lines+markers', name=f'{model_name}: Test Loss'))

    fig.update_layout(
        title=f'Training, Validation, and Test Losses Over Epochs',
        xaxis_title='Epoch',
        yaxis_title='Loss',
        legend_title='Loss Type',
        yaxis_range=[0, 0.37]
    )
    fig.show()
    fig.write_html("plot.html")