In [None]:
!pip install torch transformers datasets

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import time

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define a larger Transformer-based language model
class LargeLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, max_seq_len):
        super(LargeLanguageModel, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.transformer(x)
        logits = self.fc_out(x)
        return logits

# Dataset preparation
class TextDataset(Dataset):
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        input_ids = self.tokens[idx : idx + self.block_size]
        target_ids = self.tokens[idx + 1 : idx + self.block_size + 1]
        return input_ids, target_ids

# Hyperparameters
vocab_size = 30000
embed_dim = 512
num_heads = 8
num_layers = 6
hidden_dim = 2048
max_seq_len = 256
batch_size = 64  # Adjust based on GPU memory; try 128 or 256 for higher utilization.
learning_rate = 5e-5
epochs = 10
block_size = max_seq_len - 1

# Load dataset
dataset = load_dataset("openwebtext")
train_text = " ".join(dataset["train"]["text"])[:10_000_000]

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Tokenize in batches
tokenized_text = []
for text in dataset["train"]["text"]:
    tokenized = tokenizer(text, truncation=True, padding=True, return_tensors="pt")["input_ids"].squeeze(0)
    tokenized_text.append(tokenized)

# Combine tokenized texts and save
tokens = torch.cat(tokenized_text, dim=0)
torch.save(tokens, "tokenized_openwebtext.pt")
print("Tokenized dataset saved to 'tokenized_openwebtext.pt'.")

# Dataset and DataLoader
dataset = TextDataset(tokens, block_size=block_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Model, loss function, and optimizer
model = LargeLanguageModel(
    tokenizer.vocab_size, embed_dim, num_heads, num_layers, hidden_dim, max_seq_len
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()  # Mixed precision training

# Training loop
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for batch_idx, (input_ids, target_ids) in enumerate(dataloader):
        # Move data to GPU
        input_ids, target_ids = input_ids.to(device), target_ids.to(device)
        
        # Debug prints for device allocation
        if batch_idx == 0:  # Print once per epoch
            print(f"Model device: {next(model.parameters()).device}")
            print(f"Input device: {input_ids.device}, Target device: {target_ids.device}")
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            logits = model(input_ids)
            loss = criterion(logits.view(-1, tokenizer.vocab_size), target_ids.view(-1))

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f"Epoch {epoch + 1}, Batch {batch_idx}, Loss: {loss.item()}")

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(dataloader):.4f}")

# Save the trained model
torch.save(model.state_dict(), "large_language_model.pt")
print("Model training complete and saved to 'large_language_model.pt'.")
