In [None]:
%load_ext autoreload
%autoreload 2

In [8]:
import sys
sys.path.append('/capstor/users/cscs/xyixuan/PDM/')

In [20]:
from ignite.metrics import RougeL
from src.verbatim_eval.utils import load_inference_data
from tqdm import tqdm
import numpy as np

In [42]:
def batch_rouge_l_calc(batch, true_key="true_suffix", gen_key="generated_suffix", len_suffix=50):
    """
    Calculate ROUGE-L scores for a batch of true and generated sequences.

    Args:
        batch (dict): Batch of data containing true and generated sequences.
        true_key (str, optional): Key for true sequences. Defaults to "true_suffix".
        gen_key (str, optional): Key for generated sequences. Defaults to "generated_suffix".

    Returns:
        dict: Dictionary containing arrays of ROUGE-L scores and summary statistics.
    """
    rouge_scores = []
    rouge_metric = RougeL(multiref="best")

    for true_seq, gen_seq in zip(batch[true_key], batch[gen_key]):
        # Wrap sequences as required by the metric
        tokenized_true = [true_seq[:len_suffix]]
        tokenized_gen = [[gen_seq[:len_suffix]]]

        # Reset metric for new pair
        rouge_metric.reset()
        rouge_metric.update((tokenized_true, tokenized_gen))
        scores = rouge_metric.compute()
        
        # Extract F1 score
        rouge_scores.append(scores['Rouge-L-F'])

    return {
        "Rouge-L": rouge_scores,
    }

In [43]:
def eval_rougeL(base_path: str, expr: str, repetitions: np.ndarray):
    """
    Evaluate the ROUGE-L metric for a given experiment and repetitions.

    Args:
        expr (str): Name of the experiment.
        repetitions (np.ndarray): Array of repetition numbers.

    Returns:
        dict: Dictionary containing arrays of ROUGE-L scores and summary statistics.
    """
    results_by_rep = {}
    
    # Load inference data
    pbar = tqdm(repetitions, desc="Processing repetition set")
    for r in pbar:
        pbar.set_description(f"Processing repetition set {r}")  
        data_path = f"{base_path}/{expr}/inference"
        data = load_inference_data(data_path, rep=r)

        # Calculate ROUGE-L scores
        data_rouge = data.map(batch_rouge_l_calc, batched=True, batch_size=10, num_proc=50, desc=f"Calculating ROUGE-L for rep={r}")
        
        # Store results in dictionary
        scores = np.array([item['Rouge-L'] for item in data_rouge])
        
        results_by_rep[r] = {
            'scores': scores,
            'mean': np.mean(scores),
            'std': np.std(scores)
        }

    print(f"\nSummary of {expr}:")
    for rep in repetitions:
        print(f"Repetition {rep:3d}: Mean = {results_by_rep[rep]['mean']:.3f} ± {results_by_rep[rep]['std']:.3f}")

    # return results_by_rep

In [44]:
base_path = "/iopsstor/scratch/cscs/xyixuan/experiment"
repetitions = np.array([1, 2, 3, 4, 8, 16, 24, 32, 48, 64, 96, 128])

In [45]:
eval_rougeL(base_path, "llama_1.5B_Sparse_Gutenberg_K_50_H_13_GBS_60", repetitions)

Calculating ROUGE-L for rep=1 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1422.67 examples/s]
Calculating ROUGE-L for rep=2 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1400.50 examples/s]
Calculating ROUGE-L for rep=3 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1252.58 examples/s]
Calculating ROUGE-L for rep=4 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1403.65 examples/s]
Calculating ROUGE-L for rep=8 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1385.65 examples/s]
Calculating ROUGE-L for rep=16 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1365.92 examples/s]
Calculating ROUGE-L for rep=24 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1399.09 examples/s]
Calculating ROUGE-L for rep=32 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1355.15 examples/s]
Calculating ROUGE-L for rep=48 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1402.50 examples/s]
Calculating ROUGE-L for rep=64 (num_proc=50): 100%|█████████


Summary of llama_1.5B_Sparse_Gutenberg_K_50_H_13_GBS_60:
Repetition   1: Mean = 0.185 ± 0.065
Repetition   2: Mean = 0.187 ± 0.073
Repetition   3: Mean = 0.186 ± 0.057
Repetition   4: Mean = 0.188 ± 0.066
Repetition   8: Mean = 0.193 ± 0.074
Repetition  16: Mean = 0.195 ± 0.066
Repetition  24: Mean = 0.198 ± 0.079
Repetition  32: Mean = 0.222 ± 0.101
Repetition  48: Mean = 0.232 ± 0.119
Repetition  64: Mean = 0.244 ± 0.129
Repetition  96: Mean = 0.265 ± 0.156
Repetition 128: Mean = 0.278 ± 0.168





In [46]:
eval_rougeL(base_path, "llama_1.5B_Sparse_Gutenberg_Standard_GBS_60", repetitions)

Calculating ROUGE-L for rep=1 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1414.28 examples/s]
Calculating ROUGE-L for rep=2 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1413.97 examples/s]
Calculating ROUGE-L for rep=3 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1425.47 examples/s]
Calculating ROUGE-L for rep=4 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1528.96 examples/s]
Calculating ROUGE-L for rep=8 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1418.32 examples/s]
Calculating ROUGE-L for rep=16 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1495.56 examples/s]
Calculating ROUGE-L for rep=24 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1437.37 examples/s]
Calculating ROUGE-L for rep=32 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1515.75 examples/s]
Calculating ROUGE-L for rep=48 (num_proc=50): 100%|██████████| 500/500 [00:00<00:00, 1342.79 examples/s]
Calculating ROUGE-L for rep=64 (num_proc=50): 100%|█████████


Summary of llama_1.5B_Sparse_Gutenberg_Standard_GBS_60:
Repetition   1: Mean = 0.180 ± 0.060
Repetition   2: Mean = 0.185 ± 0.069
Repetition   3: Mean = 0.181 ± 0.062
Repetition   4: Mean = 0.190 ± 0.069
Repetition   8: Mean = 0.186 ± 0.066
Repetition  16: Mean = 0.190 ± 0.067
Repetition  24: Mean = 0.193 ± 0.063
Repetition  32: Mean = 0.197 ± 0.066
Repetition  48: Mean = 0.201 ± 0.076
Repetition  64: Mean = 0.197 ± 0.071
Repetition  96: Mean = 0.203 ± 0.074
Repetition 128: Mean = 0.207 ± 0.088





In [29]:
data_rouge['Rouge-L']

[0.172,
 0.208,
 0.238,
 0.144,
 0.156,
 0.176,
 0.192,
 0.154,
 0.146,
 0.174,
 0.17,
 0.156,
 0.218,
 0.156,
 0.18,
 0.198,
 0.17,
 0.178,
 0.18,
 0.13,
 0.158,
 0.19600000000000004,
 0.234,
 0.19600000000000004,
 0.17,
 0.188,
 0.134,
 0.134,
 0.188,
 0.184,
 0.208,
 0.20999999999999996,
 0.202,
 0.168,
 0.16,
 0.15,
 0.156,
 0.158,
 0.138,
 0.19,
 0.114,
 0.188,
 0.062,
 0.18,
 0.18,
 0.164,
 0.19,
 0.188,
 0.184,
 0.128,
 0.152,
 0.202,
 0.134,
 0.222,
 0.278,
 0.186,
 0.216,
 0.148,
 0.152,
 0.204,
 0.174,
 0.182,
 0.172,
 0.126,
 0.156,
 0.204,
 0.194,
 0.258,
 0.208,
 0.206,
 0.138,
 0.186,
 0.19,
 0.138,
 0.314,
 0.18,
 0.14,
 0.18,
 0.174,
 0.206,
 0.148,
 0.156,
 0.158,
 0.176,
 0.166,
 0.178,
 0.184,
 0.128,
 0.178,
 0.186,
 0.178,
 0.212,
 0.204,
 0.202,
 0.168,
 0.166,
 0.176,
 0.194,
 0.18,
 0.19,
 0.156,
 0.19600000000000004,
 0.222,
 0.154,
 0.206,
 0.168,
 0.146,
 0.246,
 0.124,
 0.156,
 0.168,
 0.15,
 0.184,
 0.148,
 0.17,
 0.202,
 0.144,
 0.154,
 0.192,
 0.18,
 0.18