# Trexquant Hangman Challenge Submission – Global Alpha Researcher – India – Dhruv Kumar

## Data Preprocessing

In [1215]:
import numpy as np
import pandas as pd
import torch

In [1216]:
# reading the provided training data
file_location = "words_250000_train.txt"
with open(file_location, "r") as file:
    data = file.read().splitlines()

data[:10]

['aaa',
 'aaaaaa',
 'aaas',
 'aachen',
 'aaee',
 'aag',
 'aahed',
 'aahs',
 'aal',
 'aalesund']

#### Data Augmentation

In [1900]:
# converting the words into all possible masked states and their best predicted letter as the target

from collections import Counter

state = [] # input
guess = [] # output

for word in data:
    word_len = len(word)
    current_state = ["_" for _ in word]
    count_dict = Counter(word).most_common()

    for letter in count_dict:
        state.append("".join(current_state))

        for idx in range(word_len):
            if word[idx] == letter[0]:
                current_state[idx] = letter[0]
        guess.append(letter[0])

state[90:100], guess[90:100]

TypeError: 'module' object is not iterable

In [1218]:
print("Total States:", len(state))
print("Total Guesses:", len(guess))

Total States: 1681209
Total Guesses: 1681209


In [1219]:
longest_word = max(state, key=len)
longest_word

'_____________________________'

In [1220]:
len(longest_word) # to set the max_sequence_ length

29

Removing duplicates

In [1221]:
dataset = [list(x) for x in zip(state, guess)]

Epoch 0:   0%|          | 0/14592 [11:49<?, ?it/s]
Epoch 0:   3%|▎         | 401/14592 [06:29<3:49:34,  1.03it/s, v_num=164, train_loss=0.161, train_accuracy=0.109]
Epoch 0:   6%|▌         | 850/14592 [03:55<1:03:32,  3.60it/s, v_num=165, train_loss=0.235, train_accuracy=0.0781]


In [1222]:
dataset[1230000][0], dataset[1230000][1]

('rushsylva__a', 'n')

In [1223]:
df = pd.DataFrame(dataset, columns=["states", "guess"])
df_unique = df.drop_duplicates(subset=['states'], keep='first')
df_unique

Unnamed: 0,states,guess
0,___,a
1,______,a
2,____,a
3,aaa_,s
5,aa____,c
...,...,...
1681200,zyz__y_,o
1681201,zyzo_y_,m
1681202,zyzomy_,s
1681205,zyzzy__,v


In [1224]:
#converting pandas dataframe back to list
dataset_unique = list(df_unique.itertuples(index=False, name=None))

In [1225]:
len(dataset_unique) # augmented data length = 1,037,610

1037610

In [1226]:
state = [dataset_unique[x][0] for x in range(len(dataset_unique))]
guess = [dataset_unique[x][1] for x in range(len(dataset_unique))]

Final Augmented Data

In [1227]:
len(state), len(guess)

(1037610, 1037610)

# Tokenization of every small case Letters in Alphabet

In [1776]:
#creating the vocab with all the unique characters
chars = sorted(list(set(guess)))
chars.insert(0,"_")
chars.insert(1,"<sos>")
chars.append("<PAD>")
vocab_size = len(chars)


In [1777]:
vocab_size, "".join(chars)

(29, '_<sos>abcdefghijklmnopqrstuvwxyz<PAD>')

In [1778]:
# Mapping from string to integer and vice verca for encoding and decoding sequences
str_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_str = {i:ch for i, ch in enumerate(chars)}

def encode(word):
    return [str_to_int.get(char, str_to_int["_"]) for char in word]

def decode(embed):
    return "".join([int_to_str.get(i, int_to_str[0]) for i in embed])


In [1779]:
print(encode("hello_"))
print(decode(encode("hello_")))

[9, 6, 13, 13, 16, 0]
hello_


In [1780]:
int_to_str

{0: '_',
 1: '<sos>',
 2: 'a',
 3: 'b',
 4: 'c',
 5: 'd',
 6: 'e',
 7: 'f',
 8: 'g',
 9: 'h',
 10: 'i',
 11: 'j',
 12: 'k',
 13: 'l',
 14: 'm',
 15: 'n',
 16: 'o',
 17: 'p',
 18: 'q',
 19: 'r',
 20: 's',
 21: 't',
 22: 'u',
 23: 'v',
 24: 'w',
 25: 'x',
 26: 'y',
 27: 'z',
 28: '<PAD>'}

In [1781]:
import torch

# encoding the dataset using state(s), guess(g)
encoded_states = [torch.tensor(encode(word), dtype=torch.long) for word in state]
encoded_guesses = [torch.tensor(encode(letter), dtype=torch.long) for letter in guess]



In [1782]:
encoded_states[:10], encoded_guesses[:10]

([tensor([0, 0, 0]),
  tensor([0, 0, 0, 0, 0, 0]),
  tensor([0, 0, 0, 0]),
  tensor([2, 2, 2, 0]),
  tensor([2, 2, 0, 0, 0, 0]),
  tensor([2, 2, 4, 0, 0, 0]),
  tensor([2, 2, 4, 9, 0, 0]),
  tensor([2, 2, 4, 9, 6, 0]),
  tensor([2, 2, 0, 0]),
  tensor([2, 2, 0])],
 [tensor([2]),
  tensor([2]),
  tensor([2]),
  tensor([20]),
  tensor([4]),
  tensor([9]),
  tensor([6]),
  tensor([15]),
  tensor([6]),
  tensor([8])])

In [1783]:
from torch.nn.utils.rnn import pad_sequence

padded_encoded_states = pad_sequence(encoded_states, batch_first=True, padding_value=str_to_int["<PAD>"], padding_side="left")
padded_encoded_guesses = pad_sequence(encoded_guesses, batch_first=True, padding_value=str_to_int["<PAD>"])

In [1786]:
my_list = torch.tensor([1]*1037610)
my_list = my_list.unsqueeze(dim=1)

In [1787]:
print(padded_encoded_guesses[:])
print(padded_encoded_guesses[:].shape)
print(padded_encoded_states[0,:].shape)
print(padded_encoded_states[:,1:])
print(padded_encoded_states[0,:].shape)

decoder_type_states = torch.cat((padded_encoded_states[:,:], my_list), 1)
decoder_type_guess = torch.cat((decoder_type_states[:,1:], padded_encoded_guesses[:]),1)

tensor([[ 2],
        [ 2],
        [ 2],
        ...,
        [20],
        [23],
        [ 2]])
torch.Size([1037610, 1])
torch.Size([29])
tensor([[28, 28, 28,  ...,  0,  0,  0],
        [28, 28, 28,  ...,  0,  0,  0],
        [28, 28, 28,  ...,  0,  0,  0],
        ...,
        [28, 28, 28,  ..., 14, 26,  0],
        [28, 28, 28,  ..., 26,  0,  0],
        [28, 28, 28,  ..., 26, 23,  0]])
torch.Size([29])


In [1788]:
print("Input: \n", decoder_type_states[0], "input shape: ", decoder_type_states.shape)
print("Output: \n", decoder_type_guess, "input shape: ", decoder_type_guess.shape)

Input: 
 tensor([28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
        28, 28, 28, 28, 28, 28, 28, 28,  0,  0,  0,  1]) input shape:  torch.Size([1037610, 30])
Output: 
 tensor([[28, 28, 28,  ...,  0,  1,  2],
        [28, 28, 28,  ...,  0,  1,  2],
        [28, 28, 28,  ...,  0,  1,  2],
        ...,
        [28, 28, 28,  ...,  0,  1, 20],
        [28, 28, 28,  ...,  0,  1, 23],
        [28, 28, 28,  ...,  0,  1,  2]]) input shape:  torch.Size([1037610, 30])


In [1789]:
# splitting the data into train and test set
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(decoder_type_states, decoder_type_guess, train_size=0.9, random_state=42)

In [1790]:
len(X_train), len(y_train), len(X_test), len(y_test)

(933849, 933849, 103761, 103761)

In [1791]:
# creating a dataloader
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True) 

# Transformer Architecture

In [1792]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import matplotlib.pyplot as plt

## Attention Class

In [1793]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

## Position Wise Feed Forward Network Class

In [1794]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

## Positional Encoding Class

In [1795]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

## Encoder and Decoder Classes Declaration

In [1796]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

## Encoder-Decoder Transformer Class Declaration

In [1797]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size=29, tgt_vocab_size=29, d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_seq_length=30, dropout=0.3):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

## Transformer Model Instantiation

In [1927]:
src_vocab_size = 29
tgt_vocab_size = 29
d_model = 512
num_heads = 16
num_layers = 12
d_ff = 2048
max_seq_length = 30
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)


## Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setting up loss and optimizer.
criterion = nn.CrossEntropyLoss(ignore_index=27)
optimizer = optim.Adam(transformer.parameters(), lr=0.00001, betas=(0.9, 0.98), eps=1e-9)

transformer.to(device)
transformer.train()

num_epochs = 1
train_losses = []
train_accuracies = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    batch_count = 0
    for batch in train_dataloader:
        src, tgt = batch  
        src = src.long().to(device)  
        tgt = tgt.long().to(device)

        # batch_size = src.size(0)
        # tgt_input = torch.full((batch_size, src.size(1)), 1, dtype=torch.long, device=device)
        
        optimizer.zero_grad()
        # Teacher forcing: pass the full target into the decoder.
        output = transformer(src, tgt[:,:-1])  # output shape: [batch, 29, tgt_vocab_size]
        print(tgt[:,:-1].shape)
        
        # Computing loss only on the final timestep
        loss = criterion(output[:, -1, :], tgt[:, -1])
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        batch_count += 1

        # Compute batch accuracy based on the last token
        preds = output[:, -1, :].argmax(dim=1)
        batch_acc = (preds == tgt[:, -1]).float().mean().item()
        epoch_accuracy += batch_acc

        # Optionally print batch-level info for debugging.
        print(f"Batch {batch_count}: Loss = {loss.item():.4f}, Accuracy = {batch_acc:.4f}")
        print(f"Predicted tokens: {preds.tolist()}")
        print(f"Target tokens:    {tgt[:, -1].tolist()}")
    
    avg_loss = epoch_loss / batch_count
    avg_accuracy = epoch_accuracy / batch_count
    train_losses.append(avg_loss)
    train_accuracies.append(avg_accuracy)
    print(f"Epoch {epoch+1}/{num_epochs} -- Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.4f}")


torch.Size([64, 29])
Batch 36: Loss = 3.0156, Accuracy = 0.0469
Predicted tokens: [6, 6, 6, 6, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 13, 13, 13, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 13, 6, 13, 16, 6, 6]
Target tokens:    [15, 10, 14, 20, 17, 5, 12, 5, 8, 4, 16, 16, 2, 9, 26, 10, 20, 10, 22, 13, 21, 4, 8, 6, 15, 19, 8, 3, 15, 4, 6, 10, 16, 16, 15, 16, 26, 26, 2, 16, 24, 17, 21, 6, 26, 14, 20, 2, 2, 2, 19, 2, 6, 13, 15, 20, 21, 19, 8, 10, 4, 13, 26, 16]
torch.Size([64, 29])
Batch 37: Loss = 2.9686, Accuracy = 0.0625
Predicted tokens: [6, 6, 6, 6, 6, 6, 6, 6, 6, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 20, 6, 6, 6, 6, 6, 13, 6, 6, 6, 13, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 13, 6, 6, 6, 6, 6, 6, 13, 13, 6, 6, 6, 6, 6, 6, 6, 13, 6, 6]
Target tokens:    [6, 15, 6, 10, 19, 3, 16, 21, 22, 22, 2, 17, 20, 9, 15, 21, 19, 16, 5, 22, 6, 16, 19, 13, 19, 2, 2, 16, 13, 4, 4, 9, 9, 4, 15, 4, 9, 14, 10, 22, 24, 26, 2, 15, 13

## Testing

In [1925]:
# For example, define device if not already done:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transformer.eval()
# Set up loss and optimizer.
criterion = nn.CrossEntropyLoss(ignore_index=27)

test_losses = []
test_accuracies = []

start_symbol = 1

test_loss = 0.0
test_accuracy = 0.0
batch_count = 0

with torch.no_grad():
    for batch in test_dataloader:
        src, tgt = batch  
        src = src.long().to(device)  # Ensure integer tokens and move to device
        tgt = tgt.long().to(device)
        
        output = transformer(src, tgt[:,:-1])  # output shape: [batch, 30, tgt_vocab_size]
        
        # Compute loss only on the final timestep
        loss = criterion(output[:, -1, :], tgt[:, -1])
        
        test_loss += loss.item()
        batch_count += 1

        # Compute batch accuracy based on the last token
        preds = output[:, -1, :].argmax(dim=1)
        batch_acc = (preds == tgt[:, -1]).float().mean().item()
        test_accuracy += batch_acc

        # Optionally print batch-level info for debugging.
        print(f"Batch {batch_count}: Loss = {loss.item():.4f}, Accuracy = {batch_acc:.4f}")
        print(f"Predicted tokens: {preds.tolist()}")
        print(f"Target tokens:    {tgt[:, -1].tolist()}")
        
avg_loss = test_loss / batch_count
avg_accuracy = test_accuracy / batch_count
test_losses.append(avg_loss)
test_accuracies.append(avg_accuracy)
print(f"Epoch {epoch+1}/{num_epochs} -- Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.4f}")


Batch 1: Loss = 1.7146, Accuracy = 0.4688
Predicted tokens: [17, 2, 8, 17, 21, 9, 2, 2, 21, 20, 16, 19, 21, 20, 13, 22, 9, 19, 14, 15, 10, 21, 10, 6, 21, 21, 16, 6, 10, 21, 14, 12, 14, 22, 14, 2, 6, 21, 19, 17, 9, 10, 2, 6, 4, 17, 2, 16, 4, 6, 19, 2, 20, 14, 15, 5, 13, 17, 10, 3, 21, 13, 26, 20]
Target tokens:    [17, 2, 8, 17, 21, 26, 2, 2, 13, 20, 3, 19, 21, 5, 9, 22, 9, 26, 14, 13, 10, 5, 10, 20, 15, 13, 16, 6, 16, 21, 15, 16, 12, 22, 17, 7, 6, 7, 13, 17, 19, 10, 2, 16, 23, 17, 2, 22, 14, 6, 17, 16, 20, 5, 4, 19, 13, 4, 10, 3, 16, 26, 2, 27]
Batch 2: Loss = 1.4635, Accuracy = 0.5000
Predicted tokens: [2, 6, 21, 20, 21, 21, 20, 2, 13, 2, 6, 14, 20, 15, 2, 2, 21, 4, 21, 4, 20, 26, 2, 9, 4, 6, 19, 10, 7, 19, 10, 12, 5, 4, 22, 13, 10, 19, 5, 2, 21, 2, 2, 15, 19, 21, 2, 18, 17, 19, 5, 23, 2, 4, 2, 10, 8, 13, 22, 19, 10, 19, 21, 10]
Target tokens:    [2, 6, 6, 16, 10, 21, 20, 6, 13, 2, 26, 14, 20, 15, 4, 2, 8, 10, 21, 5, 20, 26, 10, 9, 6, 16, 19, 10, 4, 19, 10, 12, 5, 21, 15, 13, 10, 15, 

KeyboardInterrupt: 

In [1932]:
# Save the trained model state.
torch.save(transformer.state_dict(), "transformer_model.pth")
print("Model saved to transformer_model.pth")

Model saved to transformer_model.pth


In [1921]:
#model 
import random

for batch in train_dataloader:
        src_fake, tgt_fake = batch  
        break

x = random.randint(0,63)

start_symbol = 1  # Replace with your actual <sos> token index

src_instance = src_fake[x].unsqueeze(0)  # now shape: [1, 29]
tgt_instance = tgt_fake[x].unsqueeze(0)  # now shape: [1, 29]

output = transformer(src_instance, src_instance[:,1:])

logits = output[:, -1, :]  # shape: [1, tgt_vocab_size]

# Get unique tokens in src_instance.
unique_tokens = src_instance.unique()  # shape: [num_unique]

for token in unique_tokens:
    logits[0, token] = -float('inf')

predictions = logits.argmax(dim=1)
print(predictions.shape)

print(f"Predicted tokens: {predictions.tolist()}")
print(f"Target tokens:    {tgt_fake[:, -1].tolist()[x]}")

torch.Size([1])
Predicted tokens: [23]
Target tokens:    23
