In [1]:
import nltk
from nltk.corpus import reuters
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np

# Download the Reuters dataset
nltk.download('reuters')

# Load the Reuters dataset
documents = reuters.fileids()
labels = [reuters.categories(doc) for doc in documents]
texts = [reuters.raw(doc) for doc in documents]

# Convert labels to binary format
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(labels)

# Filter out underrepresented classes
min_samples_per_class = 10  # Increase the threshold to ensure sufficient samples
valid_classes = np.where(np.sum(y, axis=0) >= min_samples_per_class)[0]
y = y[:, valid_classes]
mlb.classes_ = mlb.classes_[valid_classes]

# Vectorize the text data
vectorizer = TfidfVectorizer(stop_words='english', max_df=0.5, smooth_idf=True)
X = vectorizer.fit_transform(texts)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

[nltk_data] Downloading package reuters to
[nltk_data]     /home/codespace/nltk_data...


In [3]:
# Train a multi-label classifier using Random Forest
classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1))
classifier.fit(X_train, y_train)

# Make predictions
y_pred = classifier.predict(X_test)

# Evaluate the model
print(classification_report(y_test, y_pred, target_names=mlb.classes_, zero_division=0))

                 precision    recall  f1-score   support

            acq       0.98      0.95      0.96       469
           alum       0.00      0.00      0.00         7
         barley       0.00      0.00      0.00         6
            bop       1.00      0.20      0.33        20
        carcass       0.67      0.13      0.22        15
          cocoa       1.00      0.29      0.45        17
         coffee       1.00      0.36      0.53        28
         copper       1.00      0.33      0.50         9
           corn       0.91      0.19      0.31        53
         cotton       0.00      0.00      0.00        11
            cpi       1.00      0.10      0.18        20
          crude       0.93      0.67      0.78       106
            dlr       0.93      0.41      0.57        34
            dmk       1.00      1.00      1.00         1
           earn       0.99      0.93      0.96       787
           fuel       0.00      0.00      0.00         2
            gas       0.33    