In [1]:
import pandas as pd
import joblib
from pathlib import Path
from sklearn.neighbors import NearestNeighbors

ARTIF_DIR = Path("../app/artifacts")
ARTIF_DIR.mkdir(parents=True, exist_ok=True)


In [2]:
df = pd.read_csv("../data/df_file.csv")
X_raw = df["Text"].astype(str).tolist()
X = [t.replace("\n", " ").strip() for t in X_raw]

tfidf = joblib.load(ARTIF_DIR / "tfidf.pkl")
X_vec = tfidf.transform(X)
print("Matrix shape:", X_vec.shape)


Matrix shape: (2225, 77505)


In [3]:
index = NearestNeighbors(n_neighbors=10, metric="cosine").fit(X_vec)
joblib.dump(index, ARTIF_DIR / "knn_index.pkl")
print("Saved ->", ARTIF_DIR / "knn_index.pkl")


Saved -> ..\app\artifacts\knn_index.pkl


In [4]:
query = "election results and government budget policy"
qv = tfidf.transform([query])
dist, idx = index.kneighbors(qv, n_neighbors=5, return_distance=True)
sim = 1 - dist[0]

results = []
for s, i in zip(sim, idx[0]):
    results.append({
        "similarity": float(s),
        "label": int(df.iloc[i]["Label"]),
        "text_snippet": df.iloc[i]["Text"][:180] + ("..." if len(df.iloc[i]["Text"])>180 else "")
    })

pd.DataFrame(results)


Unnamed: 0,similarity,label,text_snippet
0,0.2507,0,Lib Dems predict 'best ever poll'\n \n The Lib...
1,0.100645,0,Kennedy looks to election gains\n \n They may ...
2,0.084663,0,Brown names 16 March for Budget\n \n Chancello...
3,0.084194,0,Brown names 16 March for Budget\n \n Chancello...
4,0.082189,0,Blair prepares to name poll date\n \n Tony Bla...


In [5]:
queries = [
    "election campaign and voting system",
    "sports match results and player transfer",
    "technology companies and product launches",
]
all_samples = {}
for q in queries:
    qv = tfidf.transform([q])
    dist, idx = index.kneighbors(qv, n_neighbors=5, return_distance=True)
    sim = 1 - dist[0]
    all_samples[q] = [
        {"sim": float(s), "label": int(df.iloc[i]["Label"]), "snippet": df.iloc[i]["Text"][:180]}
        for s, i in zip(sim, idx[0])
    ]

import json
with open(ARTIF_DIR / "vsm_query_samples.json", "w", encoding="utf-8") as f:
    json.dump(all_samples, f, ensure_ascii=False, indent=2)
print("Saved sample queries -> vsm_query_samples.json")


Saved sample queries -> vsm_query_samples.json
