# Visualization Functions Library

Notebook này chứa tất cả các hàm visualization từ `visualize.py`, được chuyển đổi thành dạng interactive và có thêm documentation chi tiết.

## Nội dung:
1. Basic Plotting Functions
   - ROUGE scores visualization
   - BLEU scores visualization
   - Training metrics visualization
2. Advanced Comparisons
   - Radar charts
   - Comprehensive comparisons
3. Interactive Widgets
4. Export Functions

In [None]:
# Import các thư viện cần thiết
import os
import sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import interact, interactive, fixed, interact_manual, widgets

# Thêm project path
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

from src.config import FIGURES_DIR, TABLES_DIR, ROUGE_METRICS, REPORT_TITLE

# Set style cho matplotlib
plt.style.use('seaborn')
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300

In [None]:
# Basic plotting functions
def plot_rouge_scores(results_df: pd.DataFrame, save_fig: bool = False, filename: str = "rouge_scores.png"):
    """
    Generates interactive bar chart comparing ROUGE scores across different methods.
    
    Args:
        results_df: DataFrame with columns ['Method', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L']
        save_fig: Whether to save the figure to disk
        filename: Name of the file to save
    """
    if results_df.empty:
        print("No data to plot for ROUGE scores.")
        return
    
    # Create plotly figure
    fig = go.Figure()
    methods = results_df['Method'].values
    
    for metric in ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']:
        fig.add_trace(go.Bar(
            name=metric,
            x=methods,
            y=results_df[metric],
            text=results_df[metric].round(3),
            textposition='auto',
        ))
    
    fig.update_layout(
        title='ROUGE Scores Comparison',
        xaxis_title='Method',
        yaxis_title='Score',
        barmode='group',
        width=900,
        height=500
    )
    
    if save_fig:
        if not os.path.exists(FIGURES_DIR):
            os.makedirs(FIGURES_DIR)
        plt.savefig(os.path.join(FIGURES_DIR, filename), bbox_inches='tight')
        print(f"Figure saved as {filename}")
    
    return fig

def plot_training_metrics(results_df: pd.DataFrame, save_fig: bool = False, filename: str = "training_metrics.png"):
    """
    Creates subplot comparing training time, VRAM usage, and trainable parameters.
    
    Args:
        results_df: DataFrame with training metrics
        save_fig: Whether to save the figure
        filename: Name of the file to save
    """
    fig = make_subplots(rows=1, cols=3, 
                        subplot_titles=('Training Time (minutes)', 'VRAM Usage (GB)', 
                                      'Trainable Parameters (M)'))
    
    metrics = {
        'Training Time': 'Training Time (min)',
        'VRAM': 'VRAM (GB)',
        'Parameters': 'Trainable Params (num)'
    }
    
    for idx, (name, col) in enumerate(metrics.items(), 1):
        values = results_df[col]
        if name == 'Parameters':
            values = values / 1e6  # Convert to millions
            
        fig.add_trace(
            go.Bar(name=name, x=results_df['Method'], y=values,
                  text=values.round(1), textposition='auto'),
            row=1, col=idx
        )
    
    fig.update_layout(height=400, width=1200, showlegend=False,
                     title_text="Training Metrics Comparison")
    fig.update_xaxes(tickangle=45)
    
    if save_fig:
        if not os.path.exists(FIGURES_DIR):
            os.makedirs(FIGURES_DIR)
        plt.savefig(os.path.join(FIGURES_DIR, filename), bbox_inches='tight')
        print(f"Figure saved as {filename}")
    
    return fig

In [None]:
# Advanced plotting functions
def plot_radar_comparison(results_df: pd.DataFrame, metrics: list = None, save_fig: bool = False, 
                         filename: str = "radar_comparison.png"):
    """
    Creates a radar/spider chart comparing methods across multiple metrics.
    
    Args:
        results_df: DataFrame with evaluation results
        metrics: List of metrics to include (defaults to ROUGE metrics)
        save_fig: Whether to save the figure
        filename: Name of the file to save
    """
    if metrics is None:
        metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L', 'BLEU']
    
    fig = go.Figure()
    for method in results_df['Method']:
        method_data = results_df[results_df['Method'] == method]
        fig.add_trace(go.Scatterpolar(
            r=method_data[metrics].values[0],
            theta=metrics,
            fill='toself',
            name=method
        ))
    
    fig.update_layout(
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, max([results_df[m].max() for m in metrics])]
            )),
        showlegend=True,
        title='Method Comparison Across Metrics',
        width=800,
        height=600
    )
    
    if save_fig:
        if not os.path.exists(FIGURES_DIR):
            os.makedirs(FIGURES_DIR)
        plt.savefig(os.path.join(FIGURES_DIR, filename), bbox_inches='tight')
        print(f"Figure saved as {filename}")
    
    return fig

def plot_comprehensive_comparison(results_df: pd.DataFrame, save_fig: bool = False, 
                                filename: str = "comprehensive_comparison.png"):
    """
    Creates a comprehensive visualization comparing all aspects of the methods.
    
    Args:
        results_df: DataFrame with all metrics
        save_fig: Whether to save the figure
        filename: Name of the file to save
    """
    # Create figure with secondary y-axis
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    
    # Add ROUGE scores as grouped bars
    for metric in ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']:
        fig.add_trace(
            go.Bar(name=metric, x=results_df['Method'], y=results_df[metric],
                  text=results_df[metric].round(3), textposition='auto'),
            secondary_y=False,
        )
    
    # Add training time as line on secondary axis
    fig.add_trace(
        go.Scatter(name='Training Time', x=results_df['Method'], 
                  y=results_df['Training Time (min)'],
                  line=dict(color='red', width=2),
                  mode='lines+markers'),
        secondary_y=True,
    )
    
    # Update layout
    fig.update_layout(
        title='Comprehensive Method Comparison',
        barmode='group',
        width=1000,
        height=600
    )
    
    # Set y-axes titles
    fig.update_yaxes(title_text="ROUGE Scores", secondary_y=False)
    fig.update_yaxes(title_text="Training Time (minutes)", secondary_y=True)
    
    if save_fig:
        if not os.path.exists(FIGURES_DIR):
            os.makedirs(FIGURES_DIR)
        plt.savefig(os.path.join(FIGURES_DIR, filename), bbox_inches='tight')
        print(f"Figure saved as {filename}")
    
    return fig

## Example Usage

Dưới đây là các ví dụ sử dụng các hàm visualization với dữ liệu mẫu từ `evaluation_results.csv`.

In [None]:
# Load evaluation results
results_path = os.path.join(TABLES_DIR, 'evaluation_results.csv')
df = pd.read_csv(results_path)

# Display sample of the data
print("Sample of evaluation results:")
display(df.head())

# Create and display plots
print("\nROUGE Scores Comparison:")
fig_rouge = plot_rouge_scores(df)
fig_rouge.show()

print("\nTraining Metrics Comparison:")
fig_metrics = plot_training_metrics(df)
fig_metrics.show()

print("\nRadar Chart Comparison:")
fig_radar = plot_radar_comparison(df)
fig_radar.show()

print("\nComprehensive Comparison:")
fig_comprehensive = plot_comprehensive_comparison(df)
fig_comprehensive.show()

## Interactive Widgets

Dưới đây là một số widget tương tác để tùy chỉnh các visualization.

In [None]:
def interactive_plot(metrics=None, methods=None, plot_type='bar', 
                  figure_width=800, figure_height=500):
    """
    Interactive plotting function with widgets
    """
    if metrics is None:
        metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
    if methods is None:
        methods = df['Method'].tolist()
        
    if not isinstance(metrics, list):
        metrics = [metrics]
    if not isinstance(methods, list):
        methods = [methods]
        
    df_filtered = df[df['Method'].isin(methods)]
    
    if plot_type == 'bar':
        fig = go.Figure()
        for metric in metrics:
            fig.add_trace(go.Bar(
                name=metric,
                x=df_filtered['Method'],
                y=df_filtered[metric],
                text=df_filtered[metric].round(3),
                textposition='auto',
            ))
        fig.update_layout(barmode='group')
    else:  # radar plot
        fig = go.Figure()
        for method in methods:
            method_data = df[df['Method'] == method]
            fig.add_trace(go.Scatterpolar(
                r=method_data[metrics].values[0],
                theta=metrics,
                fill='toself',
                name=method
            ))
    
    fig.update_layout(
        title=f'Comparison of {", ".join(metrics)}',
        width=figure_width,
        height=figure_height
    )
    fig.show()

# Create widgets
metric_options = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L', 'BLEU', 
                 'Training Time (min)', 'VRAM (GB)', 'Trainable Params (num)']
method_options = df['Method'].unique().tolist()

metrics_widget = widgets.SelectMultiple(
    options=metric_options,
    value=['ROUGE-1', 'ROUGE-2', 'ROUGE-L'],
    description='Metrics:',
    disabled=False
)

methods_widget = widgets.SelectMultiple(
    options=method_options,
    value=method_options[:2],  # Default to first two methods
    description='Methods:',
    disabled=False
)

plot_type_widget = widgets.RadioButtons(
    options=['bar', 'radar'],
    value='bar',
    description='Plot Type:',
    disabled=False
)

width_widget = widgets.IntSlider(
    value=800,
    min=400,
    max=1200,
    step=100,
    description='Width:',
)

height_widget = widgets.IntSlider(
    value=500,
    min=300,
    max=800,
    step=100,
    description='Height:',
)

# Create interactive plot
interact(interactive_plot, 
        metrics=metrics_widget,
        methods=methods_widget,
        plot_type=plot_type_widget,
        figure_width=width_widget,
        figure_height=height_widget);