In [111]:
from pathlib import Path
import re
import json

import pandas as pd

from utils import read_last_line_from_file

In [112]:
OUTPUT_DIR = Path("output")

In [127]:
class Run:
    def __init__(self, log_file: Path):
        if not log_file.exists():
            raise ValueError(f"log file `{log_file}` does not exist")
        self.log_file = log_file
    
    def seed(self) -> int:
        seed = re.findall(r".*_(\d+).log", str(self.log_file))
        assert len(seed) == 1
        seed = int(seed[0])
        return seed

    def metrics(self) -> dict[str, float]:
        last_line = read_last_line_from_file(self.log_file)
        _date, metrics_string = last_line.split(" - ", maxsplit=1)
        metrics_string = metrics_string.replace("'", "\"")
        metrics = json.loads(metrics_string)
        metrics = {k: float(v) for k, v in metrics.items()}

        del metrics["Epoch"]
        return metrics


class SeededRuns:
    def __init__(
        self,
        runs: list[Run],
        model_name = "Model",
        dataset_name = None,
    ):
        self.runs = runs
        self.model_name = model_name
        self.dataset = dataset_name

    @classmethod
    def from_model_name_and_dataset(
        cls, model_name: str, dataset: str, output_dir: Path = OUTPUT_DIR, log_file_extension: str = "log",
    ):
        model_output_path = output_dir / model_name
        if not model_output_path.exists():
            raise ValueError(f"model output path `{model_output_path}` does not exist")

        seeded_runs = model_output_path.glob(f"{model_name}_{dataset}_*.{log_file_extension}")
        seeded_runs = [Run(run) for run in seeded_runs]
        return cls(seeded_runs, model_name, dataset)
    
    def seeds(self):
        return [run.seed() for run in self.runs]
    
    def metrics_df(self):
        metrics = [run.metrics() for run in self.runs]
        return pd.DataFrame(metrics)
    
    def describe_metrics_df(self, *, caption: bool = False):
        metrics_df = self.metrics_df()
        metrics_df = metrics_df.describe().loc[["mean", "std", "min", "max"]]
        if caption:
            caption = f"{self.model_name} performance on {self.dataset} over {len(self.seeds())} seeds"
            metrics_df = metrics_df.style.set_caption(caption)

        return metrics_df

In [128]:
# Models on the LastFM dataset
bert4rec_lastfm = SeededRuns.from_model_name_and_dataset("BERT4Rec", "LastFM")
sasrec_lastfm = SeededRuns.from_model_name_and_dataset("SASRec", "LastFM")
duorec_lastfm = SeededRuns.from_model_name_and_dataset("DuoRec", "LastFM")
fearec_lastfm = SeededRuns.from_model_name_and_dataset("FEARec", "LastFM")
bsarec_lastfm = SeededRuns.from_model_name_and_dataset("BSARec", "LastFM")

# Models on the Diginetica dataset
# Diginetica models
bert4rec_diginetica = SeededRuns.from_model_name_and_dataset("BERT4Rec", "Diginetica")
sasrec_diginetica = SeededRuns.from_model_name_and_dataset("SASRec", "Diginetica")
duorec_diginetica = SeededRuns.from_model_name_and_dataset("DuoRec", "Diginetica")
fearec_diginetica = SeededRuns.from_model_name_and_dataset("FEARec", "Diginetica")
bsarec_diginetica = SeededRuns.from_model_name_and_dataset("BSARec", "Diginetica")

In [134]:
display(bert4rec_lastfm.describe_metrics_df(caption=True))
display(sasrec_lastfm.describe_metrics_df(caption=True))
display(duorec_lastfm.describe_metrics_df(caption=True))
display(fearec_lastfm.describe_metrics_df(caption=True))
display(bsarec_lastfm.describe_metrics_df(caption=True))

Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.02974,0.0187,0.04972,0.02512,0.07982,0.03266
std,0.005387,0.00313,0.005335,0.002687,0.005258,0.002608
min,0.0239,0.0152,0.0431,0.0215,0.0706,0.0285
max,0.0349,0.0216,0.0578,0.0285,0.0835,0.0349


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.03944,0.02682,0.05984,0.03336,0.08862,0.04064
std,0.002057,0.001705,0.002285,0.001547,0.004764,0.002239
min,0.0367,0.0252,0.056,0.0311,0.0807,0.0374
max,0.0422,0.0289,0.0615,0.035,0.0927,0.0428


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.0424,0.03204,0.05796,0.03702,0.08954,0.04488
std,0.002094,0.000666,0.004505,0.000912,0.005773,0.001112
min,0.0404,0.0312,0.0541,0.0358,0.0835,0.0434
max,0.0459,0.0329,0.0651,0.0382,0.0963,0.046


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.042,0.0291,0.05892,0.03452,0.08532,0.04108
std,0.008903,0.005377,0.011804,0.006351,0.014559,0.007007
min,0.0284,0.0199,0.0404,0.0238,0.0624,0.0293
max,0.0505,0.0329,0.0725,0.0406,0.1018,0.0479


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.04992,0.0339,0.0734,0.04144,0.1077,0.04998
std,0.002643,0.001687,0.00427,0.001372,0.0052,0.001728
min,0.0468,0.0319,0.0679,0.0394,0.1028,0.0474
max,0.0532,0.0359,0.078,0.0429,0.1165,0.0519


In [135]:
display(bert4rec_diginetica.describe_metrics_df(caption=True))
display(sasrec_diginetica.describe_metrics_df(caption=True))
display(duorec_diginetica.describe_metrics_df(caption=True))
display(fearec_diginetica.describe_metrics_df(caption=True))
display(bsarec_diginetica.describe_metrics_df(caption=True))

Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.121,0.07748,0.1902,0.09974,0.27726,0.12172
std,0.001042,0.000936,0.002084,0.001228,0.001397,0.000942
min,0.1196,0.0764,0.1872,0.0985,0.276,0.1206
max,0.1222,0.0787,0.1924,0.1013,0.2793,0.1228


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.10394,0.06652,0.1657,0.08636,0.25048,0.10774
std,0.001282,0.000303,0.003043,0.001272,0.004563,0.001626
min,0.103,0.066,0.1621,0.085,0.2467,0.1063
max,0.1062,0.0667,0.1695,0.0879,0.2581,0.1103


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.15196,0.09944,0.2279,0.12386,0.32208,0.14762
std,0.00115,0.000876,0.001042,0.000623,0.001117,0.000756
min,0.1507,0.0983,0.2265,0.1229,0.3205,0.1464
max,0.1532,0.1002,0.229,0.1245,0.3234,0.1482


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.15346,0.10056,0.23176,0.12574,0.3271,0.1498
std,0.001339,0.000695,0.00167,0.000832,0.002192,0.000316
min,0.1515,0.0998,0.2299,0.1249,0.3243,0.1495
max,0.1552,0.1014,0.2342,0.1267,0.3292,0.1503


Unnamed: 0,HR@5,NDCG@5,HR@10,NDCG@10,HR@20,NDCG@20
mean,0.15628,0.10196,0.23338,0.1267,0.32562,0.15
std,0.001555,0.000805,0.001316,0.000624,0.001695,0.000791
min,0.1546,0.1008,0.2319,0.1258,0.3236,0.149
max,0.1588,0.1029,0.2353,0.1275,0.3274,0.1508
