In [None]:
import sys
import os
sys.path.append(os.path.abspath("../utils"))

from preprocessing import load_dataset, preprocess_corpus, vectorize_corpus
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Categories to load
categories = [
    'comp.graphics',
    'rec.sport.baseball',
    'sci.med',
    'talk.politics.misc'
]

# Load raw data
data = load_dataset(categories, limit_per_category=300)
raw_texts = data['data']
labels = data['target']
label_names = data['target_names']

print("✅ Loaded", len(raw_texts), "samples")

# Preprocess text
cleaned_texts = preprocess_corpus(raw_texts)

# Vectorize
X, vectorizer = vectorize_corpus(cleaned_texts, max_features=3000)

print("✅ Vectorized text into shape:", X.shape)

# Save category info
with open('../data/categories_used.txt', 'w') as f:
    for cat in categories:
        f.write(cat + '\n')

# View category counts
df = pd.DataFrame({'text': cleaned_texts, 'label': labels})
df['label_name'] = df['label'].apply(lambda x: label_names[x])
category_counts = df['label_name'].value_counts()

# 📊 Visualize
plt.figure(figsize=(8, 5))
sns.barplot(x=category_counts.index, y=category_counts.values)
plt.title("Sample Count per Category")
plt.xlabel("Category")
plt.ylabel("Count")
plt.xticks(rotation=30)
plt.grid(True)
plt.tight_layout()
plt.show()
