In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import math
import matplotlib.pyplot as plt
from dataclasses import dataclass
from datasets import load_dataset
from tqdm import tqdm

@dataclass
class LCModelConfig:
    max_seq_len: int = 2048
    model_dim: int = 1024
    model_output_dim: int = 256  # Set to vocab_size for LM.
    frontend_dropout: float = 0.1
    decoder_layers: int = 6
    decoder_heads: int = 8
    decoder_ff_dim: int = 4096
    decoder_dropout: float = 0.1

# -----------------------
# 2. Frontend Module: Linear projection + Positional Embedding
# -----------------------
class SimpleFrontend(nn.Module):
    def __init__(self, input_dim, model_dim, max_seq_len, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(input_dim, model_dim)
        self.pos_embed = nn.Embedding(max_seq_len, model_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: [B, L, input_dim]
        B, L, _ = x.shape
        positions = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        x = self.proj(x) + self.pos_embed(positions)
        return self.dropout(x)

# -----------------------
# 3. Transformer Decoder using PyTorch's built-in modules
# -----------------------
class SimpleTransformerDecoder(nn.Module):
    def __init__(self, model_dim, num_layers, n_heads, ff_dim, dropout):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=model_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # tgt: [L, B, model_dim], memory: [S, B, model_dim]
        return self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)

# -----------------------
# 4. Postnet: Projection to desired output dimension
# -----------------------
class SimplePostnet(nn.Module):
    def __init__(self, model_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(model_dim, output_dim)
        
    def forward(self, x):
        return self.proj(x)

# -----------------------
# 5. LCModel: Combining Frontend, Decoder, and Postnet
# -----------------------
class LCModel(nn.Module):
    def __init__(self, config: LCModelConfig):
        super().__init__()
        self.config = config
        self.frontend = SimpleFrontend(
            input_dim=config.model_dim,
            model_dim=config.model_dim,
            max_seq_len=config.max_seq_len,
            dropout=config.frontend_dropout
        )
        self.decoder = SimpleTransformerDecoder(
            model_dim=config.model_dim,
            num_layers=config.decoder_layers,
            n_heads=config.decoder_heads,
            ff_dim=config.decoder_ff_dim,
            dropout=config.decoder_dropout
        )
        self.postnet = SimplePostnet(
            model_dim=config.model_dim,
            output_dim=config.model_output_dim
        )
        
    def forward(self, x):
        # x: [B, L, model_dim] (input embeddings)
        x = self.frontend(x)       # [B, L, model_dim]
        x = x.permute(1, 0, 2)     # [L, B, model_dim]
        out = self.decoder(x, x)   # For simplicity, using x as both target and memory (causal masking should be applied in practice)
        out = out.permute(1, 0, 2) # [B, L, model_dim]
        out = self.postnet(out)    # [B, L, model_output_dim]
        return out

# -----------------------
# 6. LCModelForLM: Wrap LCModel with a Token Embedding Layer for Language Modeling
# -----------------------
class LCModelForLM(nn.Module):
    def __init__(self, config: LCModelConfig, vocab_size: int):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, config.model_dim)
        self.lcmodel = LCModel(config)
        
    def forward(self, tokens):
        # tokens: [B, L] integer token indices
        embeddings = self.token_embed(tokens)  # [B, L, model_dim]
        logits = self.lcmodel(embeddings)        # [B, L, vocab_size]
        return logits

# -----------------------
# 7. WikiText Dataset for Byte-level LM
# -----------------------
class WikiTextDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len
        
    def __len__(self):
        return len(self.data) - self.seq_len
    
    def __getitem__(self, idx):
        input_seq = self.data[idx:idx+self.seq_len]
        target_seq = self.data[idx+1:idx+self.seq_len+1]
        return input_seq, target_seq

# -----------------------
# 8. Training Utilities
# -----------------------
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10,6))
    plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', marker='o')
    plt.plot(range(1, len(val_losses)+1), val_losses, label='Val Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.show()

def save_checkpoint(epoch, model, optimizer, train_loss, val_loss, best_val_loss, filepath="checkpoint.pth"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'best_val_loss': best_val_loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved at epoch {epoch}!")

def load_checkpoint(filepath="checkpoint.pth"):
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=device)
        print(f"Checkpoint loaded from epoch {checkpoint['epoch']}!")
        return checkpoint
    else:
        print("No checkpoint found!")
        return None

def generate_text_for_sample(model, input_seq, device, length, temperature=1.0):
    model.eval()
    generated_text = "".join([chr(x.item()) for x in input_seq[0]])
    for _ in range(length):
        with torch.no_grad():
            output = model(input_seq)
            logits = output[0, -1] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_byte = torch.multinomial(probs, num_samples=1).item()
            generated_text += chr(next_byte)
            input_seq = torch.cat([input_seq[:, 1:], torch.tensor([[next_byte]], device=device)], dim=1)
    return generated_text

def train_model(model, train_loader, val_loader, epochs, device, learning_rate):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    val_losses = []
    
    for epoch in range(epochs):
        model.train()
        tot_loss = 0.0
        print(f"Epoch {epoch+1}/{epochs}")
        for tokens, targets in tqdm(train_loader, desc="Training"):
            tokens = tokens.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            logits = model(tokens)  # [B, L, vocab_size]
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss.backward()
            optimizer.step()
            tot_loss += loss.item()
        avg_train_loss = tot_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        model.eval()
        tot_loss = 0.0
        with torch.no_grad():
            for tokens, targets in tqdm(val_loader, desc="Validation"):
                tokens = tokens.to(device)
                targets = targets.to(device)
                logits = model(tokens)
                loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
                tot_loss += loss.item()
        avg_val_loss = tot_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        perplexity = math.exp(avg_val_loss)
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}, Perplexity = {perplexity:.2f}")
        
        # Generate sample text at the end of each epoch.
        sample_prompt = "The "  # Common prompt
        sample_input = torch.tensor([ord(c) for c in sample_prompt], dtype=torch.long).unsqueeze(0).to(device)
        sample_generated = generate_text_for_sample(model, sample_input, device, length=seq_len//2, temperature=0.8)
        print("Sample Generated Text:")
        print(sample_generated)
        
        save_checkpoint(epoch+1, model, optimizer, train_losses, val_losses, avg_val_loss)
    
    plot_losses(train_losses, val_losses)
    return model

# -----------------------
# 9. Main Training Script
# -----------------------
if __name__ == '__main__':
    # Load WikiText-2 dataset (using a small subset for faster training)
    wiki = load_dataset("wikitext", "wikitext-2-raw-v1")
    # Select a small subset: 100 samples from train and 20 from validation.
    train_subset = wiki["train"].select(range(1000))
    val_subset = wiki["validation"].select(range(200))
    
    train_text = "\n".join(train_subset["text"])
    val_text = "\n".join(val_subset["text"])
    
    # Convert text to raw bytes using UTF-8 encoding
    train_data = torch.tensor(list(train_text.encode('utf-8')), dtype=torch.long)
    val_data = torch.tensor(list(val_text.encode('utf-8')), dtype=torch.long)

    seq_len = 128
    batch_size = 128
    
    train_dataset = WikiTextDataset(train_data, seq_len)
    val_dataset = WikiTextDataset(val_data, seq_len)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Create model configuration and instantiate LCModelForLM
    config = LCModelConfig(
        max_seq_len=128,
        model_dim=256,
        model_output_dim=256,
        frontend_dropout=0.1,
        decoder_layers=6,
        decoder_heads=4,
        decoder_ff_dim=1024,
        decoder_dropout=0.1
    )
    vocab_size = 256
    model_for_lm = LCModelForLM(config, vocab_size)
    
    # Train the model
    trained_model = train_model(
        model_for_lm,
        train_loader,
        val_loader,
        epochs=10,  # Use fewer epochs for quick training
        device="cuda",
        learning_rate=1e-4
    )


Epoch 1/10


Training: 100%|█████████████████████████████| 2257/2257 [07:25<00:00,  5.07it/s]
Validation: 100%|█████████████████████████████| 476/476 [00:26<00:00, 17.74it/s]


Epoch 1: Train Loss = 0.3795, Val Loss = 0.0211, Perplexity = 1.02
Sample Generated Text:
The eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
Checkpoint saved at epoch 1!
Epoch 2/10


Training: 100%|█████████████████████████████| 2257/2257 [07:24<00:00,  5.08it/s]
Validation: 100%|█████████████████████████████| 476/476 [00:26<00:00, 17.74it/s]


Epoch 2: Train Loss = 0.0188, Val Loss = 0.0177, Perplexity = 1.02
Sample Generated Text:
The eeeeeeeeeeeeeeeeeezezezezezezezezezezezezezezezezezezezezezezeze
Checkpoint saved at epoch 2!
Epoch 3/10


Training: 100%|█████████████████████████████| 2257/2257 [07:26<00:00,  5.06it/s]
Validation: 100%|█████████████████████████████| 476/476 [00:26<00:00, 17.67it/s]


Epoch 3: Train Loss = 0.0164, Val Loss = 0.0166, Perplexity = 1.02
Sample Generated Text:
The eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
Checkpoint saved at epoch 3!
Epoch 4/10


Training: 100%|█████████████████████████████| 2257/2257 [07:26<00:00,  5.06it/s]
Validation: 100%|█████████████████████████████| 476/476 [00:26<00:00, 17.67it/s]


Epoch 4: Train Loss = 0.0150, Val Loss = 0.0157, Perplexity = 1.02
Sample Generated Text:
The eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeegeeeeeeeeeeeeeeeeeeeeeeeeeeeee
Checkpoint saved at epoch 4!
Epoch 5/10


Training: 100%|█████████████████████████████| 2257/2257 [07:26<00:00,  5.06it/s]
Validation: 100%|█████████████████████████████| 476/476 [00:26<00:00, 17.70it/s]


Epoch 5: Train Loss = 0.0140, Val Loss = 0.0152, Perplexity = 1.02
Sample Generated Text:
The e e e epe epeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
Checkpoint saved at epoch 5!
Epoch 6/10


Training:  65%|██████████████████▉          | 1473/2257 [04:51<02:35,  5.06it/s]


KeyboardInterrupt: 