In [None]:
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
df = pl.read_csv("../models/losses.txt", has_header=False)
df.columns = ["Batch", "Learning_Rate", "Train Loss", "Val Loss", "Val MSE"]
df

In [None]:
# df = df.filter(pl.col("Batch") < 4000 )

In [None]:
plt.figure(figsize=(12, 6))

# Create the line plot with Seaborn
# Convert to Python native types for plotting
sns.set_style("whitegrid")
sns.lineplot(x=df["Batch"].to_list(), y=df["Train Loss"].to_list(), marker='o', label="Train Loss")
sns.lineplot(x=df["Batch"].to_list(), y=df["Val Loss"].to_list(), marker='s', label="Val Loss")
sns.lineplot(x=df["Batch"].to_list(), y=df["Val MSE"].to_list(), marker='^', label="Val MSE")

# Add titles and labels
plt.title("Training and Validation Metrics over Batchs", fontsize=14)
plt.xlabel("Batch", fontsize=12)
plt.ylabel("Loss / MSE", fontsize=12)
plt.legend(fontsize=10)

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 6))

# Plot Train Loss on the first y-axis
color1 = 'tab:blue'
ax1.set_xlabel('Batches', fontsize=12)
ax1.set_ylabel('Train Loss', fontsize=12)
line1 = ax1.plot(df["Batch"].to_list(), df["Train Loss"].to_list(), 
         marker='o', color=color1, linewidth=2, label="Train Loss")
ax1.tick_params(axis='y', labelcolor='black')  # Keep tick labels black

# Create a second y-axis that shares the same x-axis
ax2 = ax1.twinx()
color2 = 'tab:red'
ax2.set_ylabel('Val Loss', fontsize=12)
line2 = ax2.plot(df["Batch"].to_list(), df["Val Loss"].to_list(), 
         marker='s', color=color2, linewidth=2, label="Val Loss")
ax2.tick_params(axis='y', labelcolor='black')  # Keep tick labels black

# Add grid but only for one axis to avoid cluttering
ax1.grid(True, alpha=0.3)

# Add title
plt.title("Training and Validation Loss over Batchs", fontsize=14)

# Create a single legend with both lines
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc='upper right')

# Adjust layout
fig.tight_layout()
plt.savefig("../.github/losses.png", dpi=300)