In [1]:
import json
import os
import requests
import random
import string
import secrets
import time
import re
import collections
import torch
import torch.nn as nn
import math
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
import numpy as np
from collections import defaultdict
from collections import Counter
from tqdm import tqdm
import re

In [2]:
class HangmanDataset1(Dataset):
    def __init__(self, words, max_word_length=45, reveal_ratio=0.5):
        self.words = [word.lower() for word in words if len(word) <= max_word_length]
        self.max_length = max_word_length
        self.reveal_ratio = reveal_ratio
        self.char_to_idx = {char: i+1 for i, char in enumerate(string.ascii_lowercase)}
        self.char_to_idx['_'] = 0  # blank
        self.char_to_idx['PAD'] = 27

    def __len__(self): return len(self.words) * 80

    def __getitem__(self, idx):
        word = self.words[idx % len(self.words)]
        reveal_count = int(len(word) * self.reveal_ratio)
        revealed = random.sample(range(len(word)), reveal_count) if reveal_count > 0 else []

        word_state = [0] * self.max_length
        for pos in revealed: word_state[pos] = self.char_to_idx[word[pos]]

        target_pos, target_chars, position_context, vowels = [], [], [0]*self.max_length, set('aeiou')
        for i in range(len(word)):
            if i not in revealed:
                ctx = 0
                if i > 0 and word_state[i-1] != 0: ctx += 1
                if i < len(word)-1 and word_state[i+1] != 0: ctx += 2
                if ctx:
                    target_pos.append(i)
                    target_chars.append(self.char_to_idx[word[i]])
                    position_context[i] = ctx

        count_blanks = word_state[:len(word)].count(0)
        blank_vowel_next = [0]*self.max_length
        for i in range(len(word)):
            if word_state[i] == 0:
                l = word[i-1] if i > 0 else 'x'
                r = word[i+1] if i < len(word)-1 else 'x'
                if l in vowels or r in vowels:
                    blank_vowel_next[i] = 1

        max_targets = 10
        while len(target_pos) < max_targets:
            target_pos.append(-1)
            target_chars.append(0)

        return {
            'word_state': torch.tensor(word_state, dtype=torch.long),
            'position_context': torch.tensor(position_context, dtype=torch.long),
            'target_positions': torch.tensor(target_pos[:max_targets], dtype=torch.long),
            'target_chars': torch.tensor(target_chars[:max_targets], dtype=torch.long),
            'word_length': torch.tensor(len(word), dtype=torch.long),
            'blank_count': torch.tensor(count_blanks, dtype=torch.long),
            'next_to_vowel': torch.tensor(blank_vowel_next, dtype=torch.float)
        }

class EnhancedHangmanModel1(nn.Module):
    def __init__(self, vocab_size=28, max_len=45, emb_dim=128, hidden_dim=1024, ablate={}):
        super().__init__()
        self.ablate = ablate
        self.char_emb = nn.Embedding(vocab_size, emb_dim)
        self.ctx_emb = nn.Embedding(4, 32)

        self.pattern_cnn = nn.Sequential(
            nn.Conv1d(emb_dim, 64, 3, padding=1), nn.ReLU(), nn.Dropout(0.2),
            nn.Conv1d(64, 64, 3, padding=1), nn.ReLU(), nn.Dropout(0.2)
        )

        self.encoder = nn.LSTM(emb_dim + 32, hidden_dim, bidirectional=True, batch_first=True)

        self.pos_prior_mlp = nn.Sequential(
            nn.Linear(1 + 1 + 64, 32), nn.ReLU(), nn.Dropout(0.2), nn.Linear(32, 26)
        )

        def decoder():
            return nn.Sequential(
                nn.Linear(hidden_dim*2 + 26, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_dim // 2, 26)
            )

        self.left_decoder = decoder()
        self.right_decoder = decoder()
        self.both_decoder = decoder()

    def forward(self, word_state, position_context, word_length, blank_count, next_to_vowel):
        B, L = word_state.size()
        emb = self.char_emb(word_state)
        cnn_feat = self.pattern_cnn(emb.transpose(1, 2)).transpose(1, 2)
        ctx = self.ctx_emb(position_context)
        encoded, _ = self.encoder(torch.cat([emb, ctx], -1))

        pos_scores = []
        for i in range(L):
            is_blank = (word_state[:, i] == 0).float().unsqueeze(1)
            bc = blank_count.unsqueeze(1).float() / L
            pos_input = torch.cat([is_blank, bc, cnn_feat[:, i, :]], -1)
            pos_scores.append(self.pos_prior_mlp(pos_input).unsqueeze(1))
        priors = torch.cat(pos_scores, 1)  # [B, L, 26]

        out = torch.zeros(B, L, 26, device=word_state.device)
        for i in range(L):
            h = encoded[:, i, :]
            ptype = position_context[:, i]
            inp = torch.cat([h, priors[:, i, :]], -1)
            out[ptype==1, i, :] = self.left_decoder(inp[ptype==1])
            out[ptype==2, i, :] = self.right_decoder(inp[ptype==2])
            out[ptype==3, i, :] = self.both_decoder(inp[ptype==3])

        return out

def build_lengthwise_frequencies(word_list):
    """Build a map from word length to letter frequency Counter."""
    length_freq = defaultdict(Counter)
    for word in word_list:
        word = word.lower()
        unique_letters = set(word)
        length_freq[len(word)].update(unique_letters)
    return length_freq

def get_best_first_guess(word_length, guessed_letters, length_freq):
    """
    Returns the best frequency-based first guess for a given word length,
    excluding already guessed letters.
    """
    if word_length not in length_freq:
        # fallback to global frequency
        total_counter = Counter()
        for counter in length_freq.values():
            total_counter += counter
        freq = total_counter
    else:
        freq = length_freq[word_length]

    sorted_letters = [letter for letter, _ in freq.most_common() if letter not in guessed_letters]
    return sorted_letters[0] if sorted_letters else None

def word_matches_pattern(word, pattern, wrong_letters):
    """Check if a word matches the current hangman pattern"""
    if len(word) != len(pattern):
        return False
    if pattern.count('_') == len(pattern):
        return True
    # Reject if it contains any wrong letters
    if any(letter in word for letter in wrong_letters):
        return False
    
    # Check position-by-position for pattern match
    for w_char, p_char in zip(word, pattern):
        if p_char != '_' and w_char != p_char:
            return False
        if p_char == '_' and w_char in wrong_letters:
            return False  # Prevent wrong letters in unknown positions
    
    return True


def get_dictionary_filtered_multipliers(word_pattern, wrong_letters, dictionary):
    """
    Filter dictionary based on current pattern and wrong letters,
    then calculate letter frequency penalties based on positional constraints
    around revealed substrings.
    """
            
    # Find matching words
    matching_words = []
    
    pattern_str = ''.join(word_pattern)
    for word in dictionary:
        if word_matches_pattern(word, word_pattern, wrong_letters):
            matching_words.append(word)
    multipliers = {letter: 0.9 for letter in string.ascii_lowercase}
    for letter in string.ascii_lowercase:
        for word in matching_words:
            for i in range(len(word)):
                if word[i] == letter and word_pattern[i] == '_':
                    multipliers[letter]=1.1
                    break
            if multipliers[letter] == 1.1:
                break
    return multipliers, len(matching_words), len(dictionary) - len(matching_words)

def clean_state_dict(state_dict):
    """Removes 'module.' prefix from multi-GPU trained models if present"""
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

class HangmanSolver1:
    def __init__(self, model_path):
        self.model = EnhancedHangmanModel1()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dictionary = open("words_250000_train.txt").read().splitlines()
        self.char_to_idx = {c: i+1 for i, c in enumerate(string.ascii_lowercase)}
        self.char_to_idx['_'] = 0
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items() if v != 0}
        self.length_freq = build_lengthwise_frequencies(self.dictionary)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)
        # Dictionary-based RL multipliers
        self.dict_multipliers = {letter: 1.0 for letter in string.ascii_lowercase}
        
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
        cleaned_checkpoint0 = clean_state_dict(checkpoint)
        if isinstance(self.model, nn.DataParallel):
            self.model.module.load_state_dict(cleaned_checkpoint0)
        else:
            self.model.load_state_dict(cleaned_checkpoint0)
        self.model.eval()
    def update_dict_multipliers(self, word_pattern, wrong_letters):
        """Update multipliers based on dictionary filtering"""
        multipliers, matching_count, eliminated_count = get_dictionary_filtered_multipliers(
            word_pattern, wrong_letters, self.dictionary
        )
        self.dict_multipliers = multipliers
        return matching_count, eliminated_count

    def predict_letter(self, word_state, guessed_letters=None):
        if guessed_letters is None:
            guessed_letters = set()
        wrong_letters = {ch for ch in guessed_letters if ch not in word_state and ch.isalpha()}
        if wrong_letters is None:
            wrong_letters = set()
        if ' ' in word_state:
            word_state = word_state.replace(' ', '')
        # Update dictionary-based multipliers
        word_pattern = list(word_state)
        matching_count, eliminated_count = self.update_dict_multipliers(word_pattern, wrong_letters)
        # Use best first guess based on frequency
        if word_state.count('_') == len(word_state):
            return get_best_first_guess(len(word_state), guessed_letters, self.length_freq)
        
        max_length = 45
        state_indices = []
        position_context = []
    
        for i, char in enumerate(word_state):
            if char == '_':
                state_indices.append(0)
                ctx = 0
                if i > 0 and word_state[i-1] != '_': ctx += 1
                if i < len(word_state)-1 and word_state[i+1] != '_': ctx += 2
                position_context.append(ctx)
            else:
                state_indices.append(self.char_to_idx.get(char, 27))
                position_context.append(0)
    
        while len(state_indices) < max_length:
            state_indices.append(27)
            position_context.append(0)
    
        word_tensor = torch.tensor([state_indices], dtype=torch.long).to(self.device)
        context_tensor = torch.tensor([position_context], dtype=torch.long).to(self.device)
        length_tensor = torch.tensor([len(word_state)], dtype=torch.long).to(self.device)
        blank_count_tensor = torch.tensor([word_state.count('_')], dtype=torch.long).to(self.device)
    
        blank_vowel_next = [0] * max_length
        for i in range(len(word_state)):
            if word_state[i] == '_':
                l = word_state[i-1] if i > 0 else 'x'
                r = word_state[i+1] if i < len(word_state)-1 else 'x'
                if l in 'aeiou' or r in 'aeiou':
                    blank_vowel_next[i] = 1
        blank_vowel_tensor = torch.tensor([blank_vowel_next], dtype=torch.float).to(self.device)
    
        with torch.no_grad():
            model_out = self.model(
                word_tensor, context_tensor, length_tensor,
                blank_count_tensor, blank_vowel_tensor
            )
        # print(self.dict_multipliers)
        # Apply dictionary-based multipliers
        dict_multipliers_tensor = torch.tensor(
            [self.dict_multipliers[chr(ord('a') + i)] for i in range(26)],
            device=self.device
        )
        
        best_predictions = []
        for i in range(len(word_state)):
            if word_state[i] == '_':
                # Apply dictionary multipliers to model predictions
                adjusted_logits = model_out[0, i, :]
                reveal_ratio = sum(1 for c in word_state if c != '_') / len(word_state)
                if reveal_ratio < 0.35:
                    adjusted_logits *= dict_multipliers_tensor
                
                probs = torch.softmax(adjusted_logits, dim=0)
                for j, prob in enumerate(probs):
                    letter = chr(ord('a') + j)
                    if letter not in guessed_letters:
                        best_predictions.append((letter, prob.item(), i))
    
        if best_predictions:
            best_predictions.sort(key=lambda x: x[1], reverse=True)
            return best_predictions[0][0]
    
        # Fallback to frequency-based guess
        available_letters = [c for c in string.ascii_lowercase if c not in guessed_letters]
        if available_letters:
            return available_letters[0]
        return None

def simulate_hangman_game(solver1, solver2, word, max_wrong=6, verbose=False):
    """Simulate a hangman game and return results"""
    true_word = word.lower()
    word_state = ['_'] * len(true_word)
    guessed_letters = set()
    wrong_letters = set()
    wrong_count = 0
    
    if verbose:
        print(f"Word: {true_word}")
        print(f"Initial state: {''.join(word_state)}")
    
    while '_' in word_state and wrong_count < max_wrong:
        # Get prediction
        reveal_ratio = len([c for c in word_state if c != '_']) / len(true_word)
        solver = solver1 if reveal_ratio > 0.65 else solver2
        guess = solver.predict_letter(''.join(word_state), guessed_letters)
        
        if guess is None or guess in guessed_letters:
            break
            
        guessed_letters.add(guess)
        
        # Check if guess is correct
        if guess in true_word:
            # Reveal all instances of the letter
            for i, char in enumerate(true_word):
                if char == guess:
                    word_state[i] = char
            if verbose:
                print(f"Correct guess '{guess}': {''.join(word_state)}")
        else:
            wrong_letters.add(guess)
            wrong_count += 1
            if verbose:
                print(f"Wrong guess '{guess}' ({wrong_count}/{max_wrong}): {''.join(word_state)}")
    
    success = '_' not in word_state
    if verbose:
        print(f"Game result: {'WIN' if success else 'LOSE'}")
        print(f"Final state: {''.join(word_state)}")
        print(f"Wrong guesses: {sorted(wrong_letters)}")
    
    return {
        'success': success,
        'word': true_word,
        'guesses': len(guessed_letters),
        'wrong_guesses': wrong_count,
        'final_state': ''.join(word_state)
    }

def train_model1(words, epochs=10, early_stopping_patience=5):
    print("Updated...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EnhancedHangmanModel1()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(opt, patience=2)
    loss_fn = nn.CrossEntropyLoss()

    best = float('inf')
    patience_counter = 0

    for ep in range(epochs):
        # Curriculum: reveal_ratio increases with epoch (starts hard, becomes easier)
        reveal_schedule = [0.93, 0.87, 0.80, 0.80, 0.73, 0.73, 0.67, 0.67, 0.67, 0.60, 0.60, 0.53, 0.53, 0.47, 0.47, 0.40, 0.40, 0.33, 0.33, 0.27]
        reveal_ratio = reveal_schedule[ep] if ep < len(reveal_schedule) else 0.20
        ds = HangmanDataset1(words, reveal_ratio=reveal_ratio)
        train_len = int(0.9 * len(ds))
        tr, val = random_split(ds, [train_len, len(ds)-train_len])
        dl = DataLoader(tr, shuffle=True, pin_memory=True, batch_size=256, num_workers=4)
        vl = DataLoader(val, pin_memory=True, batch_size=256, num_workers=4)

        print(f"\n--- Epoch {ep+1} | Reveal Ratio: {reveal_ratio:.2f} ---", flush=True)
        model.train()
        total_loss = 0
        batch_count = 0

        for i, batch in enumerate(tqdm(dl, desc="Training", ncols=100)):
            opt.zero_grad()
            out = model(batch['word_state'].to(device), batch['position_context'].to(device),
                        batch['word_length'].to(device), batch['blank_count'].to(device),
                        batch['next_to_vowel'].to(device))
            loss, count = 0, 0
            for b in range(out.size(0)):
                target_pos = batch['target_positions'][b].to(device)
                target_char = batch['target_chars'][b].to(device)
                for p, c in zip(target_pos, target_char):
                    if p >= 0 and c > 0:
                        loss += loss_fn(out[b, p], c-1)
                        count += 1
            if count > 0:
                loss = loss / count
                loss.backward()
                opt.step()
                total_loss += loss.item()
                batch_count += 1
            if i % 20 == 0:
                if isinstance(loss, torch.Tensor):
                    print(f"  Batch {i}/{len(dl)} | Loss: {loss.item():.4f}", flush=True)
                else:
                    print(f"  Batch {i}/{len(dl)} | Loss: N/A (no valid targets)", flush=True)

        train_loss = total_loss / batch_count if batch_count > 0 else 0
        model.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for batch in tqdm(vl, desc="Validation", ncols=100):
                out = model(batch['word_state'].to(device), batch['position_context'].to(device),
                            batch['word_length'].to(device), batch['blank_count'].to(device),
                            batch['next_to_vowel'].to(device))
                loss, count = 0, 0
                for b in range(out.size(0)):
                    target_pos = batch['target_positions'][b].to(device)
                    target_char = batch['target_chars'][b].to(device)
                    for p, c in zip(target_pos, target_char):
                        if p >= 0 and c > 0:
                            loss += loss_fn(out[b, p], c-1)
                            count += 1
                if count > 0:
                    val_loss += loss.item() / count
                    val_batches += 1

        val_loss = val_loss / val_batches if val_batches > 0 else 0
        scheduler.step(val_loss)

        if val_loss < best:
            best = val_loss
            patience_counter = 0
            torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), "best_model1.pth")
            print("✅ Model improved and saved.", flush=True)
        else:
            patience_counter += 1
            print(f"⚠️ No improvement. Patience: {patience_counter}/{early_stopping_patience}", flush=True)
            if patience_counter >= early_stopping_patience:
                print(f"🛑 Early stopping at epoch {ep+1}", flush=True)
                break

    return model

In [3]:
solver1 = HangmanSolver1(model_path= "best_model1.pth")
solver2 = HangmanSolver1(model_path= "best_model2.pth")

In [4]:
# Test on specific word
result = simulate_hangman_game(solver1, solver2, "love", verbose=True)

Word: love
Initial state: ____
Wrong guess 'a' (1/6): ____
Correct guess 'e': ___e
Wrong guess 'r' (2/6): ___e
Correct guess 'l': l__e
Wrong guess 'i' (3/6): l__e
Correct guess 'o': lo_e
Wrong guess 'n' (4/6): lo_e
Correct guess 'v': love
Game result: WIN
Final state: love
Wrong guesses: ['a', 'i', 'n', 'r']
