In [1]:
import site
from time import time

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer

In [2]:
categories = [
    "alt.atheism",
    "talk.religion.misc",
    "comp.graphics",
    "sci.space",
]

In [3]:
def size_mb(docs):
    return sum(len(s.encode('utf-8')) for s in docs) / 1e6

In [4]:
def load_dataset(verbose=False, remove=()):
    """Load and vectorize the 20 newsgroups dataset."""
    data_train = fetch_20newsgroups(
        subset='train', categories=categories, shuffle=True, random_state=42,
        remove=remove,
    )

    data_test = fetch_20newsgroups(
        subset='test', categories=categories, shuffle=True, random_state=42,
        remove=remove,
    )

    target_names = data_train.target_names

    y_train, y_test = data_train.target, data_test.target

    t0 = time()
    vectorizer = TfidfVectorizer(
        sublinear_tf=True, max_df=0.5, stop_words='english', min_df=5
    )

    xtr = vectorizer.fit_transform(data_train.data)
    duration_train = time() - t0

    to = time()
    xts = vectorizer.transform(data_test.data)
    duration_test = time() - t0

    feature_names = vectorizer.get_feature_names_out()

    if verbose:
        data_train_size_mb = size_mb(data_train.data)
        data_test_size_mb = size_mb(data_test.data)

        print(
            f"{len(data_train.data)} documents - "
            f"{data_train_size_mb:.2f}MB (training set)"
        )
        print(f"{len(data_test.data)} documents - {data_test_size_mb:.2f}MB (test set)")
        print(f"{len(target_names)} categories")
        print(
            f"vectorize training done in {duration_train:.3f}s "
            f"at {data_train_size_mb / duration_train:.3f}MB/s"
        )
        print(f"n_samples: {xtr.shape[0]}, n_features: {xtr.shape[1]}")
        print(
            f"vectorize testing done in {duration_test:.3f}s "
            f"at {data_test_size_mb / duration_test:.3f}MB/s"
        )
        print(f"n_samples: {xts.shape[0]}, n_features: {xts.shape[1]}")

    return xtr, xts, y_train, y_test, feature_names, target_names
