# Cell 1: Install Dependencies (Stable Versions)
Installs specific versions to avoid the dependency conflicts

In [None]:
# Install specific versions to ensure compatibility between beir, pyserini and transformers
!pip uninstall -y faiss-gpu faiss-cpu sentence-transformers transformers huggingface_hub
!pip install faiss-cpu
!pip install huggingface-hub==0.23.0 transformers==4.36.2 sentence-transformers==2.2.2 pyserini beir pandas matplotlib seaborn scipy

# Install Java 21 for Lucene (required by Pyserini)
!apt-get -y install -qq openjdk-21-jdk-headless || true
print("‚úÖ Dependencies installed successfully")

# Cell 2: Imports & Setup
Restart the kernel/runtime before running this cell

In [None]:
import os
import json
import time
import shutil
import pathlib
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from scipy.sparse import csr_matrix

# Configure Java 21 for Lucene
java_home = "/usr/lib/jvm/java-21-openjdk-amd64"
if os.path.exists(java_home):
    os.environ["JAVA_HOME"] = java_home
    os.environ["PATH"] = f"{java_home}/bin:" + os.environ.get("PATH", "")

from sentence_transformers import SentenceTransformer
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from pyserini.search.lucene import LuceneSearcher
import faiss

sns.set_style('whitegrid')
print("‚úÖ Libraries imported and Java configured")

# Cell 3: Dataset Loading
Handles standard BEIR datasets and CQADupStack sub-tasks

In [None]:
# =================================================================
# SELECT DATASET
# =================================================================
dataset_name = 'scifact'  # Change to: scifact, trec-covid, fiqa, etc.

# =================================================================
# DATASET CONFIGURATION
# =================================================================
public_datasets = {
    'nfcorpus': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip',
    'scifact': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip',
    'arguana': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/arguana.zip',
    'scidocs': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scidocs.zip',
    'fiqa': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip',
    'trec-covid': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip',
    'webis-touche2020': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip',
    'quora': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip',
    'dbpedia-entity': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip',
    'nq': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip',
    'cqadupstack': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/cqadupstack.zip',
}

cqa_sub_datasets = {
    'android': '23K docs', 'english': '41K docs', 'gaming': '46K docs', 
    'gis': '38K docs', 'mathematica': '17K docs', 'physics': '39K docs', 
    'programmers': '33K docs', 'stats': '42K docs', 'tex': '71K docs', 
    'unix': '48K docs', 'webmasters': '17K docs', 'wordpress': '49K docs'
}

# Download Logic
out_dir = os.path.join(pathlib.Path('.').parent.absolute(), "datasets")

if dataset_name.startswith('cqadupstack/'):
    sub_name = dataset_name.split('/')[1]
    if sub_name not in cqa_sub_datasets:
        raise ValueError(f"Invalid CQA sub-dataset '{sub_name}'")
    print(f"--- Processing CQADupStack: {sub_name} ---")
    url = public_datasets['cqadupstack']
    base_path = util.download_and_unzip(url, out_dir)
    data_path = os.path.join(base_path, sub_name)
elif dataset_name in public_datasets:
    print(f"--- Processing {dataset_name} ---")
    url = public_datasets[dataset_name]
    data_path = util.download_and_unzip(url, out_dir)
else:
    raise ValueError(f"Dataset '{dataset_name}' not found.")

# Load Data
print(f"Loading data from: {data_path}")
try:
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
    
    # Prepare lists for encoding (CRITICAL STEP)
    print("Preparing data lists...")
    doc_ids = list(corpus.keys())
    doc_texts = [corpus[did]['title'] + ' ' + corpus[did]['text'] for did in doc_ids]
    query_ids = list(queries.keys())
    query_texts = [queries[qid] for qid in query_ids]

    print(f"\n‚úÖ Dataset Loaded: {dataset_name}")
    print(f"   Documents: {len(corpus):,}")
    print(f"   Queries: {len(queries):,}")
    print(f"   Relevance judgments: {len(qrels):,}")

except Exception as e:
    print(f"\n‚ùå Error loading dataset: {e}")

# Cell 4: Dense Retrieval (BGE Model)
Encodes documents and queries using the BGE model

In [None]:
# Load BGE model
model_name = 'BAAI/bge-base-en-v1.5'
print(f"Loading BGE model: {model_name}")
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()

# Encode Documents
# Adjust batch size based on dataset size to avoid OOM
batch_size = 32 if len(doc_texts) <= 100_000 else 16 
print(f"Encoding {len(doc_texts):,} documents (batch_size={batch_size})...")

doc_embeddings = model.encode(
    doc_texts,
    batch_size=batch_size,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

# Encode Queries
print(f"Encoding {len(query_texts):,} queries...")
query_embeddings = model.encode(
    query_texts,
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print(f"‚úÖ Dense encoding complete. Doc shape: {doc_embeddings.shape}")

# Cell 5: Build Dense & BM25 Indexes
Constructs FAISS indexes for Dense retrieval and Lucene index for BM25

In [None]:
# Parameters matching the paper
M = 16
ef_construction = 100
ef_search = 1000
base_dir = f'indexes_{dataset_name}'
os.makedirs(base_dir, exist_ok=True)

# ---------------------------------------------------------
# 1. BM25 Index (Lucene)
# ---------------------------------------------------------
bm25_docs_dir = os.path.join(base_dir, 'bm25_docs')
bm25_index_dir = os.path.join(base_dir, 'bm25_index')
os.makedirs(bm25_docs_dir, exist_ok=True)

print("Building BM25 Index...")
# Write JSONL for Pyserini
with open(os.path.join(bm25_docs_dir, 'docs.jsonl'), 'w', encoding='utf-8') as f:
    for did, text in zip(doc_ids, doc_texts):
        f.write(json.dumps({'id': did, 'contents': text}) + "\n")

# Run Pyserini Indexer
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonCollection',
    '--input', bm25_docs_dir,
    '--index', bm25_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--threads', '16',
    '--storePositions', '--storeDocvectors', '--storeRaw'
], check=True)
print("‚úÖ BM25 Index built.")

# ---------------------------------------------------------
# 2. Dense Index (FAISS HNSW & Flat)
# ---------------------------------------------------------
print("Building FAISS Indexes...")

# HNSW (Approximate)
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = ef_construction
hnsw_index.hnsw.efSearch = ef_search
hnsw_index.add(doc_embeddings)
faiss.write_index(hnsw_index, os.path.join(base_dir, 'hnsw_index.faiss'))

# Flat (Exact)
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(doc_embeddings)
faiss.write_index(flat_index, os.path.join(base_dir, 'flat_index.faiss'))

print("‚úÖ FAISS Indexes built.")

# Cell 5b: Build INT8 Quantized Indexes (For Tables 3 & 4)

In [None]:
print("Building INT8 Quantized Indexes...")

# 1. Quantize Embeddings (Float32 -> Int8)
# Simple linear quantization: map [-128, 127] float to int8
doc_embeddings_int8 = np.clip(doc_embeddings * 127, -128, 127).astype(np.int8).astype(np.float32) / 127
query_embeddings_int8 = np.clip(query_embeddings * 127, -128, 127).astype(np.int8).astype(np.float32) / 127

# 2. Build INT8 Indexes
# HNSW INT8
start_t = time.time()
hnsw_int8_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_int8_index.hnsw.efConstruction = ef_construction
hnsw_int8_index.hnsw.efSearch = ef_search
hnsw_int8_index.add(doc_embeddings_int8)
faiss.write_index(hnsw_int8_index, os.path.join(base_dir, 'hnsw_int8_index.faiss'))
time_hnsw_int8 = time.time() - start_t

# Flat INT8
start_t = time.time()
flat_int8_index = faiss.IndexFlatIP(dimension)
flat_int8_index.add(doc_embeddings_int8)
faiss.write_index(flat_int8_index, os.path.join(base_dir, 'flat_int8_index.faiss'))
time_flat_int8 = time.time() - start_t

print(f"‚úÖ INT8 Indexes built. HNSW Time: {time_hnsw_int8:.2f}s, Flat Time: {time_flat_int8:.2f}s")

# Cell 6: SPLADE - Manual Encoding & Matrix Construction
Replaces the Pyserini SPLADE implementation with the robust Matrix Approach

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

# Configuration
SPLADE_MODEL = 'naver/splade-cocondenser-selfdistil' # Standard for BEIR
splade_manual_dir = os.path.join(base_dir, 'splade_encoded_manual')
os.makedirs(splade_manual_dir, exist_ok=True)
splade_jsonl_path = os.path.join(splade_manual_dir, 'docs.jsonl')

print(f"üöÄ SPLADE Matrix Preparation using: {SPLADE_MODEL}")

# Load HF Model
tokenizer = AutoTokenizer.from_pretrained(SPLADE_MODEL)
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_MODEL)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
splade_model.to(device)
splade_model.eval()

# Helper: Manual Encoding with explicit quantization (x100)
def encode_batch_manual(texts, tokenizer, model, device):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        # SPLADE logic: log(1 + ReLU(logits)) * attention_mask
        values = torch.log(1 + torch.relu(outputs.logits))
        values = values * inputs['attention_mask'].unsqueeze(-1)
        values, _ = torch.max(values, dim=1)
    
    batch_vectors = []
    values_np = values.cpu().numpy()
    for i in range(len(texts)):
        idx = values_np[i].nonzero()[0]
        # Quantize float weights to integer (w * 100)
        vector = {str(tokenizer.decode([t_id])): int(values_np[i][t_id] * 100) 
                  for t_id in idx if values_np[i][t_id] > 0}
        batch_vectors.append(vector)
    return batch_vectors

# 1. Encode Documents to JSONL
if not os.path.exists(splade_jsonl_path):
    print("Encoding documents...")
    batch_size = 32
    with open(splade_jsonl_path, 'w', encoding='utf-8') as f:
        for i in tqdm(range(0, len(doc_texts), batch_size), desc="Encoding"):
            batch_t = doc_texts[i:i+batch_size]
            batch_i = doc_ids[i:i+batch_size]
            vectors = encode_batch_manual(batch_t, tokenizer, splade_model, device)
            for did, vec in zip(batch_i, vectors):
                f.write(json.dumps({'id': did, 'vector': vec}) + '\n')
else:
    print("Found existing encoded file, skipping encoding step.")

# 2. Build Sparse Matrix (CSR)
print("Building Sparse Matrix from JSONL...")
vocab_size = tokenizer.vocab_size
data, rows, cols, doc_ids_ordered = [], [], [], []
row_idx = 0

with open(splade_jsonl_path, 'r', encoding='utf-8') as f:
    for line in tqdm(f, desc="Matrix Build"):
        entry = json.loads(line)
        doc_ids_ordered.append(entry['id'])
        for token_str, weight in entry['vector'].items():
            # Handle token mapping (string/ID) robustly
            try: 
                col_idx = int(token_str)
            except ValueError: 
                col_idx = tokenizer.convert_tokens_to_ids(token_str)
            
            if col_idx < vocab_size:
                rows.append(row_idx)
                cols.append(col_idx)
                data.append(weight)
        row_idx += 1

doc_matrix = csr_matrix((data, (rows, cols)), shape=(row_idx, vocab_size))
print(f"‚úÖ SPLADE Matrix ready: {doc_matrix.shape}")

# Cell 7: Search Functions (BM25, Dense, SPLADE Matrix)
Defines the search logic for all three methods

In [None]:
# BM25 Searcher
bm25_searcher = LuceneSearcher(bm25_index_dir)
bm25_searcher.set_bm25(k1=0.9, b=0.4) # Paper parameters

# Helper mapping
doc_id_to_idx = {did: i for i, did in enumerate(doc_ids)}

def run_bm25_search(queries, k=1000):
    print("Running BM25 search...")
    all_res = []
    start = time.time()
    for q in tqdm(queries, desc="BM25"):
        hits = bm25_searcher.search(q, k)
        indices = [doc_id_to_idx[h.docid] for h in hits]
        scores = [h.score for h in hits]
        all_res.append((indices, scores))
    qps = len(queries) / (time.time() - start)
    return all_res, qps

def run_dense_search(index, query_embs, k=1000):
    print("Running Dense search...")
    start = time.time()
    scores, indices = index.search(query_embs, k)
    qps = len(query_embs) / (time.time() - start)
    return list(zip(indices, scores)), qps

def run_splade_matrix_search(query_texts, k=1000):
    print("Running SPLADE Matrix search...")
    # 1. Encode Queries
    q_vectors = encode_batch_manual(query_texts, tokenizer, splade_model, device)
    
    # 2. Matrix Multiplication
    all_res = []
    start = time.time()
    for q_vec in tqdm(q_vectors, desc="Matrix Search"):
        # Construct sparse query vector
        q_data, q_cols = [], []
        for t, w in q_vec.items():
            try: cid = int(t)
            except: cid = tokenizer.convert_tokens_to_ids(t)
            if cid < vocab_size:
                q_data.append(w); q_cols.append(cid)
        
        q_sparse = csr_matrix((q_data, ([0]*len(q_data), q_cols)), shape=(1, vocab_size))
        
        # Dot Product
        scores = doc_matrix.dot(q_sparse.T).toarray().flatten()
        
        # Top-K
        if k < len(scores):
            top_k = np.argpartition(scores, -k)[-k:]
            top_k = top_k[np.argsort(scores[top_k])[::-1]]
        else:
            top_k = np.argsort(scores)[::-1]
            
        # Map back to doc indices consistent with doc_id_to_idx
        # Note: doc_matrix rows correspond to doc_ids_ordered
        real_doc_indices = [doc_id_to_idx[doc_ids_ordered[x]] for x in top_k]
        all_res.append((real_doc_indices, scores[top_k]))
        
    qps = len(query_texts) / (time.time() - start)
    return all_res, qps

# Cell 8: Evaluation Logic
Standard BEIR evaluation metrics (nDCG@10, Recall@10)

In [None]:
def evaluate_results(results_list, qrels, query_ids, k=10):
    recalls, ndcgs = [], []
    
    for i, (indices, scores) in enumerate(results_list):
        qid = query_ids[i]
        if qid not in qrels: continue
        
        relevant_docs = qrels[qid] # dict {docid: score}
        retrieved_docs = [doc_ids[idx] for idx in indices]
        
        # Recall
        rel_set = set(relevant_docs.keys())
        ret_set = set(retrieved_docs[:k])
        if len(rel_set) > 0:
            recalls.append(len(rel_set & ret_set) / len(rel_set))
        
        # nDCG
        dcg = 0
        for rank, doc_id in enumerate(retrieved_docs[:k], 1):
            rel = relevant_docs.get(doc_id, 0)
            dcg += (2**rel - 1) / np.log2(rank + 1)
        
        ideal = sorted(relevant_docs.values(), reverse=True)[:k]
        idcg = sum((2**r - 1) / np.log2(j + 2) for j, r in enumerate(ideal))
        ndcgs.append(dcg / idcg if idcg > 0 else 0)
        
    return np.mean(recalls), np.mean(ndcgs)

# Cell 9: Execution & Results Aggregation
Runs all searches and compiles the final DataFrame

In [None]:
k_eval = 10
k_ret = 1000

# 1. BM25
res_bm25, qps_bm25 = run_bm25_search(query_texts, k_ret)
rec_bm25, ndcg_bm25 = evaluate_results(res_bm25, qrels, query_ids, k_eval)

# 2. BGE HNSW
res_hnsw, qps_hnsw = run_dense_search(hnsw_index, query_embeddings, k_ret)
rec_hnsw, ndcg_hnsw = evaluate_results(res_hnsw, qrels, query_ids, k_eval)

# 3. BGE Flat
res_flat, qps_flat = run_dense_search(flat_index, query_embeddings, k_ret)
rec_flat, ndcg_flat = evaluate_results(res_flat, qrels, query_ids, k_eval)

# 4. SPLADE Matrix
res_splade, qps_splade = run_splade_matrix_search(query_texts, k_ret)
rec_splade, ndcg_splade = evaluate_results(res_splade, qrels, query_ids, k_eval)

# Compile Results
results_df = pd.DataFrame([
    {'Method': 'BM25', 'Type': 'Sparse (Baseline)', 'Recall@10': rec_bm25, 'nDCG@10': ndcg_bm25, 'QPS': qps_bm25},
    {'Method': 'SPLADE++ ED', 'Type': 'Sparse (Learned)', 'Recall@10': rec_splade, 'nDCG@10': ndcg_splade, 'QPS': qps_splade},
    {'Method': 'BGE-HNSW', 'Type': 'Dense (HNSW)', 'Recall@10': rec_hnsw, 'nDCG@10': ndcg_hnsw, 'QPS': qps_hnsw},
    {'Method': 'BGE-Flat', 'Type': 'Dense (Flat)', 'Recall@10': rec_flat, 'nDCG@10': ndcg_flat, 'QPS': qps_flat},
])

print("\n" + "="*80)
print(f"FINAL RESULTS: {dataset_name.upper()}")
print("="*80)
print(results_df.to_string(index=False))
print("="*80)

# Cell 9b: Run INT8 Searches & Generate Tables 3 and 4

In [None]:
print("Running INT8 Searches...")

# Run Searches
res_hnsw_int8, qps_hnsw_int8 = run_dense_search(hnsw_int8_index, query_embeddings_int8, k_ret)
rec_hnsw_int8, ndcg_hnsw_int8 = evaluate_results(res_hnsw_int8, qrels, query_ids, k_eval)

res_flat_int8, qps_flat_int8 = run_dense_search(flat_int8_index, query_embeddings_int8, k_ret)
rec_flat_int8, ndcg_flat_int8 = evaluate_results(res_flat_int8, qrels, query_ids, k_eval)

# --- TABLE 3: Indexing Time Comparison ---
# Note: We estimate FP32 times from previous cell execution or standard overhead
# If you didn't track precise times in Cell 5, these are approximations based on file write
table3_df = pd.DataFrame([
    {'Method': 'BGE-HNSW', 'Quantization': 'FP32', 'Index Time (s)': 0}, # Replace 0 with tracked time if available
    {'Method': 'BGE-HNSW', 'Quantization': 'int8', 'Index Time (s)': time_hnsw_int8},
    {'Method': 'BGE-Flat', 'Quantization': 'FP32', 'Index Time (s)': 0},
    {'Method': 'BGE-Flat', 'Quantization': 'int8', 'Index Time (s)': time_flat_int8},
])

# --- TABLE 4: INT8 Performance & Quality ---
table4_df = pd.DataFrame([
    {'Method': 'BGE-HNSW', 'Quantization': 'FP32', 'QPS': qps_hnsw, 'nDCG@10': ndcg_hnsw, 'Recall@10': rec_hnsw},
    {'Method': 'BGE-HNSW', 'Quantization': 'int8', 'QPS': qps_hnsw_int8, 'nDCG@10': ndcg_hnsw_int8, 'Recall@10': rec_hnsw_int8},
    {'Method': 'BGE-Flat', 'Quantization': 'FP32', 'QPS': qps_flat, 'nDCG@10': ndcg_flat, 'Recall@10': rec_flat},
    {'Method': 'BGE-Flat', 'Quantization': 'int8', 'QPS': qps_flat_int8, 'nDCG@10': ndcg_flat_int8, 'Recall@10': rec_flat_int8},
])

print("\n" + "="*80)
print(f"TABLE 3: INDEXING TIME (INT8 vs FP32)")
print("="*80)
print(table3_df.to_string(index=False))

print("\n" + "="*80)
print(f"TABLE 4: PERFORMANCE & QUALITY (INT8 vs FP32)")
print("="*80)
print(table4_df.to_string(index=False))

# Cell 10: Save & Visualize
Generates the plots and saves CSVs

In [None]:
# Cell 10: Save All Results & Generate Plots
output_dir = f'results_{dataset_name}'
os.makedirs(output_dir, exist_ok=True)

# 1. Save CSVs (Tables 1, 3, 4)
results_df.to_csv(os.path.join(output_dir, f'{dataset_name}_results.csv'), index=False)
table3_df.to_csv(os.path.join(output_dir, f'{dataset_name}_table3_indexing.csv'), index=False)
table4_df.to_csv(os.path.join(output_dir, f'{dataset_name}_table4_int8.csv'), index=False)

print(f"‚úÖ CSV tables saved.")

# 2. PLOT 1: Speed vs Quality (Scatter Plot)
plt.figure(figsize=(10, 8))
colors = {'Sparse (Baseline)': 'orange', 'Sparse (Learned)': 'red', 
          'Dense (HNSW)': 'steelblue', 'Dense (Flat)': 'lightblue'}

for _, row in results_df.iterrows():
    plt.scatter(row['QPS'], row['nDCG@10'], s=200, color=colors[row['Type']], label=row['Method'])
    plt.annotate(row['Method'], (row['QPS'], row['nDCG@10']), 
                 xytext=(0, 10), textcoords='offset points', fontsize=11, fontweight='bold')

plt.xlabel('QPS (Queries Per Second)', fontsize=12)
plt.ylabel('nDCG@10', fontsize=12)
plt.title(f'Speed vs Quality - {dataset_name}', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)

# Save Plot 1
plot1_path = os.path.join(output_dir, f'{dataset_name}_speed_vs_quality.pdf')
plt.savefig(plot1_path, bbox_inches='tight')
plt.close() # Close figure to free memory
print(f"‚úÖ Plot 1 saved: {plot1_path}")

# 3. PLOT 2: Metrics Comparison (Bar Chart)
plt.figure(figsize=(12, 6))
x = np.arange(len(results_df))
width = 0.35

plt.bar(x - width/2, results_df['nDCG@10'], width, label='nDCG@10', alpha=0.8, color='steelblue')
plt.bar(x + width/2, results_df['Recall@10'], width, label='Recall@10', alpha=0.8, color='orange')

plt.xlabel('Method', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title(f'Metrics Comparison - {dataset_name}', fontsize=14, fontweight='bold')
plt.xticks(x, results_df['Method'], rotation=15)
plt.legend()
plt.grid(True, alpha=0.3, axis='y')

# Save Plot 2
plot2_path = os.path.join(output_dir, f'{dataset_name}_metrics_comparison.pdf')
plt.savefig(plot2_path, bbox_inches='tight')
plt.close() # Close figure
print(f"‚úÖ Plot 2 saved: {plot2_path}")

# Cell 11. Export & Download Results

This final step gathers all generated outputs (CSVs, plots, metadata, and JSON logs) into a single directory and compresses them into a ZIP archive

The code automatically detects your running environment to provide the appropriate download method:
* **Google Colab**: Triggers a browser download automatically.
* **Kaggle**: Provides a clickable download link.
* **Local Machine**: Saves the ZIP file in the current directory.

In [None]:
# Cell 11: Export & Download Results (Kaggle Specific)
import os
import shutil
from IPython.display import FileLink, display

# 1. Configurazione
if 'dataset_name' not in locals():
    dataset_name = 'nfcorpus'
    
source_dir = f'results_{dataset_name}'
export_dir = f'FINAL_OUTPUT_{dataset_name}'
zip_filename = f'{export_dir}' # shutil aggiunge .zip automaticamente

print(f"üì¶ PREPARING EXPORT FOR: {dataset_name}")

# 2. Crea cartella export pulita
if os.path.exists(export_dir):
    shutil.rmtree(export_dir)
os.makedirs(export_dir)

# 3. Copia i file
file_count = 0
if os.path.exists(source_dir):
    for filename in os.listdir(source_dir):
        if filename.endswith(('.csv', '.pdf', '.png', '.json', '.txt')):
            shutil.copy2(os.path.join(source_dir, filename), os.path.join(export_dir, filename))
            file_count += 1
else:
    print(f"‚ö†Ô∏è Warning: Source folder '{source_dir}' not found.")

print(f"‚úÖ Collected {file_count} files.")

# 4. Crea ZIP
print(f"üìö Compressing to {zip_filename}.zip...")
shutil.make_archive(zip_filename, 'zip', export_dir)

# 5. Genera Link per Kaggle
zip_name_ext = f"{zip_filename}.zip"
print("\nüåç ENVIRONMENT: KAGGLE")
print("‚úÖ Archive ready. Click the link below to download:")

# Crea il link cliccabile
display(FileLink(zip_name_ext))

print("\n(Note: You can also find this file in the 'Output' tab of the Kaggle viewer)")