Author: Will Blanton

# Imports

In [1]:
import ast
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import itertools

from sentence_transformers import SentenceTransformer

from gensim.models.fasttext import load_facebook_vectors
from sklearn.feature_extraction.text import CountVectorizer
from typing import List
from torch.utils.data import Dataset, DataLoader

In [2]:
from pytorchltr.loss import LambdaNDCGLoss2
from pytorchltr.evaluation import ndcg

# Constants

In [3]:
M = 1820

# Helper Functions

In [4]:
def extract_features(word):
    return pd.DataFrame({word:{
        'length': len(word),
        'num_vowels': sum(c in 'aeiou' for c in word.lower()),
        'num_consonants': sum(c.isalpha() and c not in 'aeiou' for c in word.lower()),
        'ends_with_ing': int(word.lower().endswith('ing')),
        'ends_with_ed': int(word.lower().endswith('ed')),
        'is_palindrome': int(word.lower() == word[::-1].lower())
    }})

def extract_structural_features(words):
    """
    Extract basic structural features for the given words.
    """
    return torch.tensor(pd.concat([extract_features(w) for w in words], axis=1).T.values) 

In [5]:

def construct_groups_and_labels():
    """
    Produce index groups with graded relevance scores for each combination of words.

    They're the same across all splits as long as word order is maintained. Model is order invariant, so this should be acceptable.
    """

    # consecutive groups of 4 are groups (0-3, 4-7, ...)
    groups = {
        "yellow":{0, 1, 2, 3}, 
        "green": {4, 5, 6, 7},
        "blue": {8, 9, 10, 11},
        "purple": {12, 13, 14, 15}
    }

    quads = list(itertools.combinations(range(16), 4))

    # produce graded relevance scores based on the max number of items in a group ([0,3] -> 4)
    rel_scores = [max([len(set(q).intersection(g)) for _, g in groups.items()]) for q in quads]

    return torch.tensor(quads), torch.tensor(rel_scores)

In [6]:
from torch.utils.data._utils.collate import default_collate


def collate_keep_wordlists(batch):
    # batch is a list of dicts
    keys = batch[0].keys()
    out = {}
    for k in keys:
        vals = [b[k] for b in batch]
        if k == "word_list":
            # keep as list-of-lists grouped by sample: [ [w1..w16], ... ] (len = B)
            out[k] = vals
        else:
            out[k] = default_collate(vals)
    return out

# Create Dataset

In [7]:
random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fff927f7b10>

In [8]:
print(f"Cuda available: {torch.cuda.is_available()}")

Cuda available: True


In [9]:
# TODO: try prompt with sentence transformer
word_embedder = SentenceTransformer('all-MiniLM-L6-v2')

In [10]:
test_emb = word_embedder.encode("Hello")

In [11]:
D = len(test_emb)
D

384

In [12]:
# ft = load_facebook_vectors('pretrained/cc.en.300.bin')

In [13]:
class FastTextEmbedder:
    """
    Wrap gensim embeddings for simpler conversion to tensors and to fit with sentence-transformers api for easy swapping
    """
    def __init__(self, ft_model):
        self.ft = ft_model
        self.dim = ft_model.vector_size

    def encode(self, word_list, convert_to_tensor=True):
        vectors = []
        for word in word_list:
            vectors.append(self.ft[word])

        vectors = np.stack(vectors)

        if convert_to_tensor:
            return torch.tensor(vectors, dtype=torch.float)
        else:
            return vectors

In [14]:
# word_embedder = FastTextEmbedder(ft)

In [15]:
connections_df = pd.read_csv('data/connections.csv', index_col=0)
# connections_df.drop(columns='category', inplace=True)

# fix issues with incorrect format... (not worth going in and adding a mechanism to correct...)
connections_df.loc[1298, "connections"] = "['line', 'plane', 'point', 'solid']"
connections_df.loc[1892, "connections"] = "['abyss', 'fly', 'matrix', 'thing']"

# remove april fools samples since they include emojis or other potentially noisy samples
connections_df = connections_df[~connections_df['date'].str.contains("04-01")]
connections_df.reset_index(drop=True, inplace=True)

connections_df['connections'] = connections_df['connections'].apply(ast.literal_eval)
connections_df.head()

Unnamed: 0,date,category,connections
0,2023-06-12,palindromes,"[kayak, level, mom, race car]"
1,2023-06-12,keyboard keys,"[option, return, shift, tab]"
2,2023-06-12,nba teams,"[bucks, heat, jazz, nets]"
3,2023-06-12,wet weather,"[hail, rain, sleet, snow]"
4,2023-06-13,letter homophones,"[are, queue, sea, why]"


In [16]:
all_words = list(itertools.chain.from_iterable(
    connections_df
    .groupby('date')['connections']
    .agg(lambda lists: list(itertools.chain.from_iterable(lists)))
    .values
))
all_words[:10]

['kayak',
 'level',
 'mom',
 'race car',
 'option',
 'return',
 'shift',
 'tab',
 'bucks',
 'heat']

In [17]:
# [w for w in all_words if ('_' in w) or ('-' in w)]

In [18]:
vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 3))
vectorizer.fit(all_words)

0,1,2
,input,'content'
,encoding,'utf-8'
,decode_error,'strict'
,strip_accents,
,lowercase,True
,preprocessor,
,tokenizer,
,stop_words,
,token_pattern,'(?u)\\b\\w\\w+\\b'
,ngram_range,"(1, ...)"


In [19]:
"""

TODO: add more features to handle complicated cases
- phonetic word embeddings
- character-level embeddings (either trained via RNN or pre-trained)
- n-gram?
"""
class ConnectionsDataset(Dataset):
    def __init__(
        self,
        puzzle_df: pd.DataFrame,
        word_emb_model,
        char_vectorizer,
    ):
        super().__init__()
        self.puzzle_df = puzzle_df
        self.puzzle_list = []

        # get lists of words for each puzzle 
        words_per_date = (
            self.puzzle_df
            .groupby('date')['connections']
            .agg(lambda lists: list(itertools.chain.from_iterable(lists)))
        )

        # create data object for each puzzle
        for date, word_list in words_per_date.items():
            if len(word_list) != 16:
                raise ValueError(f'Word list length {len(word_list)} does not match 16 words')

            # create node features 
            x = word_emb_model.encode(word_list, convert_to_tensor=True).cpu()

            # don't need labels, as all labels are the same...
            puzzle = {
                "x": x,
                "x_char": torch.tensor(char_vectorizer.transform(word_list).toarray()),
                # "x_struct": extract_structural_features(word_list),
                "word_list": word_list
            } 

            self.puzzle_list.append(puzzle)
 
    def __len__(self):
        return len(self.puzzle_list)

    def __getitem__(self, idx):
        return self.puzzle_list[idx]    

In [20]:
idx, y = construct_groups_and_labels()

In [21]:
idx[:10]

tensor([[ 0,  1,  2,  3],
        [ 0,  1,  2,  4],
        [ 0,  1,  2,  5],
        [ 0,  1,  2,  6],
        [ 0,  1,  2,  7],
        [ 0,  1,  2,  8],
        [ 0,  1,  2,  9],
        [ 0,  1,  2, 10],
        [ 0,  1,  2, 11],
        [ 0,  1,  2, 12]])

In [22]:
idx.unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [23]:
y[:10]

tensor([4, 3, 3, 3, 3, 3, 3, 3, 3, 3])

In [24]:
y.unique()

tensor([1, 2, 3, 4])

In [25]:
# split data
dates = connections_df["date"].unique()
train_idx = random.sample(range(len(dates)), k=int(len(dates) * .8))
train_dates = dates[train_idx]
test_dates = dates[~np.isin(range(len(dates)), train_idx)]

test_dates = np.random.permutation(test_dates)
split_idx = len(test_dates) // 2
val_dates = test_dates[split_idx:]
test_dates = test_dates[:split_idx]

print(f"Total Length: {len(dates)}")
print(f"Train length: {len(train_dates)}")
print(f"Val length: {len(val_dates)}")
print(f"Test length: {len(test_dates)}")

Total Length: 853
Train length: 682
Val length: 86
Test length: 85


In [26]:
BATCH_SIZE = 16

train = ConnectionsDataset(connections_df.loc[connections_df["date"].isin(train_dates)], word_embedder, vectorizer)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_keep_wordlists)

val = ConnectionsDataset(connections_df.loc[connections_df["date"].isin(val_dates)], word_embedder, vectorizer)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_keep_wordlists)

test = ConnectionsDataset(connections_df.loc[connections_df["date"].isin(test_dates)], word_embedder, vectorizer)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_keep_wordlists)

In [27]:
idx[0]

tensor([0, 1, 2, 3])

In [28]:
idx.shape

torch.Size([1820, 4])

In [29]:
for batch in DataLoader(val, batch_size=2, shuffle=True, collate_fn=collate_keep_wordlists):
    print(batch["word_list"][0])

    x = batch["x"]
    print(f"x: {x.shape}")
    print(f"x type: {x.dtype}")

    
    x_char = batch["x_char"]
    print(f"x_char: {x_char.shape}")

    idx_flat = idx.reshape(-1)  # (1820*4,)
    group_embeds = torch.index_select(x, dim=1, index=idx_flat).view(x.shape[0], 1820, 4, x.shape[2])
    print(group_embeds.shape)
    print(group_embeds[0][0])
    
    break

['are', 'radius', 'reverse', 'right', 'babe', 'fox', 'snack', 'ten', 'bear', 'generate', 'produce', 'yield', 'aside', 'detour', 'digression', 'tangent']
x: torch.Size([2, 16, 384])
x type: torch.float32
x_char: torch.Size([2, 16, 4800])
torch.Size([2, 1820, 4, 384])
tensor([[-0.0079,  0.0449,  0.0086,  ..., -0.0353,  0.1046, -0.0347],
        [ 0.0729,  0.0323, -0.1690,  ...,  0.0757,  0.0324,  0.0148],
        [-0.0048,  0.0567,  0.0470,  ...,  0.0278,  0.0530, -0.0354],
        [-0.0549,  0.0476, -0.0326,  ...,  0.0379,  0.0735,  0.0799]])


In [30]:
idx.shape

torch.Size([1820, 4])

# Model Architectures

## Components

In [31]:
class MultiheadAttentionBlock(torch.nn.Module):
    def __init__(
            self,
            in_dim,
            out_dim,
            attn_heads=8
        ):
        super().__init__()

        # following the notation from the set transformer paper
        self.attn = torch.nn.MultiheadAttention(in_dim, num_heads=attn_heads, batch_first=True, dropout=.3)
        self.h_norm = torch.nn.LayerNorm(in_dim)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(in_dim, out_dim),
            torch.nn.ReLU()
        )
        self.mab_norm = torch.nn.LayerNorm(in_dim)
 
    def forward(self, x, y):
        # H: self-attention with residual
        x = self.h_norm(x + self.attn(x, y, y)[0])

        # MAB 
        return self.mab_norm(x + self.ff(x))

In [32]:
class SetEncoder(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dims: List[int],
        num_mabs: int,
        attn_heads: List[int],
    ):
        super().__init__()
        assert num_mabs == len(out_dims) == len(attn_heads), "num_mabs, out_dims, attn_heads must match in length"

        mabs = []
        cur_in = in_dim
        for i in range(num_mabs):
            mabs.append(MultiheadAttentionBlock(
                in_dim=cur_in,
                out_dim=out_dims[i],
                attn_heads=attn_heads[i],
            ))
            cur_in = out_dims[i]
        self.mabs = nn.ModuleList(mabs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for mab in self.mabs:
            x = mab(x, x)
        return x

In [33]:
class SetDecoder(nn.Module):
    def __init__(
            self,
            in_dim,
            h1,
            out_dim,
            num_seeds,
            h2 = None,
            mab_attn_heads=8,
            sab_attn_heads=8,
        ):
        """
            Args:
                k: number of seed vectors
        """
        super().__init__()

        self.num_seeds = num_seeds

        # TODO: use Xavier intiializations?
        self.seed_vectors = nn.Parameter(torch.randn(num_seeds, in_dim))        
        
        self.ff1 = torch.nn.Sequential(
            torch.nn.Linear(in_dim, h1),
            torch.nn.ReLU()
        )
        
        self.mab = MultiheadAttentionBlock(
                in_dim=h1,
                out_dim=h2 if num_seeds > 1 else out_dim,
                attn_heads=mab_attn_heads,
        )

        # only need to use the set encoder when more than 1 vector is produced (model interactions between)
        if num_seeds > 1:
            self.sab = SetEncoder(h2, [h2], [sab_attn_heads])

            self.ff2 = torch.nn.Sequential(
                torch.nn.Linear(h2, out_dim),
                torch.nn.ReLU()
            )

    def forward(self, x):

        x = self.ff1(x)
        x = self.mab(self.source_vectors, x)

        if self.num_seeds > 1:
            x = self.sab(x)
            x = self.ff2(x)

        return x

In [34]:
class MLPScorer(nn.Module):
    def __init__(self, in_dim: int, hidden_dims=[256, 128], dropout: float = 0.1):
        """
        Generic configurable MLP scorer.

        Args:
            input_dim:  size of input feature vector (e.g., 4*D)
            hidden_dims: list/tuple of hidden layer sizes
            output_dim: size of output (1 for scalar score)
            dropout: dropout probability between hidden layers
        """
        super().__init__()
        layers = []
        prev_dim = in_dim

        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.LayerNorm(h))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = h

        layers.append(nn.Linear(prev_dim, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        # x: (batch, input_dim)
        return self.mlp(x)

In [35]:
class Grouper(nn.Module):
    def __init__(self, group_idx: torch.Tensor):
        super().__init__()
        # Ensure long dtype and register as a buffer so .to(device)/.cuda() moves it
        self.num_groups, self.group_size = group_idx.shape
        self.register_buffer("idx", group_idx.to(dtype=torch.long), persistent=False)

    def forward(self, x):  #
        """
        Group items in x into all combinations of 4 items
        """
        B, _, D = x.shape
        idx_flat = self.idx.reshape(-1)                      # (1820*4,)
        out = torch.index_select(x, dim=1, index=idx_flat)   # (B, 1820*4, D)
        return out.view(B, self.num_groups, self.group_size, D)

## Baseline (MLP) - No Global Context

In [36]:
# start with a simple model and build our way up
class Baseline(nn.Module):
    def __init__(self, in_dim, group_idx, layers=[512, 256, 256]):
        super().__init__()

        self.proj = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU()
        )

        self.grouper = Grouper(group_idx)
        self.scorer = MLPScorer(256 * self.grouper.group_size, hidden_dims=layers, dropout=.3)

    def forward(self, x):

        x = self.proj(x)
        x = self.grouper(x) # (B, 1820, 4, D)

        # concatenate for scoring
        x = x.flatten(2, 3) # (B, 1820, D * 4)

        return self.scorer(x)  # (B, 1820, 1)
 

## Attention Model

In [53]:

# start with a simple model and build our way up
class AttentionModel(nn.Module):
    def __init__(self, in_dim, group_idx, layers=[512, 256, 256], attn_layers=1):
        super().__init__()

        self.proj = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU()
        )

        self.grouper = Grouper(group_idx)
        self.encoder = SetEncoder(
            256,
            [256] * attn_layers,
            attn_layers,
            [8] * attn_layers
        )
        self.scorer = MLPScorer(256 * self.grouper.group_size, hidden_dims=layers, dropout=.3)

    def forward(self, x):

        x = self.proj(x)
        x = self.encoder(x)
        x = self.grouper(x) # (B, 1820, 4, D)

        # concatenate for scoring
        x = x.flatten(2, 3) # (B, 1820, D * 4)

        return self.scorer(x)  # (B, 1820, 1)

## Set Transformer

In [38]:

# class SetTransformer(torch.nn.Module):
#     def __init__(
#             self,
#             in_dim,
#         ):
#         super().__init__() 

        
#     def forward(self, x):
 
#         # combine input features into a single representation

#         # encoder layers

#         # decode into cluster centers


In [39]:
y

tensor([4, 3, 3,  ..., 3, 3, 4])

In [40]:
y.shape

torch.Size([1820])

# Train Model

In [41]:
def topk_overlap(pred_scores, relevance, k=4):
    """
    Computes fraction of relevant items appearing in the model's top-k predictions.
    Args:
        pred_scores (torch.Tensor): (B, N) predicted scores
        relevance (torch.Tensor): (B, N) binary relevance labels (0 or 1)
        k (int): number of top items to consider
    Returns:
        torch.Tensor: mean overlap fraction across batch
    """
    _, pred_idx = torch.topk(pred_scores, k, dim=1)
    _, true_idx = torch.topk(relevance, k, dim=1)

    overlap = (pred_idx.unsqueeze(2) == true_idx.unsqueeze(1)).any(dim=2).float().sum(dim=1)
    return (overlap / k).mean()

In [42]:
def evaluate(model, loader, loss_fn, n_scalar, device, k=4):
    model.eval()
    
    total_loss = 0
    total_ndcg = 0
    total_top4_acc = 0

    with torch.no_grad():
        for batch in loader:
            x = batch["x"].to(device)
            
            B, _, D = x.shape
 
            scores = model(x)

            # expand to use in batch
            relevance = y.unsqueeze(0).expand(B, len(y)) # (B, 1820)

            n = n_scalar.expand(B)

            total_loss += loss_fn(scores, relevance, n).mean().item()
            total_ndcg += ndcg(scores, relevance, n, k=k).mean().item()
            total_top4_acc = topk_overlap(scores.squeeze(-1), relevance, k=4).item()

    # Compute average loss
    avg_loss = total_loss / len(loader)
    avg_ndcg = total_ndcg / len(loader)
    avg_top4_acc = total_top4_acc / len(loader)

    return avg_loss, avg_ndcg, avg_top4_acc

In [57]:

from tensorboardX import SummaryWriter
from datetime import datetime
def train(model, y, num_epochs, model_name, device, lr=1e-4):
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
    writer = SummaryWriter(log_dir=f"runs/{model_name}/connections_experiment_LTR-{current_time}")

    scaler = torch.amp.GradScaler(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
    y = y.to(device)
    n_scalar = torch.tensor(M, device=device, dtype=torch.long)

    loss = LambdaNDCGLoss2().to(device)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_ndcg = 0
        train_top4_acc = 0

        for batch in train_loader:
            optimizer.zero_grad(set_to_none=True)
            x = batch["x"].to(device)
            
            B, _, D = x.shape
            
            # expand to use in batch
            relevance = y.unsqueeze(0).expand(B, len(y)) # (B, 1820)
            n = n_scalar.expand(B)

            with torch.amp.autocast(dtype=torch.float16, device_type=str(device)):
                scores = model(x)
                batch_loss = loss(scores, relevance, n).mean()

            scaler.scale(batch_loss).backward()
            # TODO: check if gradient clipping is needed
            scaler.step(optimizer)
            scaler.update()

            # eval train
            with torch.no_grad():
                train_ndcg += ndcg(scores, relevance, n, k=4).mean().item()
                train_top4_acc += topk_overlap(scores.squeeze(-1), relevance, k=4).item()
                
            train_loss += batch_loss.item()

        train_loss /= len(train_loader)
        train_ndcg /= len(train_loader)
        train_top4_acc /= len(train_loader)

        val_loss, val_ndcg, val_top4_acc = evaluate(model, val_loader, loss, n_scalar, device)
        test_loss, test_ndcg, test_top4_acc = evaluate(model, test_loader, loss, n_scalar, device)

        scheduler.step(val_loss)

        if epoch % 5 == 0 or epoch == num_epochs - 1:
            # --- Tensorboard logging ---
            writer.add_scalar("loss/train", train_loss, epoch)
            writer.add_scalar("loss/val", val_loss, epoch)
            writer.add_scalar("loss/test", test_loss, epoch)

            writer.add_scalar("ndcg/train", train_ndcg, epoch)
            writer.add_scalar("ndcg/val", val_ndcg, epoch)
            writer.add_scalar("ndcg/test", test_ndcg, epoch)

            writer.add_scalar("top4_acc/train", train_top4_acc, epoch)
            writer.add_scalar("top4_acc/val", val_top4_acc, epoch)
            writer.add_scalar("top4_acc/test", test_top4_acc, epoch)

            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar("lr", current_lr, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(f"param_hist/{name}", param, epoch)

                if param.grad is not None:
                    writer.add_histogram(f"grad_hist/{name}", param.grad, epoch)

            print(
                f"Epoch {epoch} | "
                f"Train Loss: {train_loss:.4f} | "
                f"Train NDCG: {train_ndcg:.4f} | "
                f"Train Top-4 Acc: {train_top4_acc:.4f} | "
                f"Val Loss: {val_loss:.4f} | "
                f"Val NDCG: {val_ndcg:.4f} | "
                f"Val Top-4 Acc: {val_top4_acc:.4f} | "
                f"Test Loss: {test_loss:.4f} | "
                f"Test NDCG: {test_ndcg:.4f} | "
                f"Test Top-4 Acc: {test_top4_acc:.4f}"
            )
    
    return model.to('cpu')

In [None]:

EPOCHS = 100
M = 1820

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Baseline Model

In [58]:
model = Baseline(
    in_dim=D,
    group_idx=idx,
    layers=[1024, 512, 256]
).to(device)

_ = train(model, y, num_epochs=100, model_name="baseline", device=device, lr=1e-4)



Epoch 0 | Train Loss: 1.7046 | Train NDCG: 0.2236 | Train Top-4 Acc: 0.0040 | Val Loss: 1.6662 | Val NDCG: 0.2492 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6666 | Test NDCG: 0.2714 | Test Top-4 Acc: 0.0000
Epoch 5 | Train Loss: 1.5732 | Train NDCG: 0.3479 | Train Top-4 Acc: 0.0247 | Val Loss: 1.6518 | Val NDCG: 0.2972 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6577 | Test NDCG: 0.2925 | Test Top-4 Acc: 0.0000
Epoch 10 | Train Loss: 1.3414 | Train NDCG: 0.4478 | Train Top-4 Acc: 0.0563 | Val Loss: 1.6419 | Val NDCG: 0.3171 | Val Top-4 Acc: 0.0069 | Test Loss: 1.6646 | Test NDCG: 0.2834 | Test Top-4 Acc: 0.0000
Epoch 15 | Train Loss: 1.1265 | Train NDCG: 0.4860 | Train Top-4 Acc: 0.0674 | Val Loss: 1.6447 | Val NDCG: 0.3107 | Val Top-4 Acc: 0.0069 | Test Loss: 1.6571 | Test NDCG: 0.2903 | Test Top-4 Acc: 0.0000
Epoch 20 | Train Loss: 0.9425 | Train NDCG: 0.5072 | Train Top-4 Acc: 0.0802 | Val Loss: 1.6524 | Val NDCG: 0.3027 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6658 | Test NDCG: 0.2830 | Test 

## Attention Model

In [64]:

model = AttentionModel(
    in_dim=D,
    group_idx=idx,
    layers=[1024, 512, 256],
    attn_layers=5
).to(device)

_ = train(model, y, num_epochs=300, model_name="attention", device=device, lr=5e-5)

Epoch 0 | Train Loss: 1.7061 | Train NDCG: 0.2185 | Train Top-4 Acc: 0.0025 | Val Loss: 1.6727 | Val NDCG: 0.2641 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6701 | Test NDCG: 0.2638 | Test Top-4 Acc: 0.0000
Epoch 5 | Train Loss: 1.6053 | Train NDCG: 0.3204 | Train Top-4 Acc: 0.0185 | Val Loss: 1.6515 | Val NDCG: 0.2862 | Val Top-4 Acc: 0.0069 | Test Loss: 1.6670 | Test NDCG: 0.2924 | Test Top-4 Acc: 0.0167
Epoch 10 | Train Loss: 1.4271 | Train NDCG: 0.4082 | Train Top-4 Acc: 0.0338 | Val Loss: 1.6475 | Val NDCG: 0.3116 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6550 | Test NDCG: 0.3100 | Test Top-4 Acc: 0.0000
Epoch 15 | Train Loss: 1.2325 | Train NDCG: 0.4558 | Train Top-4 Acc: 0.0413 | Val Loss: 1.6395 | Val NDCG: 0.3190 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6495 | Test NDCG: 0.3093 | Test Top-4 Acc: 0.0083
Epoch 20 | Train Loss: 1.0701 | Train NDCG: 0.4774 | Train Top-4 Acc: 0.0504 | Val Loss: 1.6415 | Val NDCG: 0.2942 | Val Top-4 Acc: 0.0000 | Test Loss: 1.6513 | Test NDCG: 0.3040 | Test 

## Set Transformer