In [1]:
# char_transformer_masked_train_expert.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random

# --- Constants ---
MASK_IDX = 26
PAD_IDX = 27
VOCAB_SIZE = 28
CHAR_VOCAB = list("abcdefghijklmnopqrstuvwxyz")
char2idx = {ch: i for i, ch in enumerate(CHAR_VOCAB)}
char2idx["_"] = MASK_IDX

# --- Dataset ---
class MultiMaskCharDataset(Dataset):
    def __init__(self, words, min_mask=1, max_mask=2):
        self.data = []
        for word in words:
            if not word.isalpha() or len(word) < 4:
                continue
            word = word.lower()
            n_mask = random.randint(min_mask, min(max_mask, len(word)))
            mask_indices = random.sample(range(len(word)), n_mask)

            input_seq = [char2idx.get(ch, MASK_IDX) for ch in word]
            label_seq = [-100] * len(word)

            for i in mask_indices:
                input_seq[i] = MASK_IDX
                label_seq[i] = char2idx[word[i]]

            self.data.append((input_seq, label_seq))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx][0]), torch.tensor(self.data[idx][1])

# --- Collate Function ---
def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_len = max(len(seq) for seq in inputs)
    padded_inputs = torch.full((len(inputs), max_len), PAD_IDX, dtype=torch.long)
    padded_labels = torch.full((len(labels), max_len), -100, dtype=torch.long)
    for i in range(len(batch)):
        padded_inputs[i, :len(inputs[i])] = inputs[i]
        padded_labels[i, :len(labels[i])] = labels[i]
    return padded_inputs, padded_labels

# --- Transformer Model ---
class CharTransformer(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=128, nhead=4, num_layers=4, ff_dim=256, dropout=0.2, max_len=20):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, 26)

    def forward(self, x):
        B, T = x.shape
        positions = torch.arange(0, T, device=x.device).unsqueeze(0).expand(B, T)
        emb = self.embedding(x) + self.pos_embedding(positions)
        out = self.transformer(emb)
        out = self.ln(out)
        return self.fc(out)

# --- Training Loop ---
def train_transformer_expert(all_words, expert_name, min_mask, max_mask, epochs=25, batch_size=64):
    print(f"\n--- Training TRANSFORMER Expert: {expert_name} ---")
    train_words, val_words = train_test_split(all_words, test_size=0.01, random_state=42)

    train_dataset = MultiMaskCharDataset(train_words, min_mask=min_mask, max_mask=max_mask)
    val_dataset = MultiMaskCharDataset(val_words, min_mask=min_mask, max_mask=max_mask)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = CharTransformer().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=1)

    best_val_loss = float('inf')
    patience_counter = 0
    patience_limit = 3

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss = criterion(logits.view(-1, 26), labels.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                logits = model(inputs)
                loss = criterion(logits.view(-1, 26), labels.view(-1))
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            save_name = f"transformer_{expert_name}.pth"
            torch.save(model.state_dict(), save_name)
            print(f"Validation improved. Saved to {save_name}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience {patience_counter}/{patience_limit}")

        if patience_counter >= patience_limit:
            print("Early stopping.")
            break

    print(f"✅ Training complete for expert: {expert_name}")

# --- Execute ---
if __name__ == '__main__':
    with open("words_250000_train.txt") as f:
        words = [line.strip().lower() for line in f if line.strip().isalpha()]

    # Expert 1: Late-game (1–2 masks)
    train_transformer_expert(words, expert_name="late", min_mask=1, max_mask=2)

    # Expert 2: Mid-game (2–4 masks)
    train_transformer_expert(words, expert_name="mid", min_mask=2, max_mask=4)

    # Expert 3: Early-game (4–6 masks)
    train_transformer_expert(words, expert_name="early", min_mask=4, max_mask=6)


--- Training TRANSFORMER Expert: late ---
Epoch 1/25 | Train Loss: 2.2672 | Val Loss: 2.0123
Validation improved. Saved to transformer_late.pth
Epoch 2/25 | Train Loss: 2.0798 | Val Loss: 1.9158
Validation improved. Saved to transformer_late.pth
Epoch 3/25 | Train Loss: 2.0300 | Val Loss: 1.8564
Validation improved. Saved to transformer_late.pth
Epoch 4/25 | Train Loss: 2.0003 | Val Loss: 1.8312
Validation improved. Saved to transformer_late.pth
Epoch 5/25 | Train Loss: 1.9809 | Val Loss: 1.8183
Validation improved. Saved to transformer_late.pth
Epoch 6/25 | Train Loss: 1.9714 | Val Loss: 1.7977
Validation improved. Saved to transformer_late.pth
Epoch 7/25 | Train Loss: 1.9549 | Val Loss: 1.7912
Validation improved. Saved to transformer_late.pth
Epoch 8/25 | Train Loss: 1.9455 | Val Loss: 1.7781
Validation improved. Saved to transformer_late.pth
Epoch 9/25 | Train Loss: 1.9373 | Val Loss: 1.7886
No improvement. Patience 1/3
Epoch 10/25 | Train Loss: 1.9297 | Val Loss: 1.7769
Validatio