In [2]:
import tensorflow as tf
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
import numpy as np

# Sample dataset (texts with multiple tags)

In [3]:
texts = [
    "The stock market saw major gains today.",               # finance
    "New iPhone features cutting-edge technology.",          # tech
    "Google and Microsoft both released earnings reports.",  # tech, finance
    "The Lakers won the basketball championship.",           # sports
    "Tesla's new AI chips are revolutionizing cars."         # tech, auto
]
 
labels = [
    ["finance"],
    ["tech"],
    ["tech", "finance"],
    ["sports"],
    ["tech", "auto"]
]

# Binarize multilabels

mlb = MultiLabelBinarizer()
y = mlb.fit_transform(labels)

# TF-IDF vectorization

In [4]:
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(texts).toarray()

# Split data

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Build multilabel classification model

In [6]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(X.shape[1],)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(len(mlb.classes_), activation='sigmoid')  # One sigmoid per label
])

# Compile and train

In [7]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=10, verbose=0)

<keras.src.callbacks.history.History at 0x21906a64e60>

# Predict on new sample

In [8]:
new_text = ["Apple's quarterly earnings beat expectations."]
new_X = vectorizer.transform(new_text).toarray()
pred = model.predict(new_X)[0]
thresholded = [mlb.classes_[i] for i, p in enumerate(pred) if p > 0.5]
 
print("📝 New Text:\n", new_text[0])
print("\n📌 Predicted Tags:", thresholded)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
📝 New Text:
 Apple's quarterly earnings beat expectations.

📌 Predicted Tags: ['auto', 'finance']
