# **Imports**

In [1]:
import torch
import torch.nn.functional as F
import networkx as nx
import plotly.graph_objects as go
from transformers import BertTokenizer, BertModel
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import display

# **Define BERTSimilarity Class**

In [2]:
class BERTSimilarity:
    def __init__(self, model_name="bert-base-uncased", vocab_size=30522, batch_size=256):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        raw_vocab = list(self.tokenizer.get_vocab().keys())
        self.vocab = self._filter_vocab(raw_vocab)[:vocab_size]
        self.batch_size = batch_size

        self.token_embeddings = self._compute_vocab_embeddings()

    def _filter_vocab(self, vocab):
        """
        Removes special tokens, subwords, non-alphabetic tokens, short words.
        """
        filtered = [
            token for token in vocab
            if token.isalpha() and not token.startswith("##") and len(token) > 2
        ]
        return filtered

    def _compute_vocab_embeddings(self):
        """
        Compute embeddings for filtered vocab using batching.
        """
        embeddings = []
        for i in tqdm(range(0, len(self.vocab), self.batch_size), desc="Computing embeddings"):
            batch_tokens = self.vocab[i:i+self.batch_size]
            inputs = self.tokenizer(batch_tokens, return_tensors="pt", padding=True, truncation=True).to(self.device)

            with torch.no_grad():
                outputs = self.model(**inputs)
                batch_embeds = outputs.last_hidden_state.mean(dim=1)
                embeddings.append(batch_embeds)

        return torch.cat(embeddings, dim=0)

    def get_similar_words(self, input_word, top_n=10):
        """
        Given an input word, return top N similar words.
        """
        inputs = self.tokenizer(input_word, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            input_embedding = outputs.last_hidden_state.mean(dim=1)

        similarities = F.cosine_similarity(input_embedding, self.token_embeddings, dim=1)
        top_indices = torch.topk(similarities, top_n).indices
        similar_words = [(self.vocab[i], float(similarities[i])) for i in top_indices]
        return similar_words

In [3]:
bert_model = BERTSimilarity(vocab_size=30522)
print("Model loaded & embeddings computed...")

Using device: cuda


Computing embeddings:   0%|          | 0/84 [00:00<?, ?it/s]

Model loaded & embeddings computed...


# **Similarity Matrix**

In [4]:
def get_sub_matrix(word_list, engine):
    """
    Re-computes embeddings for just the small list of words to find
    connections between them.
    """
    inputs = engine.tokenizer(word_list, return_tensors="pt", padding=True, truncation=True).to(engine.device)
    with torch.no_grad():
        outputs = engine.model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
    
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return torch.mm(embeddings, embeddings.t()).cpu().numpy()

# **Visualize**

In [5]:
def visualize_network(seed_word, top_n, threshold):
    try:
        # Get Similar Words
        similar_results = bert_model.get_similar_words(seed_word, top_n=top_n)
        print(f"Displaying top {top_n} words similar to '{seed_word}'")
        
        # Prepare Data for Graph
        found_words = [seed_word] + [item[0] for item in similar_results]
        sim_matrix = get_sub_matrix(found_words, bert_model)

        G = nx.Graph()
        for w in found_words:
            G.add_node(w)
            
        rows, cols = sim_matrix.shape
        for i in range(rows):
            for j in range(i + 1, cols):
                score = sim_matrix[i][j]
                if score > threshold:
                    G.add_edge(found_words[i], found_words[j], weight=score)

        # NetworkX Layout
        pos = nx.spring_layout(G, seed=42, k=0.5)

        # Plotly Traces
        edge_x, edge_y = [], []
        for edge in G.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])

        edge_trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=0.5, color='#888'),
            hoverinfo='none',
            mode='lines')

        node_x, node_y = [], []
        node_text = []
        node_adjacencies = []
        
        for node in G.nodes():
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
            node_text.append(node)
            node_adjacencies.append(len(G.adj[node]))

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode='markers+text',
            text=node_text,
            textposition="top center",
            hoverinfo='text',
            marker=dict(showscale=True, colorscale='Plasma', reversescale=True, color=node_adjacencies,
                        size=15, colorbar=dict(thickness=15, title=dict(text='Connections', side='right'),
                                               xanchor='left')))

        fig = go.Figure(data=[edge_trace, node_trace],
                        layout=go.Layout(
                            title=f"Network for '{seed_word}' (Similarity: {threshold})",
                            showlegend=False, hovermode='closest',
                            margin=dict(b=20,l=5,r=5,t=40),
                            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))
        fig.show()
        
    except Exception as e:
        print(f"Error: {e}")

In [6]:
# Create Widgets
seed_input = widgets.Text(value='science', description='Search Term:')
top_n_slider = widgets.IntSlider(value=30, min=10, max=100, step=5, description='Number of Words:')
threshold_slider = widgets.FloatSlider(value=0.9, min=0.5, max=0.99, step=0.01, description='Connection Threshold:')

# Use interact to make it reactive
ui = widgets.VBox([seed_input, top_n_slider, threshold_slider])
out = widgets.interactive_output(visualize_network, {'seed_word': seed_input, 'top_n': top_n_slider, 'threshold': threshold_slider})

display(ui, out)

VBox(children=(Text(value='science', description='Search Term:'), IntSlider(value=30, description='Number of Wâ€¦

Output()