Author: Will Blanton

# Imports

In [33]:
import pandas as pd
import torch
import random
import itertools

from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset
from torch_geometric import edge_index
from torch_geometric.data import Data

# Helper Functions

# Create Dataset

In [32]:
print(f"Cuda available: {torch.cuda.is_available()}")

Cuda available: True


In [34]:
model = SentenceTransformer('all-MiniLM-L6-v2')

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Unnamed: 0,date,connections
2636,2025-04-01,"[, , ,]"
2637,2025-04-01,"[, , ,]"
2638,2025-04-01,"[, , ,]"
2639,2025-04-01,"[, , ,]"


In [98]:
import ast
example_df = pd.read_csv('data/connections.csv', index_col=0)
example_df.drop(columns='category', inplace=True)

# fix issues with incorrect format... (not worth going in and adding a mechanism to correct...)
example_df.loc[1298, "connections"] = "['line', 'plane', 'point', 'solid']"
example_df.loc[1892, "connections"] = "['abyss', 'fly', 'matrix', 'thing']"

# remove april fools samples since they include emojis or other potentially noisy samples
example_df = example_df[~example_df['date'].str.contains("04-01")]
example_df.reset_index(drop=True, inplace=True)

example_df['connections'] = example_df['connections'].apply(ast.literal_eval)
example_df

Unnamed: 0,date,connections
0,2023-06-12,"[kayak, level, mom, race car]"
1,2023-06-12,"[option, return, shift, tab]"
2,2023-06-12,"[bucks, heat, jazz, nets]"
3,2023-06-12,"[hail, rain, sleet, snow]"
4,2023-06-13,"[are, queue, sea, why]"
...,...,...
2751,2025-05-01,"[pot, prize, purse, reward]"
2752,2025-05-02,"[bottle, break, goose, turtle]"
2753,2025-05-02,"[dog, link, rib, wing]"
2754,2025-05-02,"[brace, post, prop, support]"


In [99]:
example_df.groupby(by='date').agg(
    {
        "connections": lambda x: list(itertools.chain.from_iterable(x)) 
    }
)

Unnamed: 0_level_0,connections
date,Unnamed: 1_level_1
2023-06-12,"[kayak, level, mom, race car, option, return, ..."
2023-06-13,"[are, queue, sea, why, essence, people, time, ..."
2023-06-14,"[amigo, king, stooge, tenor, lab, peke, pit, p..."
2023-06-15,"[bat, iron, spider, super, dust, mop, sweep, v..."
2023-06-16,"[green, mustard, plum, scarlet, blue, down, gl..."
...,...
2025-04-28,"[bore, drain, exhaust, tire, fiber, fingerprin..."
2025-04-29,"[gemstone, infield, rhombus, suit, ladder, mou..."
2025-04-30,"[dynasty, engross, gimmick, mildew, face, imag..."
2025-05-01,"[size"" to mean small - bite, fun, pocket, trav..."


In [100]:
example_df[example_df['date'] == '2023-06-21'].reset_index(drop=True)

Unnamed: 0,date,connections
0,2023-06-21,"[fish, goat, scales, twins]"
1,2023-06-21,"[crane, jay, swallow, turkey]"
2,2023-06-21,"[chad, georgia, jordan, togo]"
3,2023-06-21,"[date, kiwi, lemon, orange]"


In [112]:
import itertools
import random
import torch
from torch_geometric.data import Data
from torch.nn.functional import cosine_similarity
from torch.utils.data import Dataset

"""

TODO: add more features to handle complicated cases
- phonetic word embeddings
- character-level embeddings (either trained via RNN or pre-trained)
- n-gram?
"""
class ConnectionsGraphDataset(Dataset):
    def __init__(
        self,
        puzzle_df: pd.DataFrame,
        word_emb_model: SentenceTransformer,
        negative_ratio: int = 3,
        include_purple: bool = False
    ):
        super().__init__()
        self.puzzle_df = puzzle_df
        self.negative_ratio = negative_ratio
        self.include_purple = include_purple
        self.data_list = []

        # get lists of words for each puzzle 
        words_per_date = (
            self.puzzle_df
            .groupby('date')['connections']
            .agg(lambda lists: list(itertools.chain.from_iterable(lists)))
        )

        # create data object for each puzzle
        for date, word_list in words_per_date.items():
            if len(word_list) != 16:
                raise ValueError(f'Word list length {len(word_list)} does not match 16 words')
            data = self._build_single_graph(date, word_list, word_emb_model)
            self.data_list.append(data)

    def _build_single_graph(self, date, word_list, word_emb_model):
        
        # create node features 
        x = word_emb_model.encode(word_list, convert_to_tensor=True)
        word2idx = {w: i for i, w in enumerate(word_list)}

        edge_index, edge_attr = self._make_fully_connected(x)

        positives = self._collect_positives(date, word2idx)
        negatives = self._sample_negatives(len(word_list), positives)

        # remove purple pairs (already handled for negatives)
        if not self.include_purple:
            positives = positives[:-1]

        group_indices = torch.tensor(positives + negatives, dtype=torch.long)
        group_labels = torch.tensor(
            [1]*len(positives) + [0]*len(negatives),
            dtype=torch.float
        )
        
        data = Data(
            x=x,    # (16, embed_dim)
            edge_index=edge_index,  # (2, num_edges)
            edge_attr=edge_attr, # (num_edges, 1)
            group_indices=group_indices,  # (num_groups, 4)
            group_labels=group_labels     # (num_groups,)
        )

        data.word_list = word_list
        return data

    def _make_fully_connected(self, x):
        num_nodes = x.size(0)
        # all i<j pairs
        pairs = list(itertools.combinations(range(num_nodes), 2))
        edge_index = torch.tensor(pairs, dtype=torch.long).t().contiguous()
        
        # add reverse direction
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        # compute cosine-similarity edge weights
        src = x[edge_index[0]]
        dst = x[edge_index[1]]
        edge_attr = cosine_similarity(src, dst).unsqueeze(1)

        return edge_index, edge_attr

    def _collect_positives(self, date, word2idx):
        subset = self.puzzle_df[self.puzzle_df['date'] == date]
        positives = []
        for _, row in subset.iterrows():
            idxs = [word2idx[w] for w in row['connections']]
            positives.append(sorted(idxs))
        return positives

    def _sample_negatives(self, num_nodes, positives):
        positive_set = set(tuple(g) for g in positives)
        all_quads = list(itertools.combinations(range(num_nodes), 4))
        random.shuffle(all_quads)

        negatives = []
       
        if self.include_purple: 
            max_neg = self.negative_ratio * len(positives)
        else:
            max_neg = self.negative_ratio * (len(positives) - 1)
        
        for quad in all_quads:
            if len(negatives) >= max_neg:
                break
            if tuple(quad) not in positive_set:
                negatives.append(list(quad))
        return negatives

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]

In [113]:
example_df['date'].nunique()

689

In [115]:
data = next(iter(ConnectionsGraphDataset(example_df, model, negative_ratio=3, include_purple=False)))
data

Data(x=[16, 384], edge_index=[2, 240], edge_attr=[240, 1], group_indices=[12, 4], group_labels=[12], word_list=[16])

In [118]:
data.group_labels

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [117]:
data.word_list

['kayak',
 'level',
 'mom',
 'race car',
 'option',
 'return',
 'shift',
 'tab',
 'bucks',
 'heat',
 'jazz',
 'nets',
 'hail',
 'rain',
 'sleet',
 'snow']