In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import os
import sentencepiece as spm
from tqdm import tqdm
import time
from contextlib import contextmanager

class TextDataset(Dataset):
    def __init__(self, file_path, tokenizer, sequence_length=64):
        self.sequence_length = sequence_length
        self.tokenizer = tokenizer

        # Read and tokenize full text
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Encode full text
        self.tokens = self.tokenizer.encode_as_ids(text)
        self.num_sequences = len(self.tokens) - sequence_length

    def __len__(self):
        return self.num_sequences

    def __getitem__(self, idx):
        sequence = self.tokens[idx:idx + self.sequence_length]
        target = self.tokens[idx + 1:idx + self.sequence_length + 1]
        return torch.tensor(sequence), torch.tensor(target)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attention = nn.MultiheadAttention(
            hidden_size, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(hidden_size)
        self.feedforward = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 4, hidden_size),
            nn.Dropout(dropout)
        )

    def forward(self, x, padding_mask=None):
        seq_len = x.shape[1]
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1).bool()

        attended = self.attention(
            self.norm1(x), self.norm1(x), self.norm1(x),
            attn_mask=causal_mask,
            key_padding_mask=padding_mask
        )[0]
        x = x + attended

        x = x + self.feedforward(self.norm2(x))
        return x

class TextTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, num_layers=6, num_heads=8, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        self.dropout = nn.Dropout(dropout)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size)
        self.output = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, padding_mask=None):
        b, s = x.shape
        token_emb = self.embedding(x)
        pos_ids = torch.arange(s, device=x.device)
        pos_emb = self.pos_emb(pos_ids).unsqueeze(0).expand(b, -1, -1)
        x = self.dropout(token_emb + pos_emb)

        for block in self.transformer_blocks:
            x = block(x, padding_mask)

        x = self.norm(x)
        return self.output(x)

@contextmanager
def nullcontext():
    yield

def train_tokenizer(file_path, vocab_size=8000, model_prefix="spm_model"):
    """Train a SentencePiece tokenizer on the input data."""
    print("Training SentencePiece tokenizer...")

    # Create temporary file with one sentence per line
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    temp_file = "temp_training_data.txt"
    with open(temp_file, 'w', encoding='utf-8') as f:
        # Simple sentence splitting on periods
        sentences = text.replace('\n', ' ').split('.')
        for sentence in sentences:
            if sentence.strip():
                f.write(sentence.strip() + '.\n')

    # Train SentencePiece model
    spm.SentencePieceTrainer.train(
        f'--input={temp_file} --model_prefix={model_prefix} '
        f'--vocab_size={vocab_size} --character_coverage=1.0 '
        '--model_type=bpe --pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3 '
        '--pad_piece=<pad> --unk_piece=<unk> --bos_piece=<s> --eos_piece=</s>'
    )

    # Clean up temporary file
    os.remove(temp_file)

    # Load the trained model
    sp = spm.SentencePieceProcessor()
    sp.load(f"{model_prefix}.model")
    return sp

def train_transformer(
    file_path,
    sequence_length=64,
    hidden_size=256,
    num_layers=6,
    num_heads=8,
    batch_size=32,
    learning_rate=3e-4,
    num_epochs=10,
    device=None
):
    """Train the transformer with SentencePiece tokenization."""

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Train and load tokenizer
    tokenizer = train_tokenizer(file_path)
    vocab_size = tokenizer.get_piece_size()

    # Create dataset and dataloader
    dataset = TextDataset(file_path, tokenizer, sequence_length)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True if device.type == 'cuda' else False,
        num_workers=4 if device.type == 'cuda' else 0
    )

    # Initialize model
    model = TextTransformer(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_heads=num_heads
    ).to(device)

    # Initialize optimizer and training components
    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    print(f"\nTraining on {device}")
    print(f"Vocabulary size: {vocab_size}")
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        progress_bar = tqdm(
            total=len(dataloader),
            desc=f"Epoch {epoch + 1}/{num_epochs}",
            unit="batch"
        )

        for batch_idx, (sequences, targets) in enumerate(dataloader):
            sequences = sequences.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            if device.type == 'cuda':
                with torch.cuda.amp.autocast():
                    outputs = model(sequences)
                    loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            else:
                outputs = model(sequences)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

            optimizer.zero_grad()
            if device.type == 'cuda':
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            epoch_loss += loss.item()
            avg_loss = epoch_loss / (batch_idx + 1)

            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{avg_loss:.4f}'
            })
            progress_bar.update(1)

        progress_bar.close()
        print(f"\nEpoch {epoch + 1} completed - Average Loss: {avg_loss:.4f}\n")

    return model, tokenizer

def generate_text(
    model,
    tokenizer,
    prompt,
    max_length=100,
    temperature=0.7,
    device=None
):
    """Generate text using the trained model and SentencePiece tokenizer."""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()

    # Encode prompt
    tokens = tokenizer.encode_as_ids(prompt)
    tokens = torch.tensor(tokens).unsqueeze(0).to(device)

    # Generate text
    with torch.no_grad():
        with torch.cuda.amp.autocast() if device.type == 'cuda' else nullcontext():
            for _ in range(max_length):
                outputs = model(tokens)
                next_token_logits = outputs[:, -1, :] / temperature
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                if next_token.item() == tokenizer.eos_id():
                    break

                tokens = torch.cat([tokens, next_token], dim=1)

    # Decode and return generated text
    generated_tokens = tokens.squeeze().cpu().tolist()
    return tokenizer.decode_ids(generated_tokens)

In [None]:
model, tokenizer = train_transformer(
    file_path="train.txt",
    sequence_length=128,  # Increased from 64 to capture more context
    hidden_size=512,      # Increased from 256 for more capacity
    num_layers=8,         # Increased from 6 for deeper processing
    num_heads=8,          # Keep same number of heads
    batch_size=96,        # Reduced to help with larger model
    num_epochs=20,        # Increased training time
    learning_rate=1e-4    # Slightly lower learning rate for stability
)
# Generate text (function remains the same but uses the BPE tokenizer)
generated_text = generate_text(
    model,
    tokenizer,
    prompt="who is this",
    max_length=100,
    temperature=0.7
)
print(generated_text)

Training SentencePiece tokenizer...


Epoch 1/20:   1%|          | 127/23068 [00:43<2:11:15,  2.91batch/s, loss=6.7151, avg_loss=7.1964]
  scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None



Training on cuda
Vocabulary size: 8000


  with torch.cuda.amp.autocast():
Epoch 1/20:  19%|█▉        | 1486/7689 [06:27<27:23,  3.77batch/s, loss=4.2185, avg_loss=5.2177]

In [16]:
generated_text = generate_text(
    model,
    tokenizer,
    prompt="Once upon a time",
    max_length=100,
    temperature=0.7
)
print(generated_text)

  with torch.cuda.amp.autocast() if device.type == 'cuda' else nullcontext():


 ⁇ nce upon a time that i was able to capture my first show for my study on the site. ian may suggest reading the theme of the bengali dim, without more information, it is impossible to determine whether it leisure became a man. the man is on the moon and the hanksb. the doorway, and the three cruise of bread for a cabinet is the projecting on the project and smiling a computer. this cultural exchange, and a 31,000 a 3


In [21]:
# Save model and tokenizer
torch.save({'model': model.state_dict(), 'tokenizer': tokenizer}, 'model.pt')

# Load model and tokenizer
checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint['model'])
tokenizer = checkpoint['tokenizer']

  checkpoint = torch.load('model.pt')
