Author: Will Blanton

# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import random
import itertools

import torch_geometric
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset
from torch.nn.functional import cosine_similarity
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import knn_graph

# Helper Functions

# Create Dataset

In [2]:
random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x18e0ace1f30>

In [3]:
print(f"Cuda available: {torch.cuda.is_available()}")

Cuda available: True


In [4]:
word_embedder = SentenceTransformer('all-MiniLM-L6-v2')

In [5]:
import ast
connections_df = pd.read_csv('data/connections.csv', index_col=0)
connections_df.drop(columns='category', inplace=True)

# fix issues with incorrect format... (not worth going in and adding a mechanism to correct...)
connections_df.loc[1298, "connections"] = "['line', 'plane', 'point', 'solid']"
connections_df.loc[1892, "connections"] = "['abyss', 'fly', 'matrix', 'thing']"

# remove april fools samples since they include emojis or other potentially noisy samples
connections_df = connections_df[~connections_df['date'].str.contains("04-01")]
connections_df.reset_index(drop=True, inplace=True)

connections_df['connections'] = connections_df['connections'].apply(ast.literal_eval)
connections_df

Unnamed: 0,date,connections
0,2023-06-12,"[kayak, level, mom, race car]"
1,2023-06-12,"[option, return, shift, tab]"
2,2023-06-12,"[bucks, heat, jazz, nets]"
3,2023-06-12,"[hail, rain, sleet, snow]"
4,2023-06-13,"[are, queue, sea, why]"
...,...,...
2751,2025-05-01,"[pot, prize, purse, reward]"
2752,2025-05-02,"[bottle, break, goose, turtle]"
2753,2025-05-02,"[dog, link, rib, wing]"
2754,2025-05-02,"[brace, post, prop, support]"


In [6]:
class ConnectionsData(Data):
    # needed to adjust the edge group indices for batching
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'group_indices':
            # group_indices refers to node indices → so shift by number of nodes
            return self.x.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)

In [7]:
"""

TODO: add more features to handle complicated cases
- phonetic word embeddings
- character-level embeddings (either trained via RNN or pre-trained)
- n-gram?
"""
class ConnectionsGraphDataset(Dataset):
    def __init__(
        self,
        puzzle_df: pd.DataFrame,
        word_emb_model: SentenceTransformer,
        negative_ratio: int = 3,
        include_purple: bool = False
    ):
        super().__init__()
        self.puzzle_df = puzzle_df
        self.negative_ratio = negative_ratio
        self.include_purple = include_purple
        self.data_list = []

        # get lists of words for each puzzle 
        words_per_date = (
            self.puzzle_df
            .groupby('date')['connections']
            .agg(lambda lists: list(itertools.chain.from_iterable(lists)))
        )

        # create data object for each puzzle
        for date, word_list in words_per_date.items():
            if len(word_list) != 16:
                raise ValueError(f'Word list length {len(word_list)} does not match 16 words')
            data = self._build_single_graph(date, word_list, word_emb_model)
            self.data_list.append(data)

    def _build_single_graph(self, date, word_list, word_emb_model):
        
        # create node features 
        x = word_emb_model.encode(word_list, convert_to_tensor=True)
        word2idx = {w: i for i, w in enumerate(word_list)}

        edge_index, edge_attr = self._make_graph(x)

        positives = self._collect_positives(date, word2idx)
        negatives = self._sample_negatives(len(word_list), positives)

        # remove purple pairs (already handled for negatives)
        if not self.include_purple:
            positives = positives[:-1]

        group_indices = torch.tensor(positives + negatives, dtype=torch.long)
        group_labels = torch.tensor(
            [1]*len(positives) + [0]*len(negatives),
            dtype=torch.float
        )
        
        data = ConnectionsData(
            x=x,    # (16, embed_dim)
            edge_index=edge_index,  # (2, num_edges)
            edge_attr=edge_attr, # (num_edges, 1)
            group_indices=group_indices,  # (num_groups, 4)
            group_labels=group_labels     # (num_groups,)
        )

        data.word_list = word_list
        return data

    def _make_graph(self, x, k=6):
        """
        Construct KNN graph for words
        :param x: 
        :param k: 
        :return: 
        """
        
        edge_index = knn_graph(x, k=k, batch=None, loop=False, cosine=True)

        # leave as directed since knn relationship is asymmetrical 
        # edge_index = to_undirected(edge_index)

        # edge attributes (cosine similarity)
        src = x[edge_index[0]]
        dst = x[edge_index[1]]
        edge_attr = cosine_similarity(src, dst).unsqueeze(1)

        return edge_index, edge_attr

    def _collect_positives(self, date, word2idx):
        subset = self.puzzle_df[self.puzzle_df['date'] == date]
        positives = []
        for _, row in subset.iterrows():
            idxs = [word2idx[w] for w in row['connections']]
            positives.append(sorted(idxs))
        return positives

    def _sample_negatives(self, num_nodes, positives):
        positive_set = set(tuple(g) for g in positives)
        all_quads = list(itertools.combinations(range(num_nodes), 4))
        random.shuffle(all_quads)

        negatives = []

        if self.include_purple: 
            max_neg = self.negative_ratio * len(positives)
        else:
            max_neg = self.negative_ratio * (len(positives) - 1)
        
        for quad in all_quads:
            if len(negatives) >= max_neg:
                break
            if tuple(quad) not in positive_set:
                negatives.append(list(quad))
        return negatives

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]    

In [8]:
# split data
dates = connections_df["date"].unique()
train_idx = random.sample(range(len(dates)), k=int(len(dates) * .8))
train_dates = dates[train_idx]
test_dates = dates[~np.isin(range(len(dates)), train_idx)]

test_dates = np.random.permutation(test_dates)
split_idx = len(test_dates) // 2
val_dates = test_dates[split_idx:]
test_dates = test_dates[:split_idx]

print(f"Total Length: {len(dates)}")
print(f"Train length: {len(train_dates)}")
print(f"Val length: {len(val_dates)}")
print(f"Test length: {len(test_dates)}")

Total Length: 689
Train length: 551
Val length: 69
Test length: 69


In [9]:
BATCH_SIZE = 64
NEGATIVE_RATIO = 3
INCLUDE_PURPLE = False

train = ConnectionsGraphDataset(connections_df.loc[connections_df["date"].isin(train_dates)], word_embedder, negative_ratio=NEGATIVE_RATIO, include_purple=INCLUDE_PURPLE)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

val = ConnectionsGraphDataset(connections_df.loc[connections_df["date"].isin(val_dates)], word_embedder, negative_ratio=NEGATIVE_RATIO, include_purple=INCLUDE_PURPLE)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)

test = ConnectionsGraphDataset(connections_df.loc[connections_df["date"].isin(test_dates)], word_embedder, negative_ratio=NEGATIVE_RATIO, include_purple=INCLUDE_PURPLE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
del word_embedder

In [11]:
NUM_SAMPLES = (4 if INCLUDE_PURPLE else 3) * (NEGATIVE_RATIO + 1)

In [12]:
for batch in DataLoader(val, batch_size=2, shuffle=True):
    print(batch)
    print(batch.group_indices)

    x = batch.x
    group_embeds = x[batch.group_indices]
    print(group_embeds.view(group_embeds.size(0), -1).shape)
    
    break

ConnectionsDataBatch(x=[32, 384], edge_index=[2, 192], edge_attr=[192, 1], group_indices=[24, 4], group_labels=[24], word_list=[2], batch=[32], ptr=[3])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [ 1,  3,  5,  8],
        [ 4,  5, 10, 11],
        [ 3,  5,  8, 14],
        [ 0,  2,  4, 15],
        [ 3,  4,  8, 12],
        [ 0,  6,  7,  8],
        [ 5,  7, 12, 14],
        [ 6, 11, 13, 15],
        [ 6,  8, 10, 12],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [17, 26, 27, 28],
        [21, 22, 30, 31],
        [20, 25, 26, 30],
        [19, 23, 24, 30],
        [18, 21, 22, 27],
        [17, 23, 24, 27],
        [18, 24, 25, 31],
        [17, 23, 28, 31],
        [20, 25, 28, 29]])
torch.Size([24, 1536])


# Model Architecture

In [13]:
from torch_geometric.nn import GATv2Conv

class GraphEncoderLayer(nn.Module):
    def __init__(self, hidden_dim, local_heads, global_heads, dropout, edge_dim):
        super().__init__()
        
        self.local_conv = GATv2Conv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            heads=local_heads,
            edge_dim=edge_dim,
            dropout=dropout,
            concat=False
        )
        
        self.global_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=global_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Dropout + LayerNorm
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, edge_index, edge_attr):
        residual = x

        x = self.local_conv(x, edge_index, edge_attr)

        x = x.unsqueeze(0)
        x, _ = self.global_attn(x, x, x)
        x = x.squeeze(0)

        x = self.dropout(x)
        x = self.norm(x + residual)

        return x

In [14]:
# TODO: implement an auto-regressive variant that builds groups instead of validating groups
class ConnectionsGNN(torch.nn.Module):
    def __init__(
            self,
            in_dim,
            num_layers=1,
            hidden_size=256,
            output_size=1,
            attn_heads=8,
            agg_heads=2,
            dropout=0.2,
            edge_dim=1
        ):
        super().__init__()

        self.lin_proj = nn.Linear(in_dim, hidden_size)
        
        self.layers = torch.nn.ModuleList()
        
        for i in range(num_layers):
            local_heads = attn_heads[i] if isinstance(attn_heads, list) else attn_heads
            self.layers.append(GraphEncoderLayer(hidden_size, local_heads, 2, dropout + .5, edge_dim))

        # batch_first so dim 1 is considered as the sequence (even though technically not a batch)
        self.attn_agg = nn.MultiheadAttention(hidden_size, agg_heads, dropout=dropout, batch_first=True)
            
        self.out = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.LayerNorm(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_size // 4, output_size)
        )

        
    def forward(self, x, edge_index, edge_attr, group_indices):
        
        x = self.lin_proj(x)    # (num_nodes, in_dim) -> (num_nodes, hidden_dim)

        # compute graph embeddings 
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr) # (num_nodes, hidden_dim)
            
        # aggregate groups for final scores
        group_embs = x[group_indices]   # (num_groups, 4, hidden_dim)
        group_embs, _ = self.attn_agg(group_embs, group_embs, group_embs)   # (num_groups, 4, hidden_dim)
        group_embs, _ = group_embs.max(dim=1)   # (num_groups, hidden_dim)
        
        logits = self.out(group_embs)    # (num_groups, 1)
        return logits.squeeze(-1)   # (num_groups,)

# Train Model

In [15]:
from sklearn.metrics import roc_auc_score

def compute_auc(logits, labels):
    """
    logits: (batch_size,) torch.Tensor (raw logits)
    labels: (batch_size,) torch.Tensor (binary labels)
    """

    # Apply sigmoid to get probabilities
    probs = torch.sigmoid(logits)

    # Move to CPU and numpy
    probs = probs.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()

    # Compute ROC AUC
    auc = roc_auc_score(labels, probs)

    return auc

In [16]:
def evaluate(model, loader, loss_fn, device):
    model.eval()
    
    total_loss = 0
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in loader:
            x = batch.x.to(device)
            edge_index = batch.edge_index.to(device)
            edge_attr = batch.edge_attr.to(device)
            group_indices = batch.group_indices.to(device)
            y = batch.group_labels.to(device)

            logits = model(x, edge_index, edge_attr, group_indices)

            # Loss (BCEWithLogits → already mean over batch)
            batch_loss = loss_fn(logits, y)
            total_loss += batch_loss.item()

            # Store for AUC
            all_logits.append(logits.detach().cpu())
            all_labels.append(y.detach().cpu())

    # Compute average loss
    avg_loss = total_loss / len(loader)

    # Compute AUC → concatenate everything
    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)

    probs = torch.sigmoid(all_logits).numpy()
    labels = all_labels.numpy()

    auc = roc_auc_score(labels, probs)

    return avg_loss, auc

In [19]:
from tensorboardX import SummaryWriter

writer = SummaryWriter(log_dir="runs/connections_experiment")

EPOCHS = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ConnectionsGNN(
    in_dim=384,
    num_layers=2,
    hidden_size=1024,
    output_size=1,
    attn_heads=[8, 8, 4, 4, 2, 2],
    agg_heads=2,
    dropout=0.2,
    edge_dim=1 
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)

loss = nn.BCEWithLogitsLoss()

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0

    for batch in train_loader:
        x = batch.x.to(device)
        edge_index = batch.edge_index.to(device)
        edge_attr = batch.edge_attr.to(device)
        group_indices = batch.group_indices.to(device)
        y = batch.group_labels.to(device)

        logits = model(x, edge_index, edge_attr, group_indices)

        batch_loss = loss(logits, y)
        train_loss += batch_loss.item()

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)

    val_loss, val_auc = evaluate(model, val_loader, loss, device)
    test_loss, test_auc = evaluate(model, test_loader, loss, device)

    scheduler.step(val_loss)

    if epoch % 5 == 0:
        # --- Tensorboard logging ---
        writer.add_scalar("loss/train", train_loss, epoch)
        writer.add_scalar("loss/val", val_loss, epoch)
        writer.add_scalar("loss/test", test_loss, epoch)

        writer.add_scalar("auc/val", val_auc, epoch)
        writer.add_scalar("auc/test", test_auc, epoch)

        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar("lr", current_lr, epoch)

        for name, param in model.named_parameters():
            writer.add_histogram(f"param_hist/{name}", param, epoch)

            if param.grad is not None:
                writer.add_histogram(f"grad_hist/{name}", param.grad, epoch)

        print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f} | Test Loss: {test_loss:.4f} | Test AUC: {test_auc:.4f}")

Epoch 0 | Train Loss: 0.6323 | Val Loss: 0.5645 | Val AUC: 0.4869 | Test Loss: 0.5645 | Test AUC: 0.4702
Epoch 5 | Train Loss: 0.5636 | Val Loss: 0.5624 | Val AUC: 0.4610 | Test Loss: 0.5624 | Test AUC: 0.4662
Epoch 10 | Train Loss: 0.5629 | Val Loss: 0.5624 | Val AUC: 0.4605 | Test Loss: 0.5624 | Test AUC: 0.4685
Epoch 15 | Train Loss: 0.5633 | Val Loss: 0.5624 | Val AUC: 0.4627 | Test Loss: 0.5624 | Test AUC: 0.4680
Epoch 20 | Train Loss: 0.5632 | Val Loss: 0.5624 | Val AUC: 0.4648 | Test Loss: 0.5624 | Test AUC: 0.4684
Epoch 25 | Train Loss: 0.5620 | Val Loss: 0.5624 | Val AUC: 0.4674 | Test Loss: 0.5624 | Test AUC: 0.4727
Epoch 30 | Train Loss: 0.5628 | Val Loss: 0.5623 | Val AUC: 0.4663 | Test Loss: 0.5623 | Test AUC: 0.4727
Epoch 35 | Train Loss: 0.5637 | Val Loss: 0.5623 | Val AUC: 0.4656 | Test Loss: 0.5623 | Test AUC: 0.4733
Epoch 40 | Train Loss: 0.5645 | Val Loss: 0.5623 | Val AUC: 0.4659 | Test Loss: 0.5623 | Test AUC: 0.4735
Epoch 45 | Train Loss: 0.5632 | Val Loss: 0.5623

KeyboardInterrupt: 