In [1]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from collections import deque
from string import ascii_lowercase
from torch.utils.data import Dataset, DataLoader

# Define constants
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using : {device}")

MAX_LENGTH = 25
BATCH_SIZE = 4096  
LEARNING_RATE = 1e-3
UNGUESSED_CHAR = 27
EMBEDDING_SIZE = 128

char_to_idx = {char: i + 1 for i, char in enumerate(ascii_lowercase)}  # input is 1-26 for a-z, 27 for unguessed and 0 for padding
char_to_idx['_'] = 27
idx_to_char = {i: char for i, char in enumerate(ascii_lowercase)}  # output idx is 0-25 for a-z

# Dataset class
class HangmanDataset(Dataset):
    def __init__(self, data):
        self.word_tensors = data['arr_1']
        self.guessed_flags = data['arr_2']
        self.targets = data['arr_3']

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        word_tensor = torch.tensor(self.word_tensors[idx], dtype=torch.long)
        guessed_flags = torch.tensor(self.guessed_flags[idx], dtype=torch.float32)
        target = torch.tensor(self.targets[idx], dtype=torch.long)
        return word_tensor, guessed_flags, target
# Load data
data = np.load("Training_Data.npz")

train_data = {key: data[key] for key in data}
train_dataset = HangmanDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


# Model definition
class BiLSTM_Network(nn.Module):
    def __init__(self):
        super(BiLSTM_Network, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=28, embedding_dim=EMBEDDING_SIZE, padding_idx=0)
        self.bilstm = nn.LSTM(input_size=EMBEDDING_SIZE, hidden_size=256, num_layers=4, dropout=0.01, batch_first=True, bidirectional=True)
        self.fcbilstm = nn.Sequential(
            nn.Linear(256 * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.01)
        )
        self.flags_dense = nn.Sequential(
            nn.Linear(26, 32),
            nn.ReLU(),
            nn.Dropout(0.01)
        )
        self.combined_dense = nn.Sequential(
            nn.Linear(64 + 32, 26)
        )

    def forward(self, word, flags):
        embedded = self.embedding(word)
        bilstm_out, _ = self.bilstm(embedded)
        lastbilstm_out = bilstm_out[:, -1, :]
        bilstm_fcout = self.fcbilstm(lastbilstm_out)
        flags_out = self.flags_dense(flags)
        # flags_out = flags_out.view(-1, 32)  # Make sure it has shape [batch_size, 32]

        # print("bilstm_fcout shape:", bilstm_fcout.shape)
        # print("flags_out shape:", flags_out.shape)
        
        combined = torch.cat((bilstm_fcout, flags_out), dim=1)
        return self.combined_dense(combined)


# Initialize model
model = BiLSTM_Network()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()
model.to(device)

# Testing function
with open("cleaned_word_list.txt", 'r') as file:
    WORDS = file.read().splitlines()

def test_model(model):

    test_words = random.sample(WORDS, 1000)
    model.eval()
    total_pass = 0
    with torch.no_grad():
        for word in test_words:
            attempt = 0
            input_ = ['_'] * len(word)
            word_char_set = set(word)
            predicted_chars = set()  # Keep track of chars we've already guessed
            input_tensor = torch.tensor([[char_to_idx[char] for char in input_]],dtype=torch.long)
            input_tensor = torch.nn.functional.pad(input_tensor,(0, MAX_LENGTH - len(input_tensor[0])), value=0).to(device)
            guessed_flags = torch.tensor([[0.0 for i in range(26)]], dtype=torch.float32).to(device)
            
            
            while attempt < 6:
                
                output = model(input_tensor,guessed_flags)
                
                # Get probabilities and sort them
                probs = torch.softmax(output, dim=1)
                sorted_probs, sorted_indices = torch.sort(probs, dim=1, descending=True)
                
                # Find the highest probability char that hasn't been guessed yet
                output_char = None
                for idx in sorted_indices[0]:
                    char = idx_to_char[idx.item()]
                    if char not in predicted_chars:
                        output_char = char
                        guessed_flags[0,idx]=1.0
                        break
                
                if output_char is None:  # If we've tried all chars
                    break
                    
                predicted_chars.add(output_char)
                # print(input_, "-->", output_char)
                
                if output_char in word_char_set:
                    # replace the '_' with the output_char at all the positions where the output_char is in the word
                    for i in range(len(word)):
                        if word[i] == output_char:
                            input_tensor[0,i] = char_to_idx[output_char]
                            input_[i] = output_char
                    word_char_set.remove(output_char)
                    if len(word_char_set) == 0:  # Word completed
                        break
                else:
                    attempt += 1
                    
            if len(word_char_set) == 0:
                total_pass += 1
    model.train()
    return total_pass / len(test_words)

# Training loop
NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS):
    avg_loss = 0
    for i, (word_tensor, guessed_flags, target) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}"):
        word_tensor = word_tensor.to(device)
        guessed_flags = guessed_flags.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(word_tensor, guessed_flags)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()

    avg_loss /= len(train_loader)
    test_acc = test_model(model)

    print(f"Epoch {epoch+1}, Average Loss: {avg_loss}, Test Accuracy: {test_acc}")

    if epoch % 1 == 0:
        torch.save(model.state_dict(), f"model_BiLSTM{epoch+1}.pth")



Using : cuda


Epoch 1: 100%|██████████| 8904/8904 [27:57<00:00,  5.31it/s]


Epoch 1, Average Loss: 2.198186398382564, Test Accuracy: 0.642


Epoch 2: 100%|██████████| 8904/8904 [27:50<00:00,  5.33it/s]


Epoch 2, Average Loss: 2.0598500899950447, Test Accuracy: 0.688


Epoch 3: 100%|██████████| 8904/8904 [28:01<00:00,  5.30it/s]


Epoch 3, Average Loss: 2.023886770558914, Test Accuracy: 0.7


Epoch 4: 100%|██████████| 8904/8904 [28:14<00:00,  5.26it/s]


Epoch 4, Average Loss: 2.0031306301005123, Test Accuracy: 0.685


Epoch 5: 100%|██████████| 8904/8904 [28:09<00:00,  5.27it/s]


Epoch 5, Average Loss: 1.991385769704817, Test Accuracy: 0.715


Epoch 6: 100%|██████████| 8904/8904 [28:13<00:00,  5.26it/s]


Epoch 6, Average Loss: 1.9835298040619007, Test Accuracy: 0.684


Epoch 7: 100%|██████████| 8904/8904 [28:13<00:00,  5.26it/s]


Epoch 7, Average Loss: 1.9776660549351468, Test Accuracy: 0.689


Epoch 8: 100%|██████████| 8904/8904 [28:12<00:00,  5.26it/s]


Epoch 8, Average Loss: 1.9730854887123699, Test Accuracy: 0.669


Epoch 9: 100%|██████████| 8904/8904 [28:15<00:00,  5.25it/s]


Epoch 9, Average Loss: 1.9690984959336733, Test Accuracy: 0.709


Epoch 10: 100%|██████████| 8904/8904 [28:10<00:00,  5.27it/s]


Epoch 10, Average Loss: 1.9658560513486116, Test Accuracy: 0.701
