# TRAIN AND LEARN OF word_language_model

## Model code

## Uploading files from Google Disk

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from __future__ import unicode_literals, print_function, division
import os
from io import open
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import time
import random
import argparse
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from collections import Counter
from typing import Dict, List, Union, Tuple, Optional

### data.py

In [48]:
class Dictionary(object):
    def __init__(self) -> None:
        self.word2idx: Dict[str, int] = {}
        self.idx2word: List[str] = []
        self.word2count: Counter = Counter()

    def add_word(self, word: str) -> int:
        self.word2count[word] += 1
        return self.word2idx.get(word, -1)

    def finalize(self, min_freq: int = 5) -> None:
        for word, count in self.word2count.items():
            if count >= min_freq and word not in self.word2idx:
                self.idx2word.append(word)
                self.word2idx[word] = len(self.idx2word) - 1
        self.word2idx['<unk>'] = len(self.idx2word)
        self.idx2word.append('<unk>')

    def __len__(self) -> int:
        return len(self.idx2word)

class Corpus(object):
  """Corpus class for loading and tokenizing text data."""
  def __init__(self, path: str, min_freq: int = 5) -> None:
      self.dictionary: Dictionary = Dictionary()
      self.min_freq: int = min_freq
      self.train: torch.Tensor = self.tokenize(os.path.join(path, 'train.txt'))
      self.valid: torch.Tensor = self.tokenize(os.path.join(path, 'valid.txt'))
      self.test: torch.Tensor = self.tokenize(os.path.join(path, 'test.txt'))

  def tokenize(self, path: str) -> torch.Tensor:
      """Tokenizes a text file."""
      assert os.path.exists(path)
      # Add words to the dictionary
      with open(path, 'r', encoding="utf8") as f:
          for line in f:
              words: List[str] = line.split() + ['<eos>']
              for word in words:
                  self.dictionary.add_word(word)
      self.dictionary.finalize(self.min_freq)
      # Tokenize file content
      with open(path, 'r', encoding="utf8") as f:
          idss: List[torch.Tensor] = []
          for line in f:
              words: List[str] = line.split() + ['<eos>']
              ids: List[int] = []
              for word in words:
                  idx = self.dictionary.word2idx.get(word, self.dictionary.word2idx['<unk>'])
                  ids.append(idx)
              idss.append(torch.tensor(ids, dtype=torch.int64))
          ids_tensor: torch.Tensor = torch.cat(idss)

      return ids_tensor

In [4]:
class Old_Dictionary(object):
    """Dictionary for word-to-index mapping."""

    def __init__(self) -> None:
        self.word2idx: Dict[str, int] = {}
        self.idx2word: List[str] = []

    def add_word(self, word: str) -> int:
        """Add a word to the dictionary."""
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self) -> int:
        return len(self.idx2word)


class Old_Corpus(object):
    """Corpus class for loading and tokenizing text data."""

    def __init__(self, path: str) -> None:
        self.dictionary: Old_Dictionary = Old_Dictionary()
        self.train: torch.Tensor = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid: torch.Tensor = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test: torch.Tensor = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path: str) -> torch.Tensor:
        """Tokenizes a text file."""
        assert os.path.exists(path), f"Path {path} does not exist"

        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words: List[str] = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss: List[torch.Tensor] = []
            for line in f:
                words: List[str] = line.split() + ['<eos>']
                ids: List[int] = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids, dtype=torch.int64))
            ids_tensor: torch.Tensor = torch.cat(idss)

        return ids_tensor

### model.py

In [49]:
# ===============================
# MODEL ARCHITECTURES (model.py)
# ===============================

class RNNModel(nn.Module):
    """RNN-based language model (LSTM/GRU/RNN)."""

    def __init__(
        self,
        rnn_type: str,
        ntoken: int,
        ninp: int,
        nhid: int,
        nlayers: int,
        dropout: float = 0.5,
        tie_weights: bool = False
    ) -> None:
        super(RNNModel, self).__init__()
        self.ntoken: int = ntoken
        self.rnn_type: str = rnn_type
        self.nhid: int = nhid
        self.nlayers: int = nlayers

        self.drop: nn.Dropout = nn.Dropout(dropout)
        self.encoder: nn.Embedding = nn.Embedding(ntoken, ninp)

        if rnn_type in ['LSTM', 'GRU']:
            self.rnn: nn.Module = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            try:
                nonlinearity: str = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            except KeyError as e:
                raise ValueError(
                    "Invalid option for `--model`. "
                    "Options are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU']"
                ) from e
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)

        self.decoder: nn.Linear = nn.Linear(nhid, ntoken)

        # Tie weights
        if tie_weights:
            if nhid != ninp:
                raise ValueError('When using tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight

        self.init_weights()

    def init_weights(self) -> None:
        """Initialize weights."""
        initrange: float = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        emb: torch.Tensor = self.drop(self.encoder(input))
        output: torch.Tensor
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded: torch.Tensor = self.decoder(output)
        decoded = decoded.view(-1, self.ntoken)
        return F.log_softmax(decoded, dim=1), hidden

    def init_hidden(self, bsz: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Initialize hidden state."""
        weight: torch.Tensor = next(self.parameters())
        if self.rnn_type == 'LSTM':
            return (
                weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid)
            )
        else:
            return weight.new_zeros(self.nlayers, bsz, self.nhid)

# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    """Positional encoding for Transformer."""

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        self.dropout: nn.Dropout = nn.Dropout(p=dropout)

        pe: torch.Tensor = torch.zeros(max_len, d_model)
        position: torch.Tensor = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term: torch.Tensor = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Transformer):
    """Transformer-based language model."""

    def __init__(
        self,
        ntoken: int,
        ninp: int,
        nhead: int,
        nhid: int,
        nlayers: int,
        dropout: float = 0.5
    ) -> None:
        super(TransformerModel, self).__init__(
            d_model=ninp,
            nhead=nhead,
            dim_feedforward=nhid,
            num_encoder_layers=nlayers
        )
        self.model_type: str = 'Transformer'
        self.src_mask: Optional[torch.Tensor] = None
        self.pos_encoder: PositionalEncoding = PositionalEncoding(ninp, dropout)

        self.input_emb: nn.Embedding = nn.Embedding(ntoken, ninp)
        self.ninp: int = ninp
        self.decoder: nn.Linear = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """Generate mask for causal attention."""
        return torch.log(torch.tril(torch.ones(sz, sz)))

    def init_weights(self) -> None:
        """Initialize weights."""
        initrange: float = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src: torch.Tensor, has_mask: bool = True) -> torch.Tensor:
        if has_mask:
            device: torch.device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask: torch.Tensor = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output: torch.Tensor = self.encoder(src, mask=self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)

In [50]:
# Label smoothing loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing: float = 0.0) -> None:
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing: float = smoothing
        self.confidence: float = 1.0 - smoothing

    def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        log_probs: torch.Tensor = output
        n_classes: int = log_probs.size(-1)
        with torch.no_grad():
            true_dist: torch.Tensor = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

### Test train on original data

In [51]:
# ===============================
# TRAINING & EVALUATION (main.py)
# ===============================

def get_lr(step: float, d_model: float, warmup_steps: int) -> float:
    """Gets the learning rate step."""
    lr: float = d_model ** -0.5 * min(step ** -0.5, step * warmup_steps ** -1.5)
    return lr

def batchify(data: torch.Tensor, bsz: int, device: torch.device) -> torch.Tensor:
    """Divide data into batches."""
    nbatch: int = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def get_batch(source: torch.Tensor, i: int, bptt: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get a batch of data."""
    seq_len: int = min(bptt, len(source) - 1 - i)
    data: torch.Tensor = source[i:i+seq_len]
    target: torch.Tensor = source[i+1:i+1+seq_len].view(-1)
    return data, target


def repackage_hidden(h: Union[torch.Tensor, Tuple]) -> Union[torch.Tensor, Tuple]:
    """Detach hidden state from history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

def top_k_sampling(logits: torch.Tensor, k: int, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """Top k-sampling in generate function."""
    values, indices = torch.topk(logits, k)
    values = values.div(temperature).exp()
    values = values / values.sum()
    return torch.multinomial(values, 1), indices

def evaluate(
    model: nn.Module,
    data_source: torch.Tensor,
    criterion: nn.Module,
    bptt: int,
    ntokens: int,
    eval_batch_size: int,
    is_transformer: bool
) -> float:
    """Evaluate the model."""
    model.eval()
    total_loss: float = 0.0
    if not is_transformer:
        hidden = model.init_hidden(eval_batch_size)

    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data: torch.Tensor
            targets: torch.Tensor
            data, targets = get_batch(data_source, i, bptt)

            if is_transformer:
                output: torch.Tensor = model(data)
                output = output.view(-1, ntokens)
            else:
                output, hidden = model(data, hidden)
                hidden = repackage_hidden(hidden)

            total_loss += len(data) * criterion(output, targets).item()

    return total_loss / (len(data_source) - 1)


def train_epoch(
    model: nn.Module,
    train_data: torch.Tensor,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    epoch: int,
    bptt: int,
    ntokens: int,
    batch_size: int,
    clip: float,
    log_interval: int,
    is_transformer: bool,
    use_optimizer: bool = True,
    use_warmup: bool = False,
    step: int = 0,
    d_model: int = 512,
    warmup_steps: int = 4000,
    dry_run: bool = False
) -> int:
    """Train for one epoch."""
    model.train()
    total_loss: float = 0.0
    start_time: float = time.time()

    if not is_transformer:
        hidden = model.init_hidden(batch_size)

    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data: torch.Tensor
        targets: torch.Tensor
        data, targets = get_batch(train_data, i, bptt)

        optimizer.zero_grad() if use_optimizer else model.zero_grad()

        if is_transformer:
            output: torch.Tensor = model(data)
            output = output.view(-1, ntokens)
        else:
            hidden = repackage_hidden(hidden)
            output, hidden = model(data, hidden)

        loss: torch.Tensor = criterion(output, targets)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        if use_warmup:
            for param_group in optimizer.param_groups:
                param_group['lr'] = get_lr(step + 1, d_model, warmup_steps)
        if use_optimizer:
            optimizer.step()
        else:
            for p in model.parameters():
                p.data.add_(p.grad, alpha=-optimizer.param_groups[0]['lr'])

        total_loss += loss.item()
        step += 1

        if batch % log_interval == 0 and batch > 0:
            cur_loss: float = total_loss / log_interval
            elapsed: float = time.time() - start_time
            print(
                f'| epoch {epoch:3d} | {batch:5d}/{len(train_data) // bptt:5d} batches | '
                f'lr {optimizer.param_groups[0]["lr"]:02.6f} | ms/batch {elapsed * 1000 / log_interval:5.2f} | '
                f'loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}'
            )
            total_loss = 0
            start_time = time.time()
        if dry_run:
            break
    return step


def train_model(
    model_type: str = 'LSTM', # RNN_TANH, RNN_RELU, LSTM, GRU, Transformer
    data_path: str = '/content/drive/MyDrive/data_word_train/wikitext-2',
    emsize: int = 200,
    nhid: int = 200,
    nlayers: int = 2,
    lr: float = 0.001,
    clip: float = 0.25,
    epochs: int = 60,
    batch_size: int = 20,
    bptt: int = 35,
    dropout: float = 0.2,
    tied: bool = False,
    nhead: int = 2,
    log_interval: int = 200,
    save_path: str = 'model.pt',
    onnx_export: str = '',
    dry_run: bool = False,
    accel: bool = True,
    use_optimizer: bool = True,
    optimizer_type: str = 'AdamW',
    weight_decay: Optional[float] = None,
    use_betas: bool = False,
    betas: Optional[Tuple[float, float]] = (0.9, 0.98),
    use_eps: bool = False,
    eps: float = 1e-9,
    criterion: Optional[nn.Module] = None,
    use_label_smoothing: bool = False,
    label_smoothing: float = 0.1,
    use_warmup: bool = False,
    warmup_steps: int = 4000,
    min_freq: int = 5,
    seed: int = 1111,
    old_version: bool = True
) -> None:
    """Main training function."""

    if data_path == '/content/drive/MyDrive/data_word_train/wikitext-2':
      if not os.path.exists(data_path):
          print("Downloading Wikitext-2 dataset...")
          !wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip -P /content/
          !unzip /content/wikitext-2-v1.zip -d /content/data_word_train/
          !mkdir -p /content/drive/MyDrive/data_word_train/
          !mv /content/data_word_train/wikitext-2 /content/drive/MyDrive/data_word_train/
          print("Wikitext-2 dataset moved to Google Drive")

    torch.manual_seed(seed)

    # Set device
    device: torch.device = torch.device('cuda' if accel and torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    corpus: Union[Corpus, Old_Corpus] = Old_Corpus(data_path) if old_version else Corpus(data_path, min_freq=min_freq)
    print(f"Vocabulary size: {len(corpus.dictionary)}")

    eval_batch_size: int = 10
    train_data: torch.Tensor = batchify(corpus.train, batch_size, device)
    val_data: torch.Tensor = batchify(corpus.valid, eval_batch_size, device)
    test_data: torch.Tensor = batchify(corpus.test, eval_batch_size, device)

    # Build model
    ntokens: int = len(corpus.dictionary)
    is_transformer: bool = model_type == 'Transformer'

    model: nn.Module = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device) if is_transformer else RNNModel(model_type, ntokens, emsize, nhid, nlayers, dropout, tied).to(device)
    # Loss and optimizer (Adam with weight_decay as in Transformer paper)
    criterion: nn.Module = criterion if criterion is not None else (LabelSmoothingLoss(smoothing=label_smoothing) if use_label_smoothing else nn.NLLLoss())
    if use_betas == True and use_eps == True:
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) if optimizer_type == 'AdamW' else optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
    elif use_betas == False and use_eps == True:
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps) if optimizer_type == 'AdamW' else optim.Adam(model.parameters(), lr=lr, eps=eps)
    elif use_betas == False and use_eps == False:
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) if optimizer_type == 'AdamW' else optim.Adam(model.parameters(), lr=lr)

    scheduler: optim.lr_scheduler.ReduceLROnPlateau = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    ) if not use_warmup else None

    # Training loop
    best_val_loss: Optional[float] = None
    global_step: int = 0

    try:
        for epoch in range(1, epochs + 1):
            epoch_start_time: float = time.time()

            global_step = train_epoch(
                model, train_data, criterion, optimizer, epoch, bptt, ntokens, batch_size, clip, log_interval, is_transformer,
                use_optimizer, use_warmup, global_step, emsize, warmup_steps, dry_run
            )

            val_loss: float = evaluate(
                model, val_data, criterion, bptt, ntokens,
                eval_batch_size, is_transformer
            )

            print('-' * 89)
            print(
                f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | '
                f'valid loss {val_loss:5.2f} | valid ppl {math.exp(val_loss):8.2f}'
            )
            print('-' * 89)

            # Save best model
            if not best_val_loss or val_loss < best_val_loss:
                with open(save_path, 'wb') as f:
                    torch.save(model, f)
                best_val_loss = val_loss

            # Learning rate scheduling
            # lr /= 4.0
            if use_warmup:
                print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")
            else:
                scheduler.step(val_loss)
                print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')
    safe_globals: List = [
        RNNModel, TransformerModel, PositionalEncoding, Old_Dictionary, Dictionary, Old_Corpus, Corpus,
        nn.Dropout, nn.Linear, nn.GRU, nn.LSTM, nn.RNN, nn.Embedding,
        nn.TransformerEncoder, nn.TransformerEncoderLayer, nn.MultiheadAttention,
        nn.LayerNorm, F.relu, nn.ModuleList, nn.modules.linear.NonDynamicallyQuantizableLinear
    ]
    with torch.serialization.safe_globals(safe_globals):
        with open(save_path, 'rb') as f:
            model = torch.load(f, map_location=device)
    test_loss: float = evaluate(
        model, test_data, criterion, bptt, ntokens, eval_batch_size, is_transformer
    )
    print('=' * 89)
    print(
        f'| End of training | test loss {test_loss:5.2f} | '
        f'test ppl {math.exp(test_loss):8.2f}'
    )
    print('=' * 89)


In [54]:
# ===============================
# TEXT GENERATION (generate.py)
# ===============================

def generate_text(
    checkpoint: str = 'model.pt',
    data_path: str = '/content/drive/MyDrive/data_word_train/wikitext-2',
    outf: str = 'generated.txt',
    words: int = 1000,
    temperature: float = 1.0,
    top_k: int = 40,
    seed: int = 1111,
    log_interval: int = 100,
    accel: bool = True,
    min_freq: int = 5,
    use_top_k: bool = False,
    old_version: bool = True
) -> None:
    """Generate text from trained model."""

    torch.manual_seed(seed)
    device: torch.device = torch.device('cuda' if accel and torch.cuda.is_available() else 'cpu')
    assert len(corpus.dictionary) == model.ntoken, f"Vocabulary size mismatch: {len(corpus.dictionary)} vs {model.ntoken}"

    # Load model
    safe_globals: List = [
        RNNModel, TransformerModel, PositionalEncoding, Old_Dictionary, Dictionary, Old_Corpus, Corpus,
        nn.Dropout, nn.Linear, nn.GRU, nn.LSTM, nn.RNN, nn.Embedding,
        nn.TransformerEncoder, nn.TransformerEncoderLayer, nn.MultiheadAttention,
        nn.LayerNorm, F.relu, nn.ModuleList, nn.modules.linear.NonDynamicallyQuantizableLinear
    ]
    with torch.serialization.safe_globals(safe_globals):
        with open(checkpoint, 'rb') as f:
            model: nn.Module = torch.load(f, map_location=device)
    model.eval()

    # Load corpus
    corpus: Union[Corpus, Old_Corpus] = Old_Corpus(data_path) if old_version else Corpus(data_path, min_freq=min_freq)
    print(f"Vocabulary size: {len(corpus.dictionary)}")
    print(f"Vocabulary size: {len(corpus.dictionary)}")
    ntokens: int = len(corpus.dictionary)

    is_transformer: bool = hasattr(model, 'model_type') and model.model_type == 'Transformer'
    if not is_transformer:
        hidden = model.init_hidden(1)

    input: torch.Tensor = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

    with open(outf, 'w') as outfile:
        with torch.no_grad():
            for i in range(words):
                if is_transformer:
                    output: torch.Tensor = model(input, False)
                    if use_top_k:
                        word_weights: torch.Tensor = output[-1].squeeze().cpu()
                        prob, top_indices = top_k_sampling(word_weights, top_k, temperature)
                        word_idx: int = top_indices[prob.item()].item()
                    else:
                        word_weights: torch.Tensor = output[-1].squeeze().div(temperature).exp().cpu()
                        word_idx: int = torch.multinomial(word_weights, 1)[0].item()
                    word_tensor: torch.Tensor = torch.Tensor([[word_idx]]).long().to(device)
                    input = torch.cat([input, word_tensor], 0)

                else:
                    output, hidden = model(input, hidden)
                    word_weights = output.squeeze().div(temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0].item()
                    input.fill_(word_idx)

                word: str = corpus.dictionary.idx2word[word_idx]
                if word == '@-@' or word == '@.@' or word == '@,@':
                  word = ' '
                outfile.write(word + ('\n' if i % 20 == 19 else ' '))

                if i % log_interval == 0:
                    print(f'| Generated {i}/{words} words')

In [29]:
# Example 1: Train on WikiText-2
print("Training LSTM on WikiText-2...")
train_model(
    model_type = 'LSTM', # RNN_TANH, RNN_RELU, LSTM, GRU, Transformer
    data_path = '/content/drive/MyDrive/data_word_train/wikitext-2',
    emsize = 400,
    nhid = 400,
    nlayers = 4,
    lr = 0.001,
    clip = 0.25,
    epochs = 60,
    batch_size = 20,
    bptt = 35,
    dropout = 0.2,
    tied = False,
    nhead = 4,
    log_interval = 200,
    save_path = 'model.pt',
    onnx_export = '',
    dry_run = False,
    accel = True,
    use_optimizer = True,
    optimizer_type = 'AdamW',
    weight_decay=1e-5,
    use_betas = False,
    use_eps = False,
    criterion = nn.NLLLoss(),
    use_label_smoothing = False,
    label_smoothing = 0.1,
    use_warmup = False,
    warmup_steps = 4000,
    min_freq = 5,
    seed = 1111,
    old_version = True
)

# Example 2: Generate text
print("\nGenerating text...")
generate_text(
    checkpoint='model.pt',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    words=1000,
    temperature=1.0,
    old_version=True,
    use_top_k=False,
    accel = True
)

!cat generated.txt

# Example 3: Train on custom names dataset
# First, create the data files (see instructions below)
print("\nTraining on custom names dataset...")
# train_model(
#     model_type='LSTM',
#     data_path='./data/names',
#     emsize=128,
#     nhid=128,
#     nlayers=2,
#     epochs=20,
#     lr=0.001
# )

Training LSTM on WikiText-2...
Using device: cuda
Vocabulary size: 33278
| epoch   1 |   200/ 2983 batches | lr 0.001000 | ms/batch 24.86 | loss  7.33 | ppl  1522.35
| epoch   1 |   400/ 2983 batches | lr 0.001000 | ms/batch 24.81 | loss  7.10 | ppl  1207.10
| epoch   1 |   600/ 2983 batches | lr 0.001000 | ms/batch 24.89 | loss  7.09 | ppl  1196.08
| epoch   1 |   800/ 2983 batches | lr 0.001000 | ms/batch 24.92 | loss  7.09 | ppl  1196.80
| epoch   1 |  1000/ 2983 batches | lr 0.001000 | ms/batch 24.99 | loss  7.10 | ppl  1214.24
| epoch   1 |  1200/ 2983 batches | lr 0.001000 | ms/batch 25.00 | loss  7.11 | ppl  1228.74
| epoch   1 |  1400/ 2983 batches | lr 0.001000 | ms/batch 25.00 | loss  6.94 | ppl  1031.50
| epoch   1 |  1600/ 2983 batches | lr 0.001000 | ms/batch 25.05 | loss  6.63 | ppl   758.50
| epoch   1 |  1800/ 2983 batches | lr 0.001000 | ms/batch 25.11 | loss  6.46 | ppl   636.73
| epoch   1 |  2000/ 2983 batches | lr 0.001000 | ms/batch 25.14 | loss  6.40 | ppl   602.

In [None]:
print("Training Transformer on WikiText-2...")
train_model(
    model_type='Transformer',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    emsize=512,
    nhid=2048,
    nlayers=6,
    epochs=60,
    lr=0.0003,
    batch_size = 32,
    bptt = 50,
    dropout = 0.1,
    nhead = 8,
    save_path='model_1.pt',
    use_optimizer = True,
    optimizer_type = 'AdamW',
    weight_decay=1e-4,
    use_betas=True,
    betas=(0.9, 0.98),
    use_eps=True,
    eps=1e-9,
    criterion = nn.NLLLoss(),
    use_label_smoothing = False,
    use_warmup = False,
    seed = 1111,
    old_version = True
)

print("\nGenerating text...")
generate_text(
    checkpoint='model_1.pt',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    words=1000,
    temperature=1.0,
    seed=1111,
    old_version=True,
    use_top_k=False,
    accel = True
)

In [47]:
print("Training Transformer on WikiText-2...")
train_model(
    model_type='Transformer',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    emsize=512,
    nhid=2048,
    nlayers=6,
    lr=0.0001,
    clip=0.25,
    epochs=60,
    batch_size=32,
    bptt=50,
    dropout=0.1,
    tied=False,
    nhead=8,
    log_interval=200,
    save_path='model_2.pt',
    dry_run=False,
    accel=True,
    use_optimizer=True,
    optimizer_type='AdamW',
    weight_decay=1e-4,
    use_betas=True,
    betas=(0.9, 0.98),
    use_eps=True,
    eps=1e-9,
    criterion=None,
    use_label_smoothing=True,
    label_smoothing=0.1,
    use_warmup=True,
    warmup_steps=4000,
    min_freq=5,
    seed=1111,
    old_version=False
    # batch_size=64,  # Увеличенный батч для лучшей сходимости
    # bptt=64,     # Увеличенная длина последовательности
)

print("\nGenerating text...")
generate_text(
    checkpoint='model_2.pt',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    outf='generated_2.txt',
    words=1000,
    temperature=1.0,
    top_k=40,
    seed=1111,
    log_interval=100,
    use_top_k=True,  # Use top-k sampling
    min_freq=5
)

!cat generated_2.txt

Training Transformer on WikiText-2...


AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [56]:
# ===============================
# DEBUG SETUP
# ===============================
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Synchronous CUDA errors
os.environ['TORCH_USE_CUDA_DSA'] = '1'    # Device-side asserts
print("Training Transformer on WikiText-2...")
train_model(
    model_type='Transformer',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    emsize=512,
    nhid=2048,
    nlayers=6,
    lr=0.0001,
    clip=0.25,
    epochs=1,
    batch_size=32,
    bptt=50,
    dropout=0.1,
    tied=False,
    nhead=8,
    log_interval=200,
    save_path='model_3.pt',
    dry_run=False,
    accel=True,
    use_optimizer=True,
    optimizer_type='AdamW',
    weight_decay=1e-4,
    use_betas=True,
    betas=(0.9, 0.98),
    use_eps=True,
    eps=1e-9,
    criterion=None,
    use_label_smoothing=True,
    label_smoothing=0.1,
    use_warmup=True,
    warmup_steps=4000,
    min_freq=5,
    seed=1111,
    old_version=False
    # batch_size=64,  # Увеличенный батч для лучшей сходимости
    # bptt=64,     # Увеличенная длина последовательности
)

print("\nGenerating text...")
generate_text(
    checkpoint='model_3.pt',
    data_path='/content/drive/MyDrive/data_word_train/wikitext-2',
    outf='generated_3.txt',
    words=1000,
    temperature=1.0,
    top_k=40,
    seed=1111,
    log_interval=100,
    use_top_k=True,  # Use top-k sampling
    min_freq=5,
    old_version=False
)

!cat generated_2.txt

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


### MAIN PART

In [None]:
# Step 4: Data Preparation for Custom Data (English-French dataset as example)
# Extract French sentences and create train/valid/test.txt
DATA_PATH = '/content/drive/MyDrive/data_word_train/custom'
INPUT_FILE = '/content/drive/MyDrive/data/eng-fra.txt'
os.makedirs(DATA_PATH, exist_ok=True)
random.seed(1111)

with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    lines = f.readlines()
french_sentences = [line.strip().split('\t')[1] for line in lines if len(line.strip().split('\t')) == 2]

random.shuffle(french_sentences)
n = len(french_sentences)
train_end = int(n * 0.8)
valid_end = train_end + int(n * 0.1)
train_data = french_sentences[:train_end]
valid_data = french_sentences[train_end:valid_end]
test_data = french_sentences[valid_end:]

def save_sentences(sentences, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        for sentence in sentences:
            sentence = sentence.replace('.', ' .').replace(',', ' ,').replace('!', ' !').replace('?', ' ?')
            f.write(sentence + ' <eos>\n')

save_sentences(train_data, os.path.join(DATA_PATH, 'train.txt'))
save_sentences(valid_data, os.path.join(DATA_PATH, 'valid.txt'))
save_sentences(test_data, os.path.join(DATA_PATH, 'test.txt'))

print(f"Created datasets: {len(train_data)} train, {len(valid_data)} valid, {len(test_data)} test sentences")

# Step 5: Training Function from main.py (adapted for Colab)
# Define args as a class for Colab
class Args:
    def __init__(self):
        self.data = '/content/drive/MyDrive/data_word_train/custom'  # Custom data path
        self.model = 'LSTM'  # RNN_TANH, RNN_RELU, LSTM, GRU, Transformer
        self.emsize = 200
        self.nhid = 200
        self.nlayers = 2
        self.lr = 0.001
        self.clip = 0.25
        self.epochs = 20  # Reduced for faster training
        self.batch_size = 20
        self.bptt = 35
        self.dropout = 0.2
        self.tied = False
        self.seed = 1111
        self.log_interval = 200
        self.save = 'model.pt'
        self.onnx_export = ''
        self.nhead = 2
        self.dry_run = False
        self.accel = True
        self.use_optimizer = True  # Use AdamW

args = Args()

# Set the random seed manually for reproducibility.

torch.manual_seed(args.seed)

if args.accel and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)


###############################################################################
# Load data
###############################################################################

corpus = Corpus(args.data)

# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)
if args.model == 'Transformer':
    model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
else:
    model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)

criterion = nn.NLLLoss()
if args.use_optimizer:
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i):
    seq_len = min(args.bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

def evaluate(data_source):
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    if args.model != 'Transformer':
        hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = get_batch(data_source, i)
            if args.model == 'Transformer':
                output = model(data)
                output = output.view(-1, ntokens)
            else:
                output, hidden = model(data, hidden)
                hidden = repackage_hidden(hidden)
            total_loss += len(data) * criterion(output, targets).item()
    return total_loss / (len(data_source) - 1)

def train_func():
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    if args.model != 'Transformer':
        hidden = model.init_hidden(args.batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad() if args.use_optimizer else model.zero_grad()
        if args.model == 'Transformer':
            output = model(data)
            output = output.view(-1, ntokens)
        else:
            hidden = repackage_hidden(hidden)
            output, hidden = model(data, hidden)
        loss = criterion(output, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        if args.use_optimizer:
            optimizer.step()
        else:
            for p in model.parameters():
                p.data.add_(p.grad, alpha=-lr)

        total_loss += loss.item()

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, lr,
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        if args.dry_run:
            break

lr = args.lr
best_val_loss = None

try:
    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train_func()
        val_loss = evaluate(val_data)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
        print('-' * 89)
        if not best_val_loss or val_loss < best_val_loss:
            with open(args.save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            lr /= 4.0
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

# Load the best saved model.
with open(args.save, 'rb') as f:
    model = torch.load(f)

# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

# Step 6: Generate Text from generate.py (adapted for Colab)
checkpoint = 'model.pt'  # Your saved model
outf = 'generated.txt'
words = 1000
temperature = 1.0
log_interval = 100
accel = True  # Use CUDA

torch.manual_seed(1111)

if accel and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

with open(checkpoint, 'rb') as f:
    model = torch.load(f, map_location=device)
model.eval()

corpus = Corpus(args.data)
ntokens = len(corpus.dictionary)

is_transformer_model = hasattr(model, 'model_type') and model.model_type == 'Transformer'
if not is_transformer_model:
    hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

with open(outf, 'w') as outf:
    with torch.no_grad():
        for i in range(words):
            if is_transformer_model:
                output = model(input, False)
                word_weights = output[-1].squeeze().div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                input = torch.cat([input, word_tensor], 0)
            else:
                output, hidden = model(input, hidden)
                word_weights = output.squeeze().div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                input.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(word + ('\n' if i % 20 == 19 else ' '))

            if i % log_interval == 0:
                print('| Generated {}/{} words'.format(i, words))

# Print generated text
!head -n 20 generated.txt

### Names dataset learn

In [None]:
"""
Improved Word-Level Language Modeling with Full Type Hints
Based on PyTorch RNN/Transformer example with enhancements
"""

import os
import math
import time
from io import open
from typing import List, Dict, Tuple, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

# ===============================
# DATA PREPARATION (data.py)
# ===============================

class Dictionary:
    """Dictionary for word-to-index mapping."""

    def __init__(self) -> None:
        self.word2idx: Dict[str, int] = {}
        self.idx2word: List[str] = []

    def add_word(self, word: str) -> int:
        """Add a word to the dictionary."""
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self) -> int:
        return len(self.idx2word)


class Corpus:
    """Corpus class for loading and tokenizing text data."""

    def __init__(self, path: str) -> None:
        self.dictionary: Dictionary = Dictionary()
        self.train: torch.Tensor = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid: torch.Tensor = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test: torch.Tensor = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path: str) -> torch.Tensor:
        """Tokenizes a text file."""
        assert os.path.exists(path), f"Path {path} does not exist"

        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words: List[str] = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss: List[torch.Tensor] = []
            for line in f:
                words: List[str] = line.split() + ['<eos>']
                ids: List[int] = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids, dtype=torch.int64))
            ids_tensor: torch.Tensor = torch.cat(idss)

        return ids_tensor


# ===============================
# MODEL ARCHITECTURES (model.py)
# ===============================

class PositionalEncoding(nn.Module):
    """Positional encoding for Transformer."""

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        self.dropout: nn.Dropout = nn.Dropout(p=dropout)

        pe: torch.Tensor = torch.zeros(max_len, d_model)
        position: torch.Tensor = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term: torch.Tensor = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class RNNModel(nn.Module):
    """RNN-based language model (LSTM/GRU/RNN)."""

    def __init__(
        self,
        rnn_type: str,
        ntoken: int,
        ninp: int,
        nhid: int,
        nlayers: int,
        dropout: float = 0.5,
        tie_weights: bool = False
    ) -> None:
        super(RNNModel, self).__init__()
        self.ntoken: int = ntoken
        self.rnn_type: str = rnn_type
        self.nhid: int = nhid
        self.nlayers: int = nlayers

        self.drop: nn.Dropout = nn.Dropout(dropout)
        self.encoder: nn.Embedding = nn.Embedding(ntoken, ninp)

        if rnn_type in ['LSTM', 'GRU']:
            self.rnn: nn.Module = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            try:
                nonlinearity: str = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            except KeyError as e:
                raise ValueError(
                    "Invalid option for `--model`. "
                    "Options are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU']"
                ) from e
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)

        self.decoder: nn.Linear = nn.Linear(nhid, ntoken)

        # Tie weights
        if tie_weights:
            if nhid != ninp:
                raise ValueError('When using tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight

        self.init_weights()

    def init_weights(self) -> None:
        """Initialize weights."""
        initrange: float = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        emb: torch.Tensor = self.drop(self.encoder(input))
        output: torch.Tensor
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded: torch.Tensor = self.decoder(output)
        decoded = decoded.view(-1, self.ntoken)
        return F.log_softmax(decoded, dim=1), hidden

    def init_hidden(self, bsz: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Initialize hidden state."""
        weight: torch.Tensor = next(self.parameters())
        if self.rnn_type == 'LSTM':
            return (
                weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid)
            )
        else:
            return weight.new_zeros(self.nlayers, bsz, self.nhid)


class TransformerModel(nn.Transformer):
    """Transformer-based language model."""

    def __init__(
        self,
        ntoken: int,
        ninp: int,
        nhead: int,
        nhid: int,
        nlayers: int,
        dropout: float = 0.5
    ) -> None:
        super(TransformerModel, self).__init__(
            d_model=ninp,
            nhead=nhead,
            dim_feedforward=nhid,
            num_encoder_layers=nlayers
        )
        self.model_type: str = 'Transformer'
        self.src_mask: Optional[torch.Tensor] = None
        self.pos_encoder: PositionalEncoding = PositionalEncoding(ninp, dropout)

        self.input_emb: nn.Embedding = nn.Embedding(ntoken, ninp)
        self.ninp: int = ninp
        self.decoder: nn.Linear = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """Generate mask for causal attention."""
        return torch.log(torch.tril(torch.ones(sz, sz)))

    def init_weights(self) -> None:
        """Initialize weights."""
        initrange: float = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src: torch.Tensor, has_mask: bool = True) -> torch.Tensor:
        if has_mask:
            device: torch.device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask: torch.Tensor = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output: torch.Tensor = self.encoder(src, mask=self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)


# ===============================
# TRAINING & EVALUATION (main.py)
# ===============================

def batchify(data: torch.Tensor, bsz: int, device: torch.device) -> torch.Tensor:
    """Divide data into batches."""
    nbatch: int = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def get_batch(source: torch.Tensor, i: int, bptt: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get a batch of data."""
    seq_len: int = min(bptt, len(source) - 1 - i)
    data: torch.Tensor = source[i:i+seq_len]
    target: torch.Tensor = source[i+1:i+1+seq_len].view(-1)
    return data, target


def repackage_hidden(h: Union[torch.Tensor, Tuple]) -> Union[torch.Tensor, Tuple]:
    """Detach hidden state from history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


def evaluate(
    model: nn.Module,
    data_source: torch.Tensor,
    criterion: nn.Module,
    bptt: int,
    ntokens: int,
    eval_batch_size: int,
    is_transformer: bool
) -> float:
    """Evaluate the model."""
    model.eval()
    total_loss: float = 0.0
    if not is_transformer:
        hidden = model.init_hidden(eval_batch_size)

    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data: torch.Tensor
            targets: torch.Tensor
            data, targets = get_batch(data_source, i, bptt)

            if is_transformer:
                output: torch.Tensor = model(data)
                output = output.view(-1, ntokens)
            else:
                output, hidden = model(data, hidden)
                hidden = repackage_hidden(hidden)

            total_loss += len(data) * criterion(output, targets).item()

    return total_loss / (len(data_source) - 1)


def train_epoch(
    model: nn.Module,
    train_data: torch.Tensor,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    epoch: int,
    bptt: int,
    ntokens: int,
    batch_size: int,
    clip: float,
    log_interval: int,
    is_transformer: bool
) -> None:
    """Train for one epoch."""
    model.train()
    total_loss: float = 0.0
    start_time: float = time.time()

    if not is_transformer:
        hidden = model.init_hidden(batch_size)

    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data: torch.Tensor
        targets: torch.Tensor
        data, targets = get_batch(train_data, i, bptt)

        optimizer.zero_grad()

        if is_transformer:
            output: torch.Tensor = model(data)
            output = output.view(-1, ntokens)
        else:
            hidden = repackage_hidden(hidden)
            output, hidden = model(data, hidden)

        loss: torch.Tensor = criterion(output, targets)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss: float = total_loss / log_interval
            elapsed: float = time.time() - start_time
            print(
                f'| epoch {epoch:3d} | {batch:5d}/{len(train_data) // bptt:5d} batches | '
                f'ms/batch {elapsed * 1000 / log_interval:5.2f} | '
                f'loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}'
            )
            total_loss = 0
            start_time = time.time()


def train_model(
    model_type: str = 'LSTM',
    data_path: str = './data/wikitext-2',
    emsize: int = 200,
    nhid: int = 200,
    nlayers: int = 2,
    lr: float = 0.001,
    clip: float = 0.25,
    epochs: int = 40,
    batch_size: int = 20,
    bptt: int = 35,
    dropout: float = 0.2,
    tied: bool = False,
    nhead: int = 2,
    log_interval: int = 200,
    save_path: str = 'model.pt',
    use_cuda: bool = True
) -> None:
    """Main training function."""

    # Set device
    device: torch.device = torch.device('cuda' if use_cuda and torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    corpus: Corpus = Corpus(data_path)

    eval_batch_size: int = 10
    train_data: torch.Tensor = batchify(corpus.train, batch_size, device)
    val_data: torch.Tensor = batchify(corpus.valid, eval_batch_size, device)
    test_data: torch.Tensor = batchify(corpus.test, eval_batch_size, device)

    # Build model
    ntokens: int = len(corpus.dictionary)
    is_transformer: bool = model_type == 'Transformer'

    if is_transformer:
        model: nn.Module = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
    else:
        model = RNNModel(model_type, ntokens, emsize, nhid, nlayers, dropout, tied).to(device)

    # Loss and optimizer (Adam with weight_decay as in Transformer paper)
    criterion: nn.NLLLoss = nn.NLLLoss()
    optimizer: optim.Adam = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler: optim.lr_scheduler.ReduceLROnPlateau = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )

    # Training loop
    best_val_loss: Optional[float] = None

    try:
        for epoch in range(1, epochs + 1):
            epoch_start_time: float = time.time()

            train_epoch(
                model, train_data, criterion, optimizer, epoch,
                bptt, ntokens, batch_size, clip, log_interval, is_transformer
            )

            val_loss: float = evaluate(
                model, val_data, criterion, bptt, ntokens,
                eval_batch_size, is_transformer
            )

            print('-' * 89)
            print(
                f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | '
                f'valid loss {val_loss:5.2f} | valid ppl {math.exp(val_loss):8.2f}'
            )
            print('-' * 89)

            # Save best model
            if not best_val_loss or val_loss < best_val_loss:
                with open(save_path, 'wb') as f:
                    torch.save(model, f)
                best_val_loss = val_loss

            # Learning rate scheduling
            scheduler.step(val_loss)
            print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    # Load best model and test
    with open(save_path, 'rb') as f:
        model = torch.load(f, map_location=device)

    test_loss: float = evaluate(
        model, test_data, criterion, bptt, ntokens,
        eval_batch_size, is_transformer
    )

    print('=' * 89)
    print(
        f'| End of training | test loss {test_loss:5.2f} | '
        f'test ppl {math.exp(test_loss):8.2f}'
    )
    print('=' * 89)


# ===============================
# TEXT GENERATION (generate.py)
# ===============================

def generate_text(
    checkpoint: str = 'model.pt',
    data_path: str = './data/wikitext-2',
    outf: str = 'generated.txt',
    words: int = 1000,
    temperature: float = 1.0,
    seed: int = 1111,
    log_interval: int = 100,
    use_cuda: bool = True
) -> None:
    """Generate text from trained model."""

    torch.manual_seed(seed)
    device: torch.device = torch.device('cuda' if use_cuda and torch.cuda.is_available() else 'cpu')

    # Load model
    with open(checkpoint, 'rb') as f:
        model: nn.Module = torch.load(f, map_location=device)
    model.eval()

    # Load corpus
    corpus: Corpus = Corpus(data_path)
    ntokens: int = len(corpus.dictionary)

    is_transformer: bool = hasattr(model, 'model_type') and model.model_type == 'Transformer'
    if not is_transformer:
        hidden = model.init_hidden(1)

    input: torch.Tensor = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

    with open(outf, 'w') as outfile:
        with torch.no_grad():
            for i in range(words):
                if is_transformer:
                    output: torch.Tensor = model(input, False)
                    word_weights: torch.Tensor = output[-1].squeeze().div(temperature).exp().cpu()
                    word_idx: int = torch.multinomial(word_weights, 1)[0].item()
                    word_tensor: torch.Tensor = torch.Tensor([[word_idx]]).long().to(device)
                    input = torch.cat([input, word_tensor], 0)
                else:
                    output, hidden = model(input, hidden)
                    word_weights = output.squeeze().div(temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0].item()
                    input.fill_(word_idx)

                word: str = corpus.dictionary.idx2word[word_idx]
                outfile.write(word + ('\n' if i % 20 == 19 else ' '))

                if i % log_interval == 0:
                    print(f'| Generated {i}/{words} words')


# ===============================
# EXAMPLE USAGE
# ===============================

if __name__ == '__main__':
    # Example 1: Train on WikiText-2
    print("Training LSTM on WikiText-2...")
    train_model(
        model_type='LSTM',
        data_path='./data/wikitext-2',
        emsize=400,
        nhid=400,
        nlayers=4,
        epochs=40,
        lr=0.001
    )

    # Example 2: Generate text
    print("\nGenerating text...")
    generate_text(
        checkpoint='model.pt',
        data_path='./data/wikitext-2',
        words=1000,
        temperature=1.0
    )

    # Example 3: Train on custom names dataset
    # First, create the data files (see instructions below)
    print("\nTraining on custom names dataset...")
    # train_model(
    #     model_type='LSTM',
    #     data_path='./data/names',
    #     emsize=128,
    #     nhid=128,
    #     nlayers=2,
    #     epochs=20,
    #     lr=0.001
    # )

In [None]:
"""
Script to prepare names dataset for Word-Level Language Modeling
Converts multiple text files with names into train/valid/test splits
"""

import os
import random
from typing import List, Tuple
from pathlib import Path


def read_names_from_files(data_dir: str) -> List[str]:
    """
    Read all names from text files in the directory.

    Args:
        data_dir: Directory containing .txt files with names

    Returns:
        List of all names
    """
    all_names: List[str] = []

    for filename in Path(data_dir).glob('*.txt'):
        print(f"Reading {filename.name}...")
        with open(filename, 'r', encoding='utf-8') as f:
            names: List[str] = [line.strip() for line in f if line.strip()]
            all_names.extend(names)
            print(f"  Found {len(names)} names")

    print(f"\nTotal names: {len(all_names)}")
    return all_names


def create_train_valid_test_splits(
    names: List[str],
    train_ratio: float = 0.8,
    valid_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42
) -> Tuple[List[str], List[str], List[str]]:
    """
    Split names into train, validation, and test sets.

    Args:
        names: List of names
        train_ratio: Proportion for training (default 0.8)
        valid_ratio: Proportion for validation (default 0.1)
        test_ratio: Proportion for testing (default 0.1)
        seed: Random seed for reproducibility

    Returns:
        Tuple of (train_names, valid_names, test_names)
    """
    assert abs(train_ratio + valid_ratio + test_ratio - 1.0) < 1e-6, \
        "Ratios must sum to 1.0"

    random.seed(seed)
    names_copy: List[str] = names.copy()
    random.shuffle(names_copy)

    n: int = len(names_copy)
    train_end: int = int(n * train_ratio)
    valid_end: int = train_end + int(n * valid_ratio)

    train_names: List[str] = names_copy[:train_end]
    valid_names: List[str] = names_copy[train_end:valid_end]
    test_names: List[str] = names_copy[valid_end:]

    print(f"\nSplit sizes:")
    print(f"  Train: {len(train_names)} ({len(train_names)/n*100:.1f}%)")
    print(f"  Valid: {len(valid_names)} ({len(valid_names)/n*100:.1f}%)")
    print(f"  Test:  {len(test_names)} ({len(test_names)/n*100:.1f}%)")

    return train_names, valid_names, test_names


def save_names_to_file(names: List[str], filename: str) -> None:
    """
    Save names to file, one per line.

    Args:
        names: List of names
        filename: Output filename
    """
    with open(filename, 'w', encoding='utf-8') as f:
        for name in names:
            # Write each name as a separate "sentence"
            # The model will add <eos> automatically
            f.write(name + '\n')
    print(f"Saved {len(names)} names to {filename}")


def prepare_names_dataset(
    input_dir: str,
    output_dir: str,
    train_ratio: float = 0.8,
    valid_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42
) -> None:
    """
    Main function to prepare names dataset.

    Args:
        input_dir: Directory containing raw name files
        output_dir: Directory to save train/valid/test files
        train_ratio: Proportion for training
        valid_ratio: Proportion for validation
        test_ratio: Proportion for testing
        seed: Random seed
    """
    print("="*60)
    print("Preparing Names Dataset for Language Modeling")
    print("="*60)

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Read all names
    all_names: List[str] = read_names_from_files(input_dir)

    # Split into train/valid/test
    train_names, valid_names, test_names = create_train_valid_test_splits(
        all_names, train_ratio, valid_ratio, test_ratio, seed
    )

    # Save splits
    print("\nSaving splits...")
    save_names_to_file(train_names, os.path.join(output_dir, 'train.txt'))
    save_names_to_file(valid_names, os.path.join(output_dir, 'valid.txt'))
    save_names_to_file(test_names, os.path.join(output_dir, 'test.txt'))

    print("\n" + "="*60)
    print("Dataset preparation complete!")
    print(f"Output directory: {output_dir}")
    print("="*60)

    # Print sample names
    print("\nSample names from train set:")
    for name in train_names[:10]:
        print(f"  {name}")


# ===============================
# ALTERNATIVE: Word-level format
# ===============================

def prepare_wordlevel_format(
    input_dir: str,
    output_dir: str,
    train_ratio: float = 0.8,
    valid_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42
) -> None:
    """
    Prepare dataset where each character is a 'word'.
    This allows character-level language modeling using word LM code.

    Args:
        input_dir: Directory containing raw name files
        output_dir: Directory to save train/valid/test files
        train_ratio: Proportion for training
        valid_ratio: Proportion for validation
        test_ratio: Proportion for testing
        seed: Random seed
    """
    print("="*60)
    print("Preparing Character-Level (as words) Dataset")
    print("="*60)

    os.makedirs(output_dir, exist_ok=True)

    # Read all names
    all_names: List[str] = read_names_from_files(input_dir)

    # Split
    train_names, valid_names, test_names = create_train_valid_test_splits(
        all_names, train_ratio, valid_ratio, test_ratio, seed
    )

    # Convert to character-level
    def names_to_char_words(names: List[str], filename: str) -> None:
        """Convert names to space-separated characters."""
        with open(filename, 'w', encoding='utf-8') as f:
            for name in names:
                # Each character becomes a "word"
                char_sequence: str = ' '.join(list(name))
                f.write(char_sequence + '\n')
        print(f"Saved {len(names)} names (as char sequences) to {filename}")

    print("\nSaving character-level splits...")
    names_to_char_words(train_names, os.path.join(output_dir, 'train.txt'))
    names_to_char_words(valid_names, os.path.join(output_dir, 'valid.txt'))
    names_to_char_words(test_names, os.path.join(output_dir, 'test.txt'))

    print("\n" + "="*60)
    print("Character-level dataset preparation complete!")
    print(f"Output directory: {output_dir}")
    print("="*60)

    # Print sample
    print("\nSample character sequences from train set:")
    for name in train_names[:5]:
        char_seq: str = ' '.join(list(name))
        print(f"  {name} -> {char_seq}")


# ===============================
# EXAMPLE USAGE
# ===============================

if __name__ == '__main__':
    # Example 1: Word-level (each name is a word)
    # Best for: learning name distributions
    prepare_names_dataset(
        input_dir='data/names_raw',  # Your directory with German.txt, Russian.txt, etc.
        output_dir='data/names_word',
        train_ratio=0.8,
        valid_ratio=0.1,
        test_ratio=0.1
    )

    # Example 2: Character-level (each character is a word)
    # Best for: generating new names character-by-character
    prepare_wordlevel_format(
        input_dir='data/names_raw',
        output_dir='data/names_char',
        train_ratio=0.8,
        valid_ratio=0.1,
        test_ratio=0.1
    )

    print("\n" + "="*60)
    print("NEXT STEPS:")
    print("="*60)
    print("1. Use 'data/names_word' for name-level modeling")
    print("2. Use 'data/names_char' for character-level modeling")
    print("\nTo train:")
    print("  python improved_word_lm.py --data data/names_word")
    print("\nTo generate:")
    print("  python improved_word_lm.py --generate --checkpoint model.pt")

In [None]:
"""
Complete example with training, evaluation, and improvements
for Word-Level Language Modeling on Names Dataset
"""

import os
import math
import time
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim

# Assuming improved_word_lm.py is imported
from improved_word_lm import (
    Corpus, RNNModel, TransformerModel,
    batchify, evaluate, train_epoch
)


# ===============================
# EXPERIMENT 1: Compare Architectures
# ===============================

def compare_architectures(
    data_path: str = './data/names_char',
    epochs: int = 20,
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
) -> Dict[str, Dict]:
    """Compare different model architectures on names dataset."""

    print("="*80)
    print("EXPERIMENT 1: Comparing Model Architectures")
    print("="*80)

    # Load data
    corpus: Corpus = Corpus(data_path)
    ntokens: int = len(corpus.dictionary)

    batch_size: int = 20
    eval_batch_size: int = 10
    bptt: int = 35

    train_data: torch.Tensor = batchify(corpus.train, batch_size, device)
    val_data: torch.Tensor = batchify(corpus.valid, eval_batch_size, device)
    test_data: torch.Tensor = batchify(corpus.test, eval_batch_size, device)

    # Define architectures to test
    architectures: Dict[str, Dict] = {
        'LSTM-Small': {
            'model_type': 'LSTM',
            'emsize': 128,
            'nhid': 128,
            'nlayers': 2,
            'dropout': 0.2
        },
        'LSTM-Large': {
            'model_type': 'LSTM',
            'emsize': 256,
            'nhid': 256,
            'nlayers': 3,
            'dropout': 0.3
        },
        'GRU': {
            'model_type': 'GRU',
            'emsize': 200,
            'nhid': 200,
            'nlayers': 2,
            'dropout': 0.2
        },
        'Transformer': {
            'model_type': 'Transformer',
            'emsize': 200,
            'nhid': 200,
            'nlayers': 2,
            'nhead': 2,
            'dropout': 0.2
        }
    }

    results: Dict[str, Dict] = {}

    for name, config in architectures.items():
        print(f"\n{'='*80}")
        print(f"Training {name}")
        print(f"{'='*80}")

        # Create model
        is_transformer: bool = config['model_type'] == 'Transformer'

        if is_transformer:
            model: nn.Module = TransformerModel(
                ntokens,
                config['emsize'],
                config['nhead'],
                config['nhid'],
                config['nlayers'],
                config['dropout']
            ).to(device)
        else:
            model = RNNModel(
                config['model_type'],
                ntokens,
                config['emsize'],
                config['nhid'],
                config['nlayers'],
                config['dropout'],
                tie_weights=False
            ).to(device)

        # Optimizer with weight decay (from Transformer paper)
        criterion: nn.NLLLoss = nn.NLLLoss()
        optimizer: optim.Adam = optim.Adam(
            model.parameters(),
            lr=0.001,
            weight_decay=1e-5
        )

        # Training
        train_losses: List[float] = []
        val_losses: List[float] = []
        best_val_loss: float = float('inf')

        start_time: float = time.time()

        for epoch in range(1, epochs + 1):
            train_epoch(
                model, train_data, criterion, optimizer, epoch,
                bptt, ntokens, batch_size, 0.25, 100, is_transformer
            )

            val_loss: float = evaluate(
                model, val_data, criterion, bptt, ntokens,
                eval_batch_size, is_transformer
            )

            train_losses.append(criterion(
                model(train_data[:100])[0] if is_transformer else model(train_data[:100], model.init_hidden(batch_size))[0],
                train_data[1:101].view(-1)
            ).item())
            val_losses.append(val_loss)

            print(f'Epoch {epoch:3d} | val loss {val_loss:5.2f} | val ppl {math.exp(val_loss):8.2f}')

            best_val_loss = min(best_val_loss, val_loss)

        training_time: float = time.time() - start_time

        # Test
        test_loss: float = evaluate(
            model, test_data, criterion, bptt, ntokens,
            eval_batch_size, is_transformer
        )

        results[name] = {
            'model': model,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'best_val_loss': best_val_loss,
            'test_loss': test_loss,
            'test_ppl': math.exp(test_loss),
            'training_time': training_time,
            'config': config
        }

        print(f"\n{name} Results:")
        print(f"  Best Val Loss: {best_val_loss:.4f}")
        print(f"  Test Loss: {test_loss:.4f}")
        print(f"  Test PPL: {math.exp(test_loss):.2f}")
        print(f"  Training Time: {training_time:.2f}s")

    return results


# ===============================
# EXPERIMENT 2: Hyperparameter Tuning
# ===============================

def tune_hyperparameters(
    data_path: str = './data/names_char',
    base_config: Dict = None
) -> Dict[str, List]:
    """Tune hyperparameters for best model."""

    print("="*80)
    print("EXPERIMENT 2: Hyperparameter Tuning")
    print("="*80)

    if base_config is None:
        base_config = {
            'model_type': 'LSTM',
            'emsize': 200,
            'nhid': 200,
            'nlayers': 2,
            'dropout': 0.2,
            'lr': 0.001,
            'batch_size': 20,
            'epochs': 15
        }

    # Parameters to tune
    param_grid: Dict[str, List] = {
        'lr': [0.0001, 0.001, 0.01],
        'dropout': [0.1, 0.2, 0.3],
        'nhid': [128, 200, 256],
        'nlayers': [2, 3, 4]
    }

    results: Dict[str, List] = {}

    for param_name, param_values in param_grid.items():
        print(f"\nTuning {param_name}...")
        results[param_name] = []

        for value in param_values:
            config: Dict = base_config.copy()
            config[param_name] = value

            print(f"  Testing {param_name}={value}")

            # Train and evaluate (simplified version)
            # In practice, you'd call the full training function
            test_ppl: float = train_and_evaluate(data_path, config)

            results[param_name].append({
                'value': value,
                'test_ppl': test_ppl
            })

            print(f"    Test PPL: {test_ppl:.2f}")

    return results


def train_and_evaluate(data_path: str, config: Dict) -> float:
    """Simplified train and evaluate function."""
    # Placeholder - implement actual training
    # This is a simplified version
    return 10.0  # Return test perplexity


# ===============================
# EXPERIMENT 3: Learning Curves
# ===============================

def analyze_learning_curves(results: Dict[str, Dict]) -> None:
    """Plot and analyze learning curves."""

    print("="*80)
    print("EXPERIMENT 3: Learning Curve Analysis")
    print("="*80)

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    for idx, (name, result) in enumerate(results.items()):
        row: int = idx // 2
        col: int = idx % 2

        ax = axes[row, col]

        epochs: List[int] = list(range(1, len(result['train_losses']) + 1))
        ax.plot(epochs, result['train_losses'], label='Train Loss', marker='o')
        ax.plot(epochs, result['val_losses'], label='Val Loss', marker='s')

        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title(f'{name} Learning Curves')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('learning_curves.png', dpi=150)
    plt.close()

    print("Learning curves saved to 'learning_curves.png'")


# ===============================
# EXPERIMENT 4: Generate & Analyze Names
# ===============================

def generate_and_analyze_names(
    model: nn.Module,
    corpus: Corpus,
    n_samples: int = 100,
    temperature: float = 1.0,
    device: torch.device = torch.device('cpu')
) -> List[str]:
    """Generate names and analyze quality."""

    print("="*80)
    print("EXPERIMENT 4: Name Generation & Analysis")
    print("="*80)

    model.eval()
    ntokens: int = len(corpus.dictionary)
    is_transformer: bool = hasattr(model, 'model_type')

    generated_names: List[str] = []

    with torch.no_grad():
        for i in range(n_samples):
            if not is_transformer:
                hidden = model.init_hidden(1)

            # Start with random character
            input_tensor: torch.Tensor = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
            name_chars: List[str] = []

            max_len: int = 20
            for _ in range(max_len):
                if is_transformer:
                    output: torch.Tensor = model(input_tensor, False)
                    word_weights: torch.Tensor = output[-1].squeeze().div(temperature).exp().cpu()
                else:
                    output, hidden = model(input_tensor, hidden)
                    word_weights = output.squeeze().div(temperature).exp().cpu()

                word_idx: int = torch.multinomial(word_weights, 1)[0].item()
                word: str = corpus.dictionary.idx2word[word_idx]

                if word == '<eos>':
                    break

                name_chars.append(word)
                input_tensor.fill_(word_idx)

            generated_name: str = ''.join(name_chars)
            if generated_name:  # Only add non-empty names
                generated_names.append(generated_name)

    # Analysis
    print(f"\nGenerated {len(generated_names)} names")
    print(f"Sample names:")
    for name in generated_names[:20]:
        print(f"  {name}")

    # Statistics
    avg_length: float = sum(len(name) for name in generated_names) / len(generated_names)
    unique_names: int = len(set(generated_names))

    print(f"\nStatistics:")
    print(f"  Average length: {avg_length:.2f}")
    print(f"  Unique names: {unique_names} ({unique_names/len(generated_names)*100:.1f}%)")

    return generated_names


# ===============================
# EXPERIMENT 5: Temperature Sampling
# ===============================

def experiment_temperature(
    model: nn.Module,
    corpus: Corpus,
    temperatures: List[float] = [0.5, 0.8, 1.0, 1.2, 1.5],
    n_samples: int = 10,
    device: torch.device = torch.device('cpu')
) -> None:
    """Experiment with different temperature values."""

    print("="*80)
    print("EXPERIMENT 5: Temperature Sampling")
    print("="*80)

    for temp in temperatures:
        print(f"\nTemperature = {temp}")
        names: List[str] = generate_and_analyze_names(
            model, corpus, n_samples, temp, device
        )
        print(f"  Sample: {', '.join(names[:5])}")


# ===============================
# MAIN EXECUTION
# ===============================

def run_all_experiments(data_path: str = './data/names_char') -> None:
    """Run all experiments."""

    print("="*80)
    print("COMPLETE EXPERIMENTAL PIPELINE")
    print("Word-Level Language Modeling on Names Dataset")
    print("="*80)

    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Experiment 1: Compare architectures
    results: Dict = compare_architectures(data_path, epochs=20, device=device)

    # Experiment 3: Analyze learning curves
    analyze_learning_curves(results)

    # Find best model
    best_model_name: str = min(
        results.keys(),
        key=lambda x: results[x]['test_loss']
    )
    best_model: nn.Module = results[best_model_name]['model']

    print(f"\n{'='*80}")
    print(f"Best Model: {best_model_name}")
    print(f"Test PPL: {results[best_model_name]['test_ppl']:.2f}")
    print(f"{'='*80}")

    # Load corpus for generation
    corpus: Corpus = Corpus(data_path)

    # Experiment 4: Generate names
    generated_names: List[str] = generate_and_analyze_names(
        best_model, corpus, n_samples=100, device=device
    )

    # Experiment 5: Temperature sampling
    experiment_temperature(best_model, corpus, device=device)

    # Save results
    with open('experiment_results.txt', 'w') as f:
        f.write("EXPERIMENTAL RESULTS\n")
        f.write("="*80 + "\n\n")

        for name, result in results.items():
            f.write(f"{name}:\n")
            f.write(f"  Test Loss: {result['test_loss']:.4f}\n")
            f.write(f"  Test PPL: {result['test_ppl']:.2f}\n")
            f.write(f"  Training Time: {result['training_time']:.2f}s\n")
            f.write("\n")

        f.write("\nGenerated Names:\n")
        for name in generated_names[:50]:
            f.write(f"  {name}\n")

    print("\nResults saved to 'experiment_results.txt'")
    print("\nExperiments complete!")


if __name__ == '__main__':
    # Run all experiments
    run_all_experiments(data_path='./data/names_char')