In [1]:
import warnings, argparse, random, pandas as pd, numpy as np, os
from pathlib import Path
from collections import Counter
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
import torch, torch.nn as nn
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings('ignore')

In [2]:
# RAND 
def evaluate_rand(df, topn, neg_k=19):
    """RAND baseline: random shuffle candidate set"""
    rng = np.random.default_rng(42); results = {}
    for n in topn:
        hit = tot = 0
        for _, grp in df.groupby('guest_id'):
            pos, neg = grp[grp.label==1], grp[grp.label==0]
            if pos.empty or neg.empty: continue
            for idx in pos.index:
                sel = rng.choice(neg.index, size=min(neg_k, len(neg)), replace=False)
                cand = np.append(sel, idx); rng.shuffle(cand)
                if idx in cand[:n]: hit += 1
                tot += 1
        results[f'HR@{n}'] = hit / tot if tot else 0.0
    return results

#  STL (TF-IDF + LR) 
class STL:
    def __init__(self, max_feat=50):
        self.vec = TfidfVectorizer(max_features=max_feat, stop_words='english', min_df=5, max_df=0.5)
        self.clf = LogisticRegression(max_iter=50, random_state=42)
        self.fit_done = False
    def _txt(self, df, idict):
        return [str(idict.get(lid, '')) for lid in df.listing_id]
    def _X(self, texts):
        X = self.vec.transform(texts) if self.fit_done else self.vec.fit_transform(texts)
        self.fit_done = True; return X
    def fit(self, df, idict):
        self.clf.fit(self._X(self._txt(df, idict)), df.label)
        return self
    def score(self, df, idict):
        return self.clf.predict_proba(self._X(self._txt(df, idict)))[:,1]

def evaluate_stl(tr, te, idict, topn, neg_k=19):
    model = STL().fit(tr, idict)
    tmp = te.copy(); tmp['score'] = model.score(tmp, idict)
    results = {}
    for n in topn:
        hit = tot = 0
        for _, grp in tmp.groupby('guest_id'):
            pos, neg = grp[grp.label==1], grp[grp.label==0]
            if pos.empty or neg.empty: continue
            for idx in pos.index:
                sel = np.random.choice(neg.index, min(neg_k,len(neg)), replace=False)
                cand = pd.concat([neg.loc[sel], pos.loc[[idx]]]).sort_values('score', ascending=False)
                if idx in cand.head(n).index: hit += 1
                tot += 1
        results[f'HR@{n}'] = hit / tot if tot else 0
    return results

# RTM‑G baseline 
class RTM_G:
    def __init__(self, n_topics: int = 60, lda_iter: int = 5):
        self.cv = CountVectorizer(stop_words='english', min_df=5, max_df=0.5)
        self.lda = LatentDirichletAllocation(n_components=n_topics, max_iter=lda_iter,
                                             learning_method='batch', random_state=0)
        self.clf = LogisticRegression(max_iter=50, random_state=42)
        self.cv_fitted = False

    def _txt(self, df: pd.DataFrame, idict: dict[str, str]) -> list[str]:
        return [str(idict.get(lid, '')) for lid in df.listing_id]

    def _theta(self, texts: list[str]):
        X = self.cv.transform(texts) if self.cv_fitted else self.cv.fit_transform(texts)
        self.cv_fitted = True
        return self.lda.transform(X) if hasattr(self.lda, 'components_') else self.lda.fit_transform(X)

    def fit(self, df: pd.DataFrame, idict: dict[str, str]):
        self.clf.fit(self._theta(self._txt(df, idict)), df.label)
        return self

    def score(self, df: pd.DataFrame, idict: dict[str, str]):
        return self.clf.predict_proba(self._theta(self._txt(df, idict)))[:, 1]

def evaluate_rtm_g(tr, te, idict, topn, neg_k=19):
    model = RTM_G().fit(tr, idict)
    tmp = te.copy(); tmp['score'] = model.score(tmp, idict)
    results = {}
    for n in topn:
        hit = tot = 0
        for _, grp in tmp.groupby('guest_id'):
            pos, neg = grp[grp.label==1], grp[grp.label==0]
            if pos.empty or neg.empty: continue
            for idx in pos.index:
                sel = np.random.choice(neg.index, min(neg_k,len(neg)), replace=False)
                cand = pd.concat([neg.loc[sel], pos.loc[[idx]]]).sort_values('score', ascending=False)
                if idx in cand.head(n).index: hit += 1
                tot += 1
        results[f'HR@{n}'] = hit / tot if tot else 0
    return results

# Main Callable Logic 
def run_baselines(data_root='autodl-fs', topn=range(1, 11), neg_k=19):
    root = Path(data_root)
    if not root.exists():
        raise FileNotFoundError(f"Data directory '{root}' not found")
    tr = pd.read_csv(root / 'transaction_train.csv')
    te = pd.read_csv(root / 'transaction_test.csv')
    idict = dict(pd.read_csv(root / 'dj_documents_unique.csv').values)

    all_res = {}
    print('Evaluating RAND...');   all_res['RAND']   = evaluate_rand(te, topn, neg_k)
    print('Evaluating STL...');    all_res['STL']    = evaluate_stl(tr, te, idict, topn, neg_k)
    print('Evaluating RTM-G...');  all_res['RTM-G']  = evaluate_rtm_g(tr, te, idict, topn, neg_k)

    # Output table
    header = 'TopN | ' + ' | '.join(f'{k:7s}' for k in all_res)
    print('\n' + header)
    print('-'*len(header))
    for n in topn:
        row = f' {n:>2}  | ' + ' | '.join(f"{all_res[k][f'HR@{n}']:.3f}" for k in all_res)
        print(row)

    # Save results
    out_dir = Path('brtm_outputs')
    out_dir.mkdir(exist_ok=True)
    pd.DataFrame(all_res).T.to_csv(out_dir / 'baseline_results.csv')
    print(f"\nResults saved to: {out_dir / 'baseline_results.csv'}")
    return all_res


def print_baseline_comparison(final_results: dict, csv_path='brtm_outputs/baseline_results.csv'):
    """Compare RAND / STL / RTM‑G with your model (BRTM-Sample) using HR@1-10"""
    baseline_df = pd.read_csv(csv_path, index_col=0)
    rand = [baseline_df.loc['RAND', f'HR@{n}'] for n in range(1, 11)]
    stl  = [baseline_df.loc['STL',  f'HR@{n}'] for n in range(1, 11)]
    rtmg = [baseline_df.loc['RTM-G',f'HR@{n}'] for n in range(1, 11)]
    brtm = [final_results.get(f'HR@{n}', 0.0) for n in range(1, 11)]
    print("=" * 90)
    print("RAND vs STL vs RTM‑G vs YOURS (Hit Rate)")
    print("=" * 90)
    print("TopN | RAND |  STL  | RTM‑G | BRTM-SAMPLE | ∆ vs RTM‑G")
    print("-" * 90)

    for n in range(1, 11):
        hr_rand = rand[n - 1]
        hr_stl  = stl[n - 1]
        hr_rtmg = rtmg[n - 1]
        hr_ours = brtm[n - 1]
        improve = ((hr_ours - hr_rtmg) / hr_rtmg * 100) if hr_rtmg else 0.0
        print(f" {n:>2}  | {hr_rand:.3f} | {hr_stl:.3f} | {hr_rtmg:.3f} | {hr_ours:.3f} | {improve:6.1f}%")


In [3]:
def load_hit_rates(npz_path='brtm_outputs/brtm_table7_complete_results.npz') -> dict:
    data = np.load(npz_path, allow_pickle=True)
    return data['hit_rates'].item()
final_results = load_hit_rates('brtm_outputs/brtm_table7_complete_results.npz')

In [5]:
# run baseline
baselines = run_baselines(data_root='autodl-fs')

Evaluating RAND...
Evaluating STL...
Evaluating RTM-G...

TopN | RAND    | STL     | RTM-G  
----------------------------------
  1  | 0.060 | 0.291 | 0.335
  2  | 0.115 | 0.383 | 0.469
  3  | 0.155 | 0.461 | 0.548
  4  | 0.206 | 0.534 | 0.595
  5  | 0.225 | 0.592 | 0.623
  6  | 0.291 | 0.652 | 0.673
  7  | 0.356 | 0.702 | 0.702
  8  | 0.391 | 0.751 | 0.746
  9  | 0.450 | 0.786 | 0.798
 10  | 0.469 | 0.818 | 0.840

Results saved to: brtm_outputs/baseline_results.csv


In [6]:
# compare with BRTM-Sample results
print_baseline_comparison(final_results, csv_path='brtm_outputs/baseline_results.csv')

RAND vs STL vs RTM‑G vs YOURS (Hit Rate)
TopN | RAND |  STL  | RTM‑G | BRTM-SAMPLE | ∆ vs RTM‑G
------------------------------------------------------------------------------------------
  1  | 0.060 | 0.291 | 0.335 | 0.300 |  -10.4%
  2  | 0.115 | 0.383 | 0.469 | 0.434 |   -7.5%
  3  | 0.155 | 0.461 | 0.548 | 0.529 |   -3.6%
  4  | 0.206 | 0.534 | 0.595 | 0.566 |   -4.8%
  5  | 0.225 | 0.592 | 0.623 | 0.645 |    3.5%
  6  | 0.291 | 0.652 | 0.673 | 0.696 |    3.5%
  7  | 0.356 | 0.702 | 0.702 | 0.769 |    9.5%
  8  | 0.391 | 0.751 | 0.746 | 0.800 |    7.2%
  9  | 0.450 | 0.786 | 0.798 | 0.843 |    5.7%
 10  | 0.469 | 0.818 | 0.840 | 0.857 |    2.0%
