In [8]:
# ===============================================================
# FAST Joint Sentimentâ€“Topic Model (JST) using Gibbs Sampling
# Optimized for Google Colab (100 topics)
# ===============================================================

!pip install datasets sentence-transformers scikit-learn numpy tqdm

import re
import numpy as np
from tqdm import tqdm
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import pickle

# -------------------------
# PARAMETERS
# -------------------------
K = 100               # topics
S = 2                 # sentiments
V = 3000
MAX_DOCS = 1000
DOC_MAX_WORDS = 100
SEED = 42
ITERATIONS = 100 # Define ITERATIONS here

alpha_base = 0.1
alpha_boost = 4.0
beta = 0.01
gamma = 0.1

np.random.seed(SEED)
random.seed(SEED)

# -------------------------
# TOKENIZER
# -------------------------
def tokenize(text):
    return re.findall(r"[a-zA-Z']+", text.lower())

# -------------------------
# 1. Load Data
# -------------------------
print("Loading IMDB...")
dataset = load_dataset("imdb", split="train")
if MAX_DOCS:
    dataset = dataset.select(range(min(MAX_DOCS, len(dataset))))

docs = [tokenize(x["text"])[:DOC_MAX_WORDS] for x in dataset]

# -------------------------
# 2. Build Vocabulary
# -------------------------
print("Building vocabulary...")
from collections import Counter
wc = Counter()
for d in docs:
    wc.update(d)

vocab_words = [w for w, _ in wc.most_common(V)]
vocab = {w:i for i,w in enumerate(vocab_words)}
inv_vocab = {i:w for w,i in vocab.items()}

UNK = V
V_eff = V + 1

docs_w = [[vocab[w] if w in vocab else UNK for w in d] for d in docs]
D = len(docs_w)
print(f"Docs: {D}, Vocab size: {V_eff}")

# -------------------------
# 3. Embeddings + Clustering
# -------------------------
print("Embedding documents...")
model = SentenceTransformer("all-MiniLM-L6-v2")
emb = model.encode([" ".join(d) for d in docs], batch_size=64)

print("Clustering...")
kmeans = KMeans(n_clusters=K, n_init=10)
cluster_id = kmeans.fit_predict(emb)

# -------------------------
# 4. Initialize Counts
# -------------------------
alpha_d = np.full((D, K), alpha_base)
for d in range(D):
    alpha_d[d, cluster_id[d]] += alpha_boost

gamma_vec = np.full(S, gamma)

n_d_s = np.zeros((D, S))
n_d_s_z = np.zeros((D, S, K))
n_s_z_w = np.zeros((S, K, V_eff))
n_s_z = np.zeros((S, K))

assign_s = []
assign_z = []

print("Random initialization...")
for d, wids in enumerate(docs_w):
    As = []
    Az = []
    for w in wids:
        s = np.random.randint(S)
        z = np.random.randint(K)
        As.append(s)
        Az.append(z)
        n_d_s[d, s] += 1
        n_d_s_z[d, s, z] += 1
        n_s_z_w[s, z, w] += 1
        n_s_z[s, z] += 1
    assign_s.append(As)
    assign_z.append(Az)

# -------------------------
# 5. FAST GIBBS SAMPLING
# -------------------------
print("Starting FAST Gibbs Sampling...")

for it in range(ITERATIONS): # Use ITERATIONS here
    for d in range(D):
        for i, w in enumerate(docs_w[d]):

            s_old = assign_s[d][i]
            z_old = assign_z[d][i]

            # remove old
            n_d_s[d, s_old] -= 1
            n_d_s_z[d, s_old, z_old] -= 1
            n_s_z_w[s_old, z_old, w] -= 1
            n_s_z[s_old, z_old] -= 1

            # probability computation (vectorized)
            probs_s = (n_d_s[d] + gamma_vec)  # shape (S)

            probs = np.zeros((S, K))

            for s in range(S):
                term2 = n_d_s_z[d, s] + alpha_d[d]
                term3 = (n_s_z_w[s, :, w] + beta) / (n_s_z[s] + V_eff*beta)
                probs[s] = probs_s[s] * term2 * term3

            flat = probs.ravel()
            flat /= flat.sum()

            idx = np.random.choice(S*K, p=flat)
            s_new = idx // K
            z_new = idx % K

            assign_s[d][i] = s_new
            assign_z[d][i] = z_new

            n_d_s[d, s_new] += 1
            n_d_s_z[d, s_new, z_new] += 1
            n_s_z_w[s_new, z_new, w] += 1
            n_s_z[s_new, z_new] += 1

    if (it+1) % 10 == 0:
        global_sent = (n_d_s.sum(axis=0) + 1e-9) / (n_d_s.sum() + 1e-9)
        print(f"Iter {it+1}/{ITERATIONS}: sentiment={global_sent}") # Use ITERATIONS here

print("Sampling complete.")


# -------------------------
# 6. Compute phi, etc.
# -------------------------
phi = (n_s_z_w + beta) / (n_s_z[:,:,None] + V_eff*beta)

print("Top words in topic 0:")
top = phi[1,0].argsort()[-10:]
print([inv_vocab.get(i,"UNK") for i in top])

Loading IMDB...
Building vocabulary...
Docs: 1000, Vocab size: 3001
Embedding documents...
Clustering...
Random initialization...
Starting FAST Gibbs Sampling...
Iter 10/100: sentiment=[0.50430673 0.49569327]
Iter 20/100: sentiment=[0.50403789 0.49596211]
Iter 30/100: sentiment=[0.50412061 0.49587939]
Iter 40/100: sentiment=[0.50404823 0.49595177]
Iter 50/100: sentiment=[0.50401721 0.49598279]
Iter 60/100: sentiment=[0.50392414 0.49607586]
Iter 70/100: sentiment=[0.50398619 0.49601381]
Iter 80/100: sentiment=[0.5039655 0.4960345]
Iter 90/100: sentiment=[0.50411027 0.49588973]
Iter 100/100: sentiment=[0.50409993 0.49590007]
Sampling complete.
Top words in topic 0:
['people', 'until', 'there', 'movies', 'great', 'only', 'watching', 'thing', 'like', "it's"]


In [9]:
print("Top words in topic 0:")
top = phi[0,1].argsort()[-10:]
print([inv_vocab.get(i,"UNK") for i in top])

Top words in topic 0:
['them', 'shows', 'like', 'completely', 'stupid', 'goes', 'is', 'that', 'with', 'are']


In [10]:
TOP_N = 10

for z in range(100):
    print(f"\nTopic {z}:")
    print("  Negative:", [inv_vocab.get(i,"UNK") for i in phi[0,z].argsort()[-TOP_N:]])
    print("  Positive:", [inv_vocab.get(i,"UNK") for i in phi[1,z].argsort()[-TOP_N:]])


Topic 0:
  Negative: ['life', 'is', 'then', 'of', 'the', 'by', 'her', 'and', 'in', 'UNK']
  Positive: ['people', 'until', 'there', 'movies', 'great', 'only', 'watching', 'thing', 'like', "it's"]

Topic 1:
  Negative: ['them', 'shows', 'like', 'completely', 'stupid', 'goes', 'is', 'that', 'with', 'are']
  Positive: ['close', 'getting', 'from', 'out', 'comes', 'one', 'by', 'but', 'it', 'not']

Topic 2:
  Negative: ['but', 'from', 'than', 'good', 'like', 'more', 'UNK', 'to', 'the', 'it']
  Positive: ['story', 'for', "it's", 'in', 'not', 'that', 'and', 'a', 'is', 'the']

Topic 3:
  Negative: ["it's", 'for', 'with', 'see', 'if', 'this', 'you', 'to', 'and', 'UNK']
  Positive: ['science', 'fiction', 'yet', 'time', 'when', "it's", 'about', 'that', 'is', 'the']

Topic 4:
  Negative: ['york', 'two', 'life', 'was', 'his', 'new', 'very', 'the', 'UNK', 'in']
  Positive: ['dialogue', 'seems', 'off', 'old', 'both', 'around', 'an', 'on', 'and', 'UNK']

Topic 5:
  Negative: ['working', 'however', 'mor