### TRAINING DATA CREATION

In [None]:
import os
import random
import string
from collections import Counter, defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import time
from joblib import Parallel, delayed
from tqdm import tqdm
import pickle
import multiprocessing

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

MAX_WRONG = 6
BATCH_SIZE = 128
EMBED_DIM = 32
HIDDEN_DIM = 128
LR = 1e-3
EPOCHS = 2000
MAX_WORD_LEN = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ALPHABET = list(string.ascii_uppercase) # this is an issue, no need to convert to uppercase (change later)
LETTER_TO_IDX = {c: i for i, c in enumerate(ALPHABET)}
IDX_TO_LETTER = {i: c for c, i in LETTER_TO_IDX.items()}
PAD_CHAR = "_"
CHAR_VOCAB = [PAD_CHAR] + ALPHABET
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_VOCAB)}
VOCAB_SIZE = len(CHAR_VOCAB)

def load_words(path="words.txt", min_len=2, max_len=MAX_WORD_LEN):
    with open(path, "r", encoding="utf-8") as f:
        words = [line.strip().upper() for line in f if line.strip()]
        print(f"Loaded {len(words)} words from {path}")
   
    def clean(w): #thora unnecessary
        return "".join([c for c in w if c.isalpha()])
    
    words = [clean(w).upper() for w in words if clean(w)] # again upper is wrong (change later)
    words = [w for w in words if min_len <= len(w) <= max_len]
    words = sorted(list(set(words)))
    words = words[:200000] # kernel dying, change from subset to complete (change later)
    return words

# revealed_letters is a set of letters u know, make a string
def mask_word(word, revealed_letters):
    return "".join([c if c in revealed_letters else PAD_CHAR for c in word])

def encode_masked(masked, max_len=MAX_WORD_LEN):
    masked = masked[:max_len] 
    padded = masked + PAD_CHAR * (max_len - len(masked)) # add padding (this may fuck things up cuz model may treat badly(change later))
    return [CHAR_TO_IDX.get(ch, 0) for ch in padded]

def encode_wrong_vec(wrong_set): # make a vector from the wrongly guessed stuff
    vec = np.zeros(26, dtype=np.float32)
    for c in wrong_set:
        if c in LETTER_TO_IDX:
            vec[LETTER_TO_IDX[c]] = 1.0
    return vec

# like one-hot encoding but with one 26bit int, efficiency me help
def set_to_mask(s):
    m = 0
    for c in s:
        m |= 1 << LETTER_TO_IDX[c]
    return m

# optimsied version w bitmask
def words_matching_mask_fast(masked, candidates, wrong_mask):
    res = []
    L = len(masked)
    for w in candidates:
        if len(w) != L: 
            continue
        bad = any(masked[i] != PAD_CHAR and masked[i] != w[i] for i in range(L))
        if bad:
            continue
        if any((wrong_mask >> LETTER_TO_IDX[c]) & 1 for c in w):
            continue
        res.append(w)
    return res

def process_word_batch(word_batch, words_by_len, samples_per_word, max_wrong):
    """process a batch of words to reduce overhead"""
    all_letters = set(ALPHABET)
    all_states = []
    
    for w in word_batch:
        states = []
        L = len(w)
        candidates_len = words_by_len[L]
        unique_letters = sorted(set(w))
        k = len(unique_letters)
        
        # Local cache for this word 
        cache_target = {}

        for _ in range(samples_per_word):
            # how many revealed letters
            reveal_count = random.randint(0, max(0, k-1))
            # which revealed letters
            revealed = set(random.sample(unique_letters, reveal_count)) if reveal_count > 0 else set()
            # alphabets not in word
            wrong_pool = list(all_letters - set(w))
            wrong_count = random.randint(0, max_wrong)
            wrong_set = set(random.sample(wrong_pool, min(wrong_count, len(wrong_pool)))) if wrong_count > 0 else set()

            masked = mask_word(w, revealed)
            wrong_mask = set_to_mask(wrong_set) # get vector / int (int for bitmask)
            revealed_mask = set_to_mask(revealed)
            key = (masked, wrong_mask, revealed_mask)

            if key in cache_target: # if u already have that mask dont reconstruct its target prob (saves time memoisation)
                target_dist = cache_target[key]
            else:
                candidates = words_matching_mask_fast(masked, candidates_len, wrong_mask) 
                if not candidates:  # if no words match the masked pattern, fallback to all words of same length that don't use wrong letters
                    candidates = [x for x in words_by_len[L] if all(((wrong_mask >> LETTER_TO_IDX[c]) & 1) == 0 for c in x)]

                excluded_mask = wrong_mask | revealed_mask
                target_dist = np.zeros(26, dtype=np.float32)
                for cw in candidates:
                    for c in set(cw):
                        if ((excluded_mask >> LETTER_TO_IDX[c]) & 1) == 0:
                            target_dist[LETTER_TO_IDX[c]] += 1
                s = target_dist.sum()
                if s > 0:
                    target_dist /= s
                else:
                    rem = [i for i in range(26) if ((excluded_mask >> i) & 1) == 0]
                    if rem:
                        for i in rem:
                            target_dist[i] = 1.0 / len(rem)
                cache_target[key] = target_dist  # memoise

            states.append({
                "masked": masked,
                "wrong_set": wrong_set,
                "target": target_dist.astype(np.float32),
                "word": w
            })
        
        all_states.extend(states)
    
    return all_states

def generate_training_states(words, max_wrong=MAX_WRONG, samples_per_word=40, n_jobs=-1):
    
    print(f"generating training states with {samples_per_word} samples per word...") # 40 samples thora overkill but u want a lot of data. 
    
    # build words_by_len dictionary, useful when creating masked states ke targets
    words_by_len = defaultdict(list)
    for w in words:
        words_by_len[len(w)].append(w)
    
    
    if n_jobs == -1:
        n_jobs = multiprocessing.cpu_count() # use all pcpu cores. 
    
    # create batches to reduce overhead
    batch_size = max(1, len(words) // (n_jobs * 4))  # 4 batches per core for load balancing
    word_batches = [words[i:i + batch_size] for i in range(0, len(words), batch_size)]
    
    print(f"Using {n_jobs} CPU cores with {len(word_batches)} batches...")
    
    # parallel process, n_jobs says kitne cpu use, loky is just used cuz stable or sumn (change this later), tqdm should ideally give a progress bar cuz cant add random print stmts with parallel processing.  
    results = Parallel(n_jobs=n_jobs, backend='loky', verbose=0)(
        delayed(process_word_batch)(batch, words_by_len, samples_per_word, max_wrong) 
        for batch in tqdm(word_batches, desc="Processing batches")
    )
    
    # parallel will give u a list of results per batch so return list of states
    all_states = []
    for states in results:
        all_states.extend(states)
    
    print(f"Generated {len(all_states)} total training states")
    return all_states

def main():
    words = load_words("words.txt", min_len=2, max_len=MAX_WORD_LEN)
    print("words length", len(words))
    
    samples_per_word = 40 # change this agar u dont have time (change later)
    random.shuffle(words) 
    split = int(0.8 * len(words))
    train_words = words[:split]
    test_words = words[split:]
    
    print(f"Train words: {len(train_words)}, Test words: {len(test_words)}")
    
    
    start = time.time()
    train_states = generate_training_states(
        words, 
        max_wrong=MAX_WRONG, 
        samples_per_word=samples_per_word,
        n_jobs=-1
    )
    end = time.time()
    
    print(f"Time to generate training states: {end - start:.2f} seconds")
    
    # Save states
    states_file = "states_all.pkl"
    with open(states_file, "wb") as f:
        pickle.dump(train_states, f)
        print(f"Saved {len(train_states)} training states to {states_file}")

if __name__ == "__main__":
    main()

### TRAINING

In [None]:
import pickle, torch, os, random, time
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

MAX_WORD_LEN = 30
BATCH_SIZE = 2048  # make this larger so that faster 
EPOCHS = 300 # if no plateau increase (change later)
LR = 1e-4 
MAX_WRONG = 6 # reduce this to constrain model to learn quicker (idk if this will work (change later))

import string
ALPHABET = list(string.ascii_uppercase)
LETTER_TO_IDX = {c: i for i, c in enumerate(ALPHABET)}
IDX_TO_LETTER = {i: c for c, i in LETTER_TO_IDX.items()}
PAD_CHAR = "_"
CHAR_VOCAB = [PAD_CHAR] + ALPHABET
CHAR_TO_IDX = {c:i for i,c in enumerate(CHAR_VOCAB)}
VOCAB_SIZE = len(CHAR_VOCAB)

def encode_masked(masked, max_len=MAX_WORD_LEN):
    padded = masked[:max_len] + PAD_CHAR*(max_len-len(masked))
    return [CHAR_TO_IDX.get(ch,0) for ch in padded]

def encode_wrong_vec(wrong_set):
    import numpy as np
    vec = np.zeros(26, dtype=np.float32)
    for c in wrong_set:
        if c in LETTER_TO_IDX:
            vec[LETTER_TO_IDX[c]] = 1.0
    return vec

class HangmanDatasetPreEncoded(Dataset):
    def __init__(self, encoded_states):
        self.encoded_states = encoded_states
    
    def __len__(self):
        return len(self.encoded_states)
    
    def __getitem__(self, idx):
        s = self.encoded_states[idx]
        return s['X_mask'], s['X_wrong'], s['y']

# Main model.
class HangmanModel(nn.Module):
    def __init__(self, char_vocab_size=VOCAB_SIZE, emb_dim=32, hidden_dim=128, out_dim=26):
        super().__init__()
        self.emb = nn.Embedding(char_vocab_size, emb_dim, padding_idx=0)
        self.gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim + out_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x_mask, x_wrong):
        emb = self.emb(x_mask)
        _, h = self.gru(emb)
        h = h.squeeze(0)
        concat = torch.cat([h, x_wrong], dim=1)
        x = F.relu(self.fc1(concat))
        logits = self.fc2(x)
        return F.log_softmax(logits, dim=1)

def soft_target_loss(log_probs, target_probs):
    '''the commented out lines are for word-weighted loss, otherwise CE loss'''
    # target_classes = target_probs.argmax(dim=1)
    # weights = class_weights[target_classes]
    # return -(target_probs * log_probs).sum(dim=1).mul(weights).mean()
    return -(target_probs * log_probs).sum(dim=1).mean()

# Use the masked dataset
print("Loading states...") 
with open("states_all.pkl","rb") as f:
    train_states = pickle.load(f)
random.shuffle(train_states)

# split into train/val
split_idx = int(0.9 * len(train_states))
train_data = train_states[:split_idx]
val_data = train_states[split_idx:]

# pre encoding will prevent recomputation in the train run, a one-time investment. Unable to save cuz pkl taking too long. maybe try saving np (change later)
print(f"pre-encoding {len(train_data)} train + {len(val_data)} val states...")
train_encoded = []
for i, s in enumerate(train_data):
    if i % 100000 == 0:
        print(f"  Encoded train {i}/{len(train_data)}")
    train_encoded.append({
        'X_mask': torch.tensor(encode_masked(s["masked"], MAX_WORD_LEN), dtype=torch.long),
        'X_wrong': torch.tensor(encode_wrong_vec(s["wrong_set"]), dtype=torch.float32),
        'y': torch.tensor(s["target"], dtype=torch.float32)
    })

val_encoded = []
for i, s in enumerate(val_data):
    val_encoded.append({
        'X_mask': torch.tensor(encode_masked(s["masked"], MAX_WORD_LEN), dtype=torch.long),
        'X_wrong': torch.tensor(encode_wrong_vec(s["wrong_set"]), dtype=torch.float32),
        'y': torch.tensor(s["target"], dtype=torch.float32)
    })

#save encoded
# with open("train_encoded.pkl","wb") as f:
#     pickle.dump(train_encoded,f)
# with open("val_encoded.pkl","wb") as f:
#     pickle.dump(val_encoded,f)
    
print("Loading pre-encoded states from file...")
train_dataset = HangmanDatasetPreEncoded(train_encoded)
val_dataset = HangmanDatasetPreEncoded(val_encoded)

# Data Loader apparently does things quicker (shuffle etc)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)
print(f"Dataset ready: {len(train_data)} train, {len(val_data)} val, batch_size={BATCH_SIZE}")

model = HangmanModel().to(DEVICE)
model = nn.DataParallel(model, device_ids=[0,1,2]) # wull help running parallel

# apparently ADAMW works better than ADAM, test ADAM if time (change later)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

# Load checkpoint BEFORE compiling CHANGE FILE LOCATION CUREENTLY WRONG TO TRAIN FROM SCRATCH (change later)
if os.path.exists("hangman_model3150.pth"):
    checkpoint = torch.load("hangman_model350.pth", map_location=DEVICE)
    
    state_dict = checkpoint["model_state"]
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace("_orig_mod.", "")
        new_state_dict[new_key] = v
    
    model.load_state_dict(new_state_dict)
    start_epoch = checkpoint.get("epoch", 550)+1
    print(f"Resuming from epoch {start_epoch}")
else:
    start_epoch = 1

# sleep for 10 seconds kernel crash maybe
time.sleep(30)
print("Starting compile...") # compile improves model executiuon time, commented out cuz kernel crash for complete dataset
# Compile model for speed AFTER loading checkpoint (PyTorch 2.0+)
# try:
#     model = torch.compile(model, mode='max-autotune')
#     print("Model compiled with torch.compile")
# except:
#     print("torch.compile not available, skipping")

time.sleep(20)
print("starting scaler..."
      )
scaler = GradScaler() # syntax issue (scaler reduces fp precision without costing much accuracy)

print("Starting scheduler") # adjust the LR accordingly to emperically proven results. HEre some COSINE Curve, can try other methods (CHANGE LATER), the args for this are based on how often u wanna restart the cosine curve of LEARNING
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=50, T_mult=2, eta_min=1e-6
)


print("Starting training...")
for epoch in range(start_epoch, EPOCHS+1):
    model.train()
    total_loss = 0.0
    epoch_start = time.time()
    data_time = 0
    train_time = 0
    all_logits = []
    all_targets = []
    
    batch_start = time.time()

    for batch_idx, (X_mask, X_wrong, y) in enumerate(train_loader):
        data_time += time.time() - batch_start
        train_start = time.time()
        
        # the non blocking will ensure that data loasing doesnt overlap with GPU computaiton
        X_mask = X_mask.to(DEVICE, non_blocking=True)
        X_wrong = X_wrong.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast(): # this line is for some scaler shit
            logp = model(X_mask, X_wrong)
            loss = soft_target_loss(logp, y)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        
        # grad clip for explding and vanish grads
        total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * X_mask.size(0)
        
        if batch_idx == 0:
            all_logits.append(logp.detach().cpu())
            all_targets.append(y.detach().cpu())
        
        train_time += time.time() - train_start
        batch_start = time.time()

    avg_train_loss = total_loss / len(train_dataset)
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct_per_class = torch.zeros(26)
    total_per_class = torch.zeros(26)
    
    with torch.no_grad():
        for X_mask, X_wrong, y in val_loader:
            X_mask = X_mask.to(DEVICE, non_blocking=True)
            X_wrong = X_wrong.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)
            
            with autocast():
                logp = model(X_mask, X_wrong)
                loss = soft_target_loss(logp, y)
            
            val_loss += loss.item() * X_mask.size(0)
            
            # Per-class accuracy (note that this accuracy term may not be a good repr )
            preds = logp.argmax(dim=1).cpu()
            targets = y.argmax(dim=1).cpu()
            for cls in range(26):
                mask = (targets == cls)
                total_per_class[cls] += mask.sum().item()
                correct_per_class[cls] += ((preds == targets) & mask).sum().item()
    
    avg_val_loss = val_loss / len(val_dataset)
    scheduler.step()
    epoch_time = time.time() - epoch_start
    
    #log
    current_lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch}/{EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.6f}  Val Loss: {avg_val_loss:.6f}  LR: {current_lr:.6f}")
    print(f"  Total Grad Norm: {total_grad_norm:.4f}")
    
    #track gradient to change LR scheduler etc
    if epoch % 10 == 0:
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                print(f"    {name}: {grad_norm:.4f}")
    
    # Logit histogram
    if len(all_logits) > 0:
        logits_sample = torch.cat(all_logits, dim=0)
        print(f"stats: min={logits_sample.min():.3f}, max={logits_sample.max():.3f}, mean={logits_sample.mean():.3f}")
    
    # Per-class accuracy (top 5 and bottom 5)
    per_class_acc = correct_per_class / (total_per_class + 1e-8)
    sorted_idx = per_class_acc.argsort(descending=True)
    print(f"  Top 5 classes: ", end="")
    for idx in sorted_idx[:5]:
        if total_per_class[idx] > 0:
            print(f"{ALPHABET[idx]}:{per_class_acc[idx]:.3f} ", end="")
    print(f"\n  Bottom 5 classes: ", end="")
    for idx in sorted_idx[-5:]:
        if total_per_class[idx] > 0:
            print(f"{ALPHABET[idx]}:{per_class_acc[idx]:.3f} ", end="")
    print()


    #save checkpoint every 50 epochs
    if epoch % 50 == 0:
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "CHAR_VOCAB": CHAR_VOCAB,
            "LETTER_TO_IDX": LETTER_TO_IDX
        }, f"hangman_model{epoch}.pth")
        print(f"  Checkpoint saved: hangman_model{epoch}.pth")

        

16
