In [1]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np

In [2]:
attention = pd.read_csv('/Users/adam.amster/seq2seq_translation/results/eval_metrics_attention_wmt14_test_wmt14_bleu.csv')
no_attention = pd.read_csv('/Users/adam.amster/seq2seq_translation/results/eval_metrics_no_attention_wmt14_test_wmt14_bleu.csv')

In [16]:
# Create figure
fig = go.Figure()

# Softer colors for scatter points, stronger lines for trends
attention_color = "mediumaquamarine"  # Soft green-blue
no_attention_color = "coral"          # Gentle coral
line_color_attention = "teal"          # Darker line for visibility
line_color_no_attention = "tomato"     # Stronger coral for trend line

# Scatter plots with higher transparency to focus on lines
fig.add_trace(go.Scatter(
    x=attention['input_length'],
    y=attention['bleu'],
    mode='markers',
    name='Attention',
    marker=dict(color=attention_color, size=6, opacity=0.5),
    legendgroup='attention'
))

fig.add_trace(go.Scatter(
    x=no_attention['input_length'],
    y=no_attention['bleu'],
    mode='markers',
    name='No Attention',
    marker=dict(color=no_attention_color, size=6, opacity=0.5),
    legendgroup='no_attention'
))

# Function to plot interpolated lines prominently
def plot_interpolated_line(x, y, color, name):
    # Fit a linear polynomial (degree=1)
    coefficients = np.polyfit(x, y, 1)
    polynomial = np.poly1d(coefficients)
    
    # Generate x values for plotting the polynomial line
    x_fit = np.linspace(min(x), max(x), 100)
    y_fit = polynomial(x_fit)
    
    # Add the trend line with higher visibility
    fig.add_trace(go.Scatter(
        x=x_fit,
        y=y_fit,
        mode='lines',
        name=f'{name}',
        line=dict(color=color, width=4),  # Thicker, solid line for emphasis
        legendgroup=name
    ))

# Plot interpolated lines
plot_interpolated_line(
    x=attention['input_length'],
    y=attention['wmt14_bleu'],
    color=line_color_attention,
    name='Attention'
)

plot_interpolated_line(
    x=no_attention['input_length'],
    y=no_attention['wmt14_bleu'],
    color=line_color_no_attention,
    name='No Attention'
)

# Update layout with larger figure size and clean design
fig.update_layout(
    xaxis_title='Num. Input Tokens',
    yaxis_title='BLEU Score',
    title='BLEU Score vs Num. Input Tokens',
    width=900,
    height=600,
    autosize=True,
    plot_bgcolor='rgba(0,0,0,0)',  # Transparent plot area
    paper_bgcolor='rgba(0,0,0,0)',  # Transparent outer background
    font=dict(color='black'),  # Set tick label color for visibility
    template='plotly_white'  # Clean white template
)

# Show figure
fig.write_json('/Users/adam.amster/aamster.github.io/assets/plotly/2024-10-03-sequence_to_sequence_translation/attention_vs_no_attention_sent_len.json')