In [10]:
import ray.tune
from pathlib import Path
from dataclasses import dataclass, field
from functools import lru_cache
import pandas as pd
from typing import List
from functools import partial

In [11]:
# pd.set_option('display.max_colwidth', None)
# pd.set_option('display.max_rows', None)

In [12]:
from etr_fr_expes import metric

In [13]:
EXPE_DIR = Path("../../experimentations/").resolve()

In [14]:
@dataclass
class Expe:
    model: str
    method: str
    task: str
    train_tasks: str
    metric: str
    mode: str
    adapter_name: str
    expe_dir: str = EXPE_DIR
    step: str = "hp_search"
    _expe_analysis = None
    
    
    @property
    def base_columns(self):
        return [
            # "expe",
            "trial_id",
            "model", 
            "method", 
            "task", 
            "train_tasks", 
            # "metric",
        ]
    
    @property
    def expe_name(self):
        return f"{self.model}.{self.method}.{self.train_tasks}"
    
    @property
    def expe_analysis(self):
        hp_search_dir = self.expe_dir / self.expe_name / "results" / f"{self.expe_name}_{self.step}"
        if self._expe_analysis is None:
            self._expe_analysis = ray.tune.ExperimentAnalysis(hp_search_dir)
        return self._expe_analysis

    @property
    def dataframe(self):
        dfs = self.expe_analysis.trial_dataframes
        res = (pd.concat(dfs, ignore_index=True)
            .reset_index(level=0)
            .assign(
                expe=self.expe_name, 
                model=self.model, 
                metric=self.metric,
                method=self.method,
                task=self.task,
                train_tasks=self.train_tasks,
            )
            .sort_values(by=f"eval_{self.metric}", ascending=self.mode == "min")
            .rename(columns=lambda x: x.replace(self.adapter_name, 'adapter'))
        )
        res = res.set_index(self.base_columns)
        return res
    
    def metric_columns(self, _type="test"):
        return self.dataframe.filter(
            regex=f"{_type}_{self.task}_(?!texts)"
        ).columns
        
    def text_columns(self, _type="test"):
        return self.dataframe.filter(
            regex=f"{_type}.*texts"
        ).columns
        
    @property
    def test_metric_df(self):
        return self.dataframe[self.metric_columns(_type="test")]
    
    @property
    def eval_metric_df(self):
        return self.dataframe[self.metric_columns(_type="eval")]
    
    @property
    def best_model(self):
        return self.dataframe.iloc[0]
    
    def get_texts_df(self, row_idx, _type="test"):
        row_df = self.dataframe.iloc[row_idx].filter(regex=f"{_type}.*texts").to_frame().T
        row_df = row_df.explode(list(row_df.columns)).reset_index(drop=True).stack().to_frame()
        return row_df
    
    @property
    def best_trial(self):
        return self.expe_analysis.get_best_trial(metric=self.metric, mode=self.mode, scope="all")
    
ETRFrExpe = partial(Expe, metric="etr_fr_srb", mode="max", task="etr_fr", adapter_name="lora_etr_fr")
OrangesumExpe = partial(Expe, metric="orangesum_rougeL", mode="max", task="orangesum", adapter_name="lora_orangesum")
WikilargeExpe = partial(Expe, metric="wikilarge_fr_sari", mode="max", task="wikilarge_fr")

In [15]:
@dataclass
class Analysis:
    expes: List[Expe]
    metrics: List[str] = field(default_factory=lambda :[
        "rouge1",
        "rouge2",
        "rougeL",
        "sari",
        "bertscore_f1",
        "srb",
        "compression_ratio",
        "novelty",
        "kmre",
        "lix"
    ])

    @property
    def dataframe(self):
        dfs = [expe.dataframe for expe in self.expes]
        return pd.concat(dfs)
    
    def best_models(self, _type=None, texts=False):
        series = [expe.best_model for expe in self.expes]
        res = pd.DataFrame(series)
        
        return_metrics = _type in ["test", "eval"]
        if not texts and return_metrics:
            res = res.filter(regex=f"{_type}.*({'|'.join(self.metrics)})")
        elif return_metrics:
            cols = self.expes[0].text_columns(_type=_type)
            res = res[cols]
        
        
        return res
    
    @property
    def test_metrics(self):
        return self.best_models(_type="test")
    
    @property
    def eval_metrics(self):
        return self.best_models(_type="eval")
    
    @property
    def test_texts(self):
        return self.best_models(_type="test", texts=True)

In [16]:
MBARTHEZ, MBART, MISTRAL, LLAMA3 = "mbarthez", "mbart", "mistral", "llama3"
LORA, MTLLORA = "lora", "mtllora"
ETR_FR, ORANGESUM, WIKILARGE, ETR_FR_ORANGESUM, ETR_FR_WIKILARGE, ALL = "etrfr", "orangesum", "wikilarge", "etrfr+orangesum", "etrfr+wikilarge", "etrfr+orangesum+wikilarge"

In [17]:
etr_fr_analysis = Analysis(
    expes=[
        # ETRFrExpe(model=MBART, method=LORA, train_tasks=ETR_FR),
        # ETRFrExpe(model=MBART, method=MTLLORA, train_tasks=ALL),
        
        ETRFrExpe(model=MBARTHEZ, method=LORA, train_tasks=ETR_FR+".reprod"),
        ETRFrExpe(model=MBARTHEZ, method=MTLLORA, train_tasks=ALL),
        ETRFrExpe(model=MBARTHEZ, method=MTLLORA, train_tasks=ETR_FR_ORANGESUM),
        ETRFrExpe(model=MBARTHEZ, method=MTLLORA, train_tasks=ETR_FR_WIKILARGE),
        
        ETRFrExpe(model=LLAMA3, method=LORA, train_tasks=ETR_FR),
        ETRFrExpe(model=LLAMA3, method=MTLLORA, train_tasks=ALL),
        ETRFrExpe(model=LLAMA3, method=MTLLORA, train_tasks=ETR_FR_ORANGESUM),
        ETRFrExpe(model=LLAMA3, method=MTLLORA, train_tasks=ETR_FR_WIKILARGE),
        
        ETRFrExpe(model=MISTRAL, method=LORA, train_tasks=ETR_FR),
        ETRFrExpe(model=MISTRAL, method=MTLLORA, train_tasks=ALL),
        ETRFrExpe(model=MISTRAL, method=MTLLORA, train_tasks=ETR_FR_ORANGESUM),
        ETRFrExpe(model=MISTRAL, method=MTLLORA, train_tasks=ETR_FR_WIKILARGE),
    ]
)

wikilarge_analysis = Analysis(
    expes=[
        WikilargeExpe(model=MBARTHEZ, method=LORA, train_tasks="wikilarge-fr", adapter_name="lora_wikilarge_fr"),
        WikilargeExpe(model=LLAMA3, method=LORA, train_tasks=WIKILARGE, adapter_name="lora_wikilarge"),
        WikilargeExpe(model=MISTRAL, method=LORA, train_tasks=WIKILARGE, adapter_name="lora_wikilarge"),
    ]
)

orangesum_analysis = Analysis(
    expes=[
        OrangesumExpe(model=MBARTHEZ, method=LORA, train_tasks=ORANGESUM),
        OrangesumExpe(model=LLAMA3, method=LORA, train_tasks=ORANGESUM),
        OrangesumExpe(model=MISTRAL, method=LORA, train_tasks=ORANGESUM),
    ]
)

In [64]:
etr_fr_analysis.eval_metrics

Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,Unnamed: 4,eval_etr_fr_rouge1,eval_etr_fr_rouge2,eval_etr_fr_rougeL,eval_etr_fr_rougeLsum,eval_etr_fr_sari,eval_etr_fr_bertscore_f1_rescaled,eval_etr_fr_bertscore_f1,eval_etr_fr_kmre,eval_etr_fr_lix,eval_etr_fr_compression_ratio,eval_etr_fr_novelty,eval_etr_fr_srb
8bc23_00000,mbarthez,lora,etr_fr,etrfr.reprod,39.6319,16.5283,29.4832,38.1354,40.929,33.3839,75.0356,96.1426,29.5607,40.5335,11.9967,41.8543
0e965_00009,mbarthez,mtllora,etr_fr,etrfr+orangesum+wikilarge,42.8477,19.5226,32.1822,40.9599,44.2588,37.9894,76.7615,98.5026,27.7983,46.3988,20.5952,44.981
dc655_00009,mbarthez,mtllora,etr_fr,etrfr+orangesum,40.7273,18.803,30.9719,39.0222,43.2281,35.1582,75.7005,98.5344,27.2011,47.7912,14.3958,43.7124
de84c_00009,mbarthez,mtllora,etr_fr,etrfr+wikilarge,41.6206,18.7326,31.1182,39.4755,43.2177,37.1391,76.4429,98.5139,27.9076,48.8784,17.72,43.888
8a9fd_00003,llama3,lora,etr_fr,etrfr,42.172,20.2764,32.7097,40.527,45.9699,36.7264,76.2882,95.3582,29.0824,45.7141,22.3861,45.8481
a6a20_00009,llama3,mtllora,etr_fr,etrfr+orangesum+wikilarge,43.3292,21.9895,33.9438,41.9159,49.4661,39.8132,77.445,97.4047,27.8156,49.011,35.2361,47.932
e160e_00010,llama3,mtllora,etr_fr,etrfr+orangesum,44.5968,22.5648,34.6758,42.8946,50.6565,39.9974,77.514,98.3807,26.6145,50.0887,36.4972,48.7962
f2da7_00011,llama3,mtllora,etr_fr,etrfr+wikilarge,43.9325,20.8832,33.548,41.8511,48.8374,39.0993,77.1775,98.3391,26.967,50.2579,35.2471,47.4373
142d8_00003,mistral,lora,etr_fr,etrfr,42.2985,20.2579,32.1511,40.1984,47.4291,37.1022,76.429,97.4827,27.5085,47.3911,28.1271,45.962
fc28b_00009,mistral,mtllora,etr_fr,etrfr+orangesum+wikilarge,44.364,21.5359,33.9146,42.3962,48.8829,39.1615,77.2008,97.6679,28.3413,48.0869,30.127,47.6977


In [61]:
ETRFrExpe(model=LLAMA3, method=MTLLORA, train_tasks=ALL, step="test_best_model").dataframe

KeyError: 'eval_etr_fr_srb'

In [70]:
etr_fr_analysis_test_best_model = Analysis(
    expes=[
        ETRFrExpe(model=LLAMA3, method=LORA, train_tasks=ETR_FR, step="test_best_model"),
        # ETRFrExpe(model=LLAMA3, method=MTLLORA, train_tasks=ALL, step="test_best_model")
    ]
)

In [87]:
etr_fr_analysis_test_best_model.expes[0].dataframe.filter(regex="texts")["test_etr_fr_politic_texts/inputs"].iloc[0]

['Travailleuses, travailleurs,  En avril 2020, pendant le premier confinement, Macron expliquait qu’il faudrait se rappeler que le pays avait tenu grâce à celles et ceux «\u2009que nos économies reconnaissent et rémunèrent si mal\u2009». Deux ans plus tard, le personnel des hôpitaux et des Ehpad, les aides à la personne, les ouvriers de l’agroalimentaire et de la logistique, les travailleurs des transports, les agents de nettoyage et ceux du gardiennage, l’armée des livreurs et des coursiers, les travailleurs de l’Éducation nationale, continuent d’être sous-payés et méprisés. Des travailleurs de la santé sont licenciés de fait et privés de salaire, en particulier en Guadeloupe, en Martinique et en Guyane. Tous ceux dont les conditions d’existence se dégradent doivent faire entendre leurs intérêts.',
 'PROFITS EN HAUSSE ET POUVOIR D’ACHAT EN BAISSE: ÇA SUFFIT! Suppressions d’emplois, cadences infernales, salaires insuffisants : plus les grandes entreprises sont rentables, plus leurs act