In [1]:
%load_ext autoreload
%autoreload 2

Cell 1 — Imports

In [2]:
from pathlib import Path
import pandas as pd
import sys

# Make sure project root is on sys.path
PROJECT_ROOT = Path().resolve().parents[0]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.data.load import load_yaml, load_all_sources, add_broad_category
from src.models.classifier import load_classifier
from src.models.rewriter import build_rewriter
from src.retrieval.indexer import (
    load_retrieval_index,
    save_retrieval_index,
    build_retrieval_index,
    retrieve_similar_articles,
    retrieval_enabled,
)

2025-12-16 22:31:09.868856: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-16 22:31:09.961300: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-16 22:31:12.369172: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Cell 2 — Load config + data

In [3]:
cfg = load_yaml(PROJECT_ROOT / 'configs' / 'base.yaml')

df = load_all_sources(cfg, root=PROJECT_ROOT)
df = add_broad_category(df, cfg, root=PROJECT_ROOT)

df.shape, df.columns

INFO:src.data.load:Loading source pakistan_today from /home/spark/NUST/Semester 5/Data Mining/Project/data/raw/pakistan_today(full-data).csv (encoding=utf-8)
INFO:src.data.load:Loading source tribune from /home/spark/NUST/Semester 5/Data Mining/Project/data/raw/tribune(full-data).csv (encoding=latin1)
INFO:src.data.load:Loading source dawn from /home/spark/NUST/Semester 5/Data Mining/Project/data/raw/dawn (full-data).csv (encoding=latin1)
INFO:src.data.load:Loading source daily_times from /home/spark/NUST/Semester 5/Data Mining/Project/data/raw/daily_times(full-data).csv (encoding=utf-8)
INFO:src.data.load:Loading preprocessed business_reorder from /home/spark/NUST/Semester 5/Data Mining/Project/data/interim/business_reorder_clean.parquet
INFO:src.data.load:Filtered invalid sources: (625905, 7) -> (624642, 7)
INFO:src.data.load:Combined dataset shape: (624642, 7)
INFO:src.data.load:Sampling up to 10000 rows per source (__file__ column).
  .apply(lambda g: g.sample(min(len(g), per_sourc

((40000, 8),
 Index(['headline', 'date', 'link', 'source', 'categories', 'description',
        '__file__', 'broad_category'],
       dtype='object'))

Cell 3 — Load classifier

In [4]:
clf = load_classifier(cfg, root=PROJECT_ROOT)
rewriter = build_rewriter(cfg)
print('Loaded classifier + rewriter')


Loaded classifier + rewriter


Cell 4 — Load or build retrieval index (offline RAG)

In [5]:
retrieval_index = None
if retrieval_enabled(cfg):
    try:
        retrieval_index = load_retrieval_index(cfg, root=PROJECT_ROOT)
        print('Loaded retrieval index.')
    except FileNotFoundError:
        print('Retrieval index missing. Building now (one-time step)...')
        retrieval_index = build_retrieval_index(df, cfg)
        path = save_retrieval_index(retrieval_index, cfg, root=PROJECT_ROOT)
        print('Saved retrieval index to:', path)


Retrieval index missing. Building now (one-time step)...
Saved retrieval index to: /home/spark/NUST/Semester 5/Data Mining/Project/data/processed/retrieval_index/tfidf_retrieval.joblib


Cell 5 — Run demo on 10 random samples

In [6]:
sample = df.sample(10, random_state=cfg['project']['random_seed']).reset_index(drop=True)

rows = []
for r in sample.to_dict('records'):
    headline = (r.get('headline') or '').strip()
    desc = (r.get('description') or '').strip()

    text = (headline + ' ' + desc).strip()
    pred_cat = clf.predict([text])[0]

    retrieved = []
    if retrieval_index is not None:
        retrieved = retrieve_similar_articles(
            text,
            retrieval_index,
            top_k=cfg['retrieval'].get('top_k', 3),
            min_similarity=cfg['retrieval'].get('min_similarity', 0.0),
        )

    out = rewriter.rewrite(
        headline=headline,
        description=desc,
        category=pred_cat,
        retrieved=retrieved,
    )

    rows.append({
        'headline': headline[:200],
        'true_cat': r.get('broad_category', ''),
        'pred_cat': pred_cat,
        'original': desc[:900],
        'expanded_article': out.compose_text(),
        'retrieved_1': (retrieved[0].get('headline') if len(retrieved) > 0 else ''),
        'retrieved_1_score': (retrieved[0].get('score') if len(retrieved) > 0 else ''),
    })

out_df = pd.DataFrame(rows)
out_df[['headline', 'true_cat', 'pred_cat']].head(10)


Unnamed: 0,headline,true_cat,pred_cat
0,Private firms setting up LNG terminals seek wa...,Business,Business
1,President Alvi tests positive for Covid-19 for...,Pakistan,Pakistan
2,OPEC and allies likely to cut production if US...,Business,Business
3,Pacific leaders struggle to keep focus on clim...,World,World
4,Police register fraud case against 2-year-old ...,Pakistan,Pakistan
5,International Criminal Court rules it has juri...,World,World
6,MoST to set auto parts quality standards,Business,Business
7,Economic terrorists will not be allowed to fle...,Pakistan,Pakistan
8,ECC meeting to clear Rs200bn for daily wage ea...,Business,Business
9,US could be next 'virus epicentre' as India lo...,World,World


Cell 6 — Length comparison table (nice for report)

In [7]:
out_df['orig_words'] = out_df['original'].apply(lambda x: len(str(x).split()))
out_df['expanded_words'] = out_df['expanded_article'].apply(lambda x: len(str(x).split()))
out_df['length_ratio'] = (out_df['expanded_words'] / out_df['orig_words'].replace(0, 1)).round(2)

out_df[['orig_words', 'expanded_words', 'length_ratio']].describe()


Unnamed: 0,orig_words,expanded_words,length_ratio
count,10.0,10.0,10.0
mean,139.6,703.2,5.063
std,4.501851,212.918764,1.614889
min,134.0,384.0,2.76
25%,136.5,505.25,3.465
50%,138.5,772.0,5.62
75%,142.75,889.5,6.4675
max,147.0,932.0,6.85


Cell 7 — Save outputs for report appendix

In [8]:
out_dir = PROJECT_ROOT / 'experiments' / 'results'
out_dir.mkdir(parents=True, exist_ok=True)

path = out_dir / 'article_expansion_demo.csv'
out_df.to_csv(path, index=False)
print('Saved:', path)


Saved: /home/spark/NUST/Semester 5/Data Mining/Project/experiments/results/article_expansion_demo.csv


Cell 1 — Imports
