In [135]:
import os
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

import sys
sys.path.insert(1, os.getenv("NOVA_HOME"))

from src.figures.plotting_utils import FONT_PATH
from matplotlib import font_manager as fm
import matplotlib

fm.fontManager.addfont(FONT_PATH)
matplotlib.rcParams['font.family'] = 'Arial'
font_size = 14
plt.rcParams.update({'font.size': font_size})

%matplotlib inline

def extract_scalar_from_event_file(event_file, tag):
    event_acc = EventAccumulator(event_file)
    event_acc.Reload()
    
    # Check if the tag exists in the file
    if tag not in event_acc.Tags()["scalars"]:
        print(f"Tag {tag} not found in {event_file}")
        return None, None
    
    # Extract the scalar values
    steps = []
    values = []
    for event in event_acc.Scalars(tag):
        steps.append(event.step)
        values.append(event.value)
        
    return steps, values

def extract_and_combine_scalars(event_files, tag):
    combined_steps = []
    combined_values = []
    
    for event_file in event_files:
        event_acc = EventAccumulator(event_file)
        event_acc.Reload()
        
        if tag not in event_acc.Tags()["scalars"]:
            print(f"Tag {tag} not found in {event_file}")
            continue
        
        steps = []
        values = []
        for event in event_acc.Scalars(tag):
            steps.append(event.step)
            values.append(event.value)
        
        combined_steps.extend(steps)
        combined_values.extend(values)
    
    # Sort by steps and average the values for the same step
    combined_steps = np.array(combined_steps)
    combined_values = np.array(combined_values)
    
    unique_steps = np.unique(combined_steps)
    averaged_values = [np.mean(combined_values[combined_steps == step]) for step in unique_steps]
    
    return unique_steps, averaged_values

def plot_multiple_event_files(event_files, tags, epoch_to_mark=-10, custom_labels=None, custom_colors=None, title='Loss Curves', output_file=None):
    plt.figure(figsize=(10, 6))
    
    tag_values_at_epoch = {}
    marker_position = None
    max_y_value = -np.inf
    
    if custom_labels is None:
        custom_labels = {val: val for val in tags}
    custom_labels_to_labels = {val:key for key,val in custom_labels.items()}
        
    # Extract the epochs
    epochs, _ = extract_and_combine_scalars(event_files, tags[0])
    # Determine the x-value 10 steps before the last step
    if len(epochs) < abs(epoch_to_mark):
        raise f"epochs ({len(steps)}) < {epoch_to_mark}"
    
    for tag in tags:
        steps, values = extract_and_combine_scalars(event_files, tag)
        if len(steps) == 0:
            print(f"No data found for tag {tag} across the event files.")
            continue
        
        marker_position = steps[epoch_to_mark]
        tag_values_at_epoch[custom_labels[tag]] = values[epoch_to_mark]
            
        max_y_value = max(max_y_value, tag_values_at_epoch[custom_labels[tag]])
        
        color = custom_colors[tag] if custom_colors is not None and tag in custom_colors else None
        
        plt.plot(steps, values, label=custom_labels[tag], color=color, linewidth=2, alpha=0.8)
        
        plt.scatter(marker_position, tag_values_at_epoch[custom_labels[tag]], marker='o', s=50, zorder=5, color='red')  # Add marker at epoch_to_mark
    
    # Display a single box above the highest y value at marker_position
    box_text = "\n".join([f"{label} (epoch {marker_position}): {value:.4f}" for label, value in tag_values_at_epoch.items()])
    # For pretrained
    xytext_position = (marker_position - 0.2 * (steps[-1] - steps[0]), max_y_value + 0.5 * max_y_value)
    # For finetuned
    # xytext_position = (marker_position - 0.2 * (steps[-1] - steps[0]), max_y_value + 0.1 * max_y_value)

    plt.gca().annotate(box_text, xy=(marker_position, max_y_value), 
                       xytext=(xytext_position[0], xytext_position[1]), 
                       textcoords='data',
                        zorder=6)
    
    # Draw arrows from the annotation box to each marker
    for label, value in tag_values_at_epoch.items():
        plt.gca().annotate("", xy=(marker_position, value), xytext=(xytext_position[0] + 0.2*xytext_position[0], xytext_position[1]-0.005*xytext_position[1]), 
                           arrowprops=dict(arrowstyle="->", color='black', lw=1, zorder=5))


    # Manually add the colored lines next to the annotation box to mimic the legend
    for idx, label in enumerate(tag_values_at_epoch.keys()):
        plt.gca().add_line(Line2D([xytext_position[0] -2.5, xytext_position[0] -1.5],
                                  #pretrained
                                  [xytext_position[1]+ 0.8-idx * 0.35, xytext_position[1] + 0.8- idx * 0.35],
                                  #finetuned
                                #  [xytext_position[1] + 0.07-idx * 0.03, xytext_position[1] + 0.07- idx * 0.03],
                                  color=custom_colors[custom_labels_to_labels[label]], lw=2, zorder=7))


    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(title)
    plt.grid(True)
    
    if output_file is None:
        plt.show()
    else:
        # Save the plot as an EPS file
        plt.savefig(output_file, format='eps')
        plt.close()

title = 'Protein Localization Learning'
# title = "Phenotypic Learning"
output_file = f"/home/projects/hornsteinlab/Collaboration/MOmaps_Sagy/MOmaps/tools/loss_curves_pretrained_model_font{font_size}__embededLegend.eps"
# output_file = f"/home/projects/hornsteinlab/Collaboration/MOmaps_Sagy/MOmaps/tools/loss_curves_finetuned_model_font{font_size}_embededLegend.eps"

# List of TensorBoard event files
event_files = [
    "/home/projects/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_114835_515225_JID604820",
    "/home/projects/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_115000_347596_JID604820",
    "/home/projects/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_122927_128696_JID604820"
]

# event_files = [
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_101338_007438_JID836765",
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_104113_380636_JID836765",
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_105108_906212_JID836765",
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_105620_929382_JID836765",
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_114822_893465_JID836765",
#     "/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/finetuned_model/tensorboard/170924_204416_466457_JID836765"
# ]


# Plot the loss curves and save as an EPS file
# pre-trained
epoch_to_mark = 10
#fine-tuned
# epoch_to_mark = 9

# The tag for the Loss/All curve
tags = ["1. Loss/Train Epochs", "1. Loss/Val Epochs", "1. Loss/Test Epochs"]
labels = ["Train", "Val    ", "Test  "]
colors = ['green', 'blue', 'turquoise']

labels = {tags[i]:labels[i] for i in range(len(tags))}
colors = {tags[i]:colors[i] for i in range(len(tags))}


plot_multiple_event_files(event_files, tags, epoch_to_mark=-1-epoch_to_mark, custom_labels=labels, custom_colors=colors, title=title, output_file=output_file)


Tag 1. Loss/Train Epochs not found in /home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_114835_515225_JID604820
Tag 1. Loss/Train Epochs not found in /home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_114835_515225_JID604820
Tag 1. Loss/Val Epochs not found in /home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_114835_515225_JID604820
Tag 1. Loss/Test Epochs not found in /home/labs/hornsteinlab/Collaboration/MOmaps/outputs/vit_models/pretrained_model/tensorboard/120924_114835_515225_JID604820


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
