In [1]:
 !pip install -q nltk tqdm

In [37]:
import random, numpy as np, torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter
import numpy as np
import random
from tqdm import tqdm
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
from nltk.corpus import stopwords
nltk.download('stopwords')

stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [4]:
from nltk.corpus import gutenberg
nltk.download('gutenberg')

sentences = []
for fileid in gutenberg.fileids():
    for sent in nltk.sent_tokenize(gutenberg.raw(fileid)):
        tokens = nltk.word_tokenize(sent.lower())
        tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
        if len(tokens) > 3:
            sentences.append(tokens)

print("Number of sentences:", len(sentences))

[nltk_data] Downloading package gutenberg to /usr/share/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!


Number of sentences: 74973


In [5]:
VOCAB_SIZE = 8000
WINDOW_SIZE = 4

counter = Counter()
for s in sentences:
    counter.update(s)
    
MIN_FREQ = 50
vocab = [w for w, c in counter.items() if c >= MIN_FREQ]
vocab = sorted(vocab, key=lambda w: counter[w], reverse=True)[:VOCAB_SIZE]

word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}

In [6]:
total_count = sum(counter.values())

def keep_word(word):
    freq = counter[word] / total_count
    p = min(1.0, np.sqrt(1e-5 / freq))
    return random.random() < max(p, 0.2)
    
sentences = [
    [w for w in s if keep_word(w)]
    for s in sentences
]

In [7]:
pairs = []
for s in sentences:
    ids = [word2idx[w] for w in s if w in word2idx]
    for i, center in enumerate(ids):
        for j in range(max(0, i-WINDOW_SIZE), min(len(ids), i+WINDOW_SIZE+1)):
            if i != j:
                pairs.append((center, ids[j]))

random.shuffle(pairs)
print("Training pairs:", len(pairs))

Training pairs: 486028


In [8]:
# class DenseEmbedding(nn.Module):
#     def __init__(self, vocab_size, dim):
#         super().__init__()
#         self.embed = nn.Embedding(vocab_size, dim)

#     def forward(self, x):
#         return self.embed(x)

In [9]:
class MultiProtoInterpretableModel(nn.Module):
    def __init__(self, vocab_size, dense_dim=100, interp_dim=128, m_proto=4, tie_decoder=True, init_proto_ids=None):
        super().__init__()
        self.center_embed = nn.Embedding(vocab_size, dense_dim)
        self.context_embed = nn.Embedding(vocab_size, dense_dim)

        self.K = interp_dim
        self.M = m_proto

        # Prototypes: (K, M, d)
        self.prototypes = nn.Parameter(torch.randn(interp_dim, m_proto, dense_dim) * 0.01)
        self.tie_decoder = tie_decoder

        if not tie_decoder:
            self.decoder_raw = nn.Parameter(torch.randn(interp_dim, vocab_size) * 0.01)


        if init_proto_ids is None:
            init_proto_ids = [[-1]*m_proto for _ in range(interp_dim)]
        self.proto_ids = init_proto_ids  

    def _proto_matrix(self):
        #  (K, M, d) -> (K*M, d) 
        return self.prototypes.view(self.K * self.M, -1)

    def forward(self, center):
        z = self.center_embed(center)        
        P = self._proto_matrix()                     # (K*M, d)
        sims = cosine_sim(z, P)                      # (B, K*M)
        sims = sims.view(z.shape[0], self.K, self.M) # (B, K, M)
       
        s = torch.relu(sims.max(dim=2).values)       # (B, K)
        logits = s @ self.decoder_matrix()           # (B, V)
        return logits, s ,z

    def decoder_matrix(self):
        if self.tie_decoder:
            Z = self.context_embed.weight              # (V, d)
            P = self._proto_matrix()                 # (K*M, d)
            sims = cosine_sim(Z, P)                  # (V, K*M)
            sims = sims.view(Z.shape[0], self.K, self.M).permute(1,0,2)  # (K, V, M)
            sims = torch.relu(sims).max(dim=2).values # (K, V)
            return sims
        else:
            return torch.softmax(torch.relu(self.decoder_raw), dim=1)

In [10]:
def prototype_diversity_loss(P):
    Pl = P.view(-1, P.shape[-1])
    Pl = Pl / (Pl.norm(dim=1, keepdim=True) + 1e-8)
    G = Pl @ Pl.T
    off = G - torch.diag(torch.diag(G))
    return (off**2).mean()

In [11]:
CE = nn.CrossEntropyLoss()

def sparsity_loss(s):
    return torch.mean(torch.abs(s))  # L1

In [12]:
def cosine_sim(a, b, eps=1e-8):
    a_n = a / (a.norm(dim=1, keepdim=True) + eps)
    b_n = b / (b.norm(dim=1, keepdim=True) + eps)
    return a_n @ b_n.T  # (B, K)

In [13]:
def prototype_pull_loss(z, s, prototypes, m_proto):
    """
    z: (B, d) center embeddings
    s: (B, K) factor activations
    prototypes: (K, M, d)
    """
    B, K = s.shape
    loss = 0.0

    for k in range(K):
        proto_k = prototypes[k]                  # (M, d)
        sim = cosine_sim(z, proto_k)              # (B, M)
        best = sim.max(dim=1).values              # (B,)
        loss += torch.mean(s[:, k] * (1 - best))  # weighted pull

    return loss / K

In [14]:
def dimension_centroid_loss(z, s):
    centroids = (s.T @ z) / (s.sum(dim=0).unsqueeze(1) + 1e-8)
    sims = cosine_sim(centroids, centroids)
    off = sims - torch.diag(torch.diag(sims))
    return (off ** 2).mean()

In [15]:
def prototype_entropy_loss(z, prototypes):
    P = prototypes.view(-1, prototypes.shape[-1])  # (K*M, d)
    sims = cosine_sim(z, P)                         # (B, K*M)
    probs = torch.softmax(sims, dim=1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
    return entropy.mean()

In [16]:
def orthogonality_loss(A):
    AtA = A @ A.T
    I = torch.eye(A.shape[0], device=A.device)
    return ((AtA - I) ** 2).mean()

In [17]:
def decorrelation_loss(s):
    s = s - s.mean(dim=0)
    cov = (s.T @ s) / s.shape[0]
    off_diag = cov - torch.diag(torch.diag(cov))
    return (off_diag ** 2).mean()

In [35]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = MultiProtoInterpretableModel(
    vocab_size=len(vocab),
    dense_dim=100,
    interp_dim=128,
    m_proto=4,          
    tie_decoder=True
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.05)

EPOCHS = 5
BATCH_SIZE = 256
λ_sparse, λ_decorr, λ_div ,λ_pull ,λ_cent ,λ_ent = 0.3, 0.3, 0.4 , 0.2 , 0.1 ,0.05

In [38]:
for epoch in range(EPOCHS):
    total_loss = 0.0
    random.shuffle(pairs)

    for i in tqdm(range(0, len(pairs), BATCH_SIZE)):
        batch = pairs[i : i + BATCH_SIZE]
        center = torch.tensor([p[0] for p in batch], device=device)
        context = torch.tensor([p[1] for p in batch], device=device)

        logits, s ,z = model(center)

        loss = CE(logits, context) \
        + λ_sparse * sparsity_loss(s) \
        + λ_decorr * decorrelation_loss(s) \
        + λ_div * prototype_diversity_loss(model.prototypes) \
        + λ_pull * prototype_pull_loss(z, s, model.prototypes, model.M) \
        + λ_cent * dimension_centroid_loss(z, s)
        + λ_ent * prototype_entropy_loss(z, model.prototypes)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, loss={total_loss:.2f}")

100%|██████████| 1899/1899 [03:43<00:00,  8.48it/s]


Epoch 1, loss=13686.05


100%|██████████| 1899/1899 [03:44<00:00,  8.44it/s]


Epoch 2, loss=13573.54


100%|██████████| 1899/1899 [03:45<00:00,  8.43it/s]


Epoch 3, loss=13505.18


100%|██████████| 1899/1899 [03:44<00:00,  8.44it/s]


Epoch 4, loss=13456.25


100%|██████████| 1899/1899 [03:44<00:00,  8.44it/s]

Epoch 5, loss=13419.76





In [39]:
@torch.no_grad()
def factor_top_words(model, idx2word, k=10):
    D = model.decoder_matrix()        # (K, V)
    top = {}
    for dim in range(D.shape[0]):
        top_ids = torch.topk(D[dim], k).indices.cpu().numpy()
        top[dim] = [idx2word[int(i)] for i in top_ids]
    return top

dimension_words = factor_top_words(model, idx2word, k=10)
for dim in range(20):
    print(f"\nDIM {dim}: {dimension_words[dim]}")


DIM 0: ['staves', 'scarlet', 'linen', 'skin', 'garment', 'ephod', 'hanging', 'garments', 'purple', 'cherubims']

DIM 1: ['whale', 'enormous', 'stern', 'rolling', 'chase', 'swung', 'floating', 'ahead', 'starbuck', 'tail']

DIM 2: ['suburbs', 'begat', 'tribe', 'reuben', 'nine', 'manasseh', 'bashan', 'villages', 'towns', 'gad']

DIM 3: ['son', 'day', 'work', 'next', 'patient', 'cool', 'arrived', 'others', 'pleasant', 'sons']

DIM 4: ['fruit', 'lands', 'welcome', 'fresh', 'east', 'north', 'songs', 'side', 'south', 'desert']

DIM 5: ['court', 'gate', 'purple', 'inner', 'porch', 'cubits', 'measured', 'pull', 'sockets', 'side']

DIM 6: ['offering', 'ram', 'sacrifice', 'lambs', 'burnt', 'bullock', 'goats', 'blemish', 'unleavened', 'sin']

DIM 7: ['thy', 'thou', 'thee', 'hast', 'shalt', 'mayest', 'wast', 'thine', 'forgive', 'art']

DIM 8: ['mrs', 'harriet', 'emma', 'acquaintance', 'elinor', 'musgrove', 'miss', 'lady', 'able', 'advantage']

DIM 9: ['son', 'god', 'eve', 'found', 'human', 'pure',

In [40]:
import pandas as pd
import requests
from scipy.stats import spearmanr

path = "/kaggle/input/wordsim353-crowd/wordsim353crowd.csv"
df = pd.read_csv(path)
df.head()

Unnamed: 0,Word 1,Word 2,Human (Mean)
0,admission,ticket,5.536
1,alcohol,chemistry,4.125
2,aluminum,metal,6.625
3,announcement,effort,2.0625
4,announcement,news,7.1875


In [22]:
def cosine(u, v, eps=1e-8):
    return (u @ v) / (np.linalg.norm(u) * np.linalg.norm(v) + eps)

In [41]:
@torch.no_grad()
def eval_wordsim(model, df, word2idx):
    sims, gold = [], []

    W = model.center_embed.weight.cpu().numpy()

    for _, row in df.iterrows():
        w1, w2, score = row['Word 1'], row['Word 2'], row['Human (Mean)']
        if w1 in word2idx and w2 in word2idx:
            v1, v2 = W[word2idx[w1]], W[word2idx[w2]]
            sims.append(cosine(v1, v2))
            gold.append(score)

    return spearmanr(sims, gold).correlation

eval_wordsim(model,df,word2idx)

0.3859162014435637

In [42]:
from collections import defaultdict
import math

def build_cooc(sentences):
    word_freq = Counter()
    pair_freq = defaultdict(int)

    for s in sentences:
        s = set(s)
        for w in s:
            word_freq[w] += 1
        for w1 in s:
            for w2 in s:
                if w1 < w2:
                    pair_freq[(w1, w2)] += 1

    return word_freq, pair_freq, len(sentences)

In [43]:
word_freq, pair_freq, N = build_cooc(sentences)

In [44]:
def npmi(words, word_freq, pair_freq, N):
    scores = []
    for i in range(len(words)):
        for j in range(i+1, len(words)):
            w1, w2 = words[i], words[j]
            if w1 not in word_freq or w2 not in word_freq:
                continue
            pw1 = word_freq[w1] / N
            pw2 = word_freq[w2] / N
            pair = tuple(sorted([w1, w2]))
            pw12 = pair_freq.get(pair, 0) / N
            if pw12 == 0:
                continue
            pmi = math.log(pw12 / (pw1 * pw2))
            scores.append(pmi / (-math.log(pw12)))
    return np.mean(scores) if scores else 0.0

In [45]:
def eval_coherence(dimension_words):
    scores = []
    for dim, words in dimension_words.items():
        scores.append(npmi(words, word_freq, pair_freq, N))
    return np.mean(scores), scores
eval_coherence(dimension_words)[0]

0.3157198220479913

In [46]:
@torch.no_grad()
def eval_sparsity(model, vocab_size=2000):
    ids = torch.arange(vocab_size, device=device)
    _, s, _ = model(ids)
    return (s > 0).float().sum(dim=1).mean().item()

eval_sparsity(model, vocab_size=2000)

75.21500396728516

In [47]:
@torch.no_grad()
def eval_dim_corr(model, vocab_size=2000):
    ids = torch.arange(vocab_size, device=device)
    _, s, _ = model(ids)
    s = s - s.mean(dim=0)
    cov = (s.T @ s) / s.shape[0]
    off = cov - torch.diag(torch.diag(cov))
    return off.abs().mean().item()

eval_dim_corr(model, vocab_size=2000)

0.0032747178338468075

In [48]:
@torch.no_grad()
def faithfulness_test(model, word, dim, topk=10):
    idx = torch.tensor([word2idx[word]], device=device)
    logits, s, z = model(idx)

    probs_orig = torch.softmax(logits, dim=1)
    s_mod = s.clone()
    s_mod[:, dim] = 0.0
    logits_mod = s_mod @ model.decoder_matrix()
    probs_mod = torch.softmax(logits_mod, dim=1)

    diff = torch.abs(probs_orig - probs_mod).mean().item()
    return diff

faithfulness_test(model, "sentence", dim, topk=10)

5.58105557502131e-06