In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
from datasets import load_dataset
import math

# Custom Collate Function to Pad Sequences
def custom_collate_fn(batch, pad_token_id):
    inputs, labels = zip(*batch)
    inputs_padded = torch.nn.utils.rnn.pad_sequence(
        inputs, batch_first=True, padding_value=pad_token_id
    )
    labels_padded = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=pad_token_id
    )
    return inputs_padded, labels_padded

# RoPE Function 
def apply_rope(x, position_ids, base=10000):
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    dim_indices = torch.arange(0, head_dim // 2, dtype=torch.float, device=x.device)
    theta = base ** (-2 * dim_indices / head_dim)
    angles = position_ids[:, None, :, None] * theta[None, None, None, :]
    sin_angles = torch.sin(angles)
    cos_angles = torch.cos(angles)
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    rotated_even = x_even * cos_angles - x_odd * sin_angles
    rotated_odd = x_even * sin_angles + x_odd * cos_angles
    rotated = torch.stack((rotated_even, rotated_odd), dim=-1).reshape(x.shape)
    return rotated

# MultiHeadAttention 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        assert self.d_k % 2 == 0, "d_k must be even for RoPE"
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        position_ids = torch.arange(seq_len, dtype=torch.float, device=x.device).unsqueeze(0).expand(batch_size, seq_len)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        Q = apply_rope(Q, position_ids)
        K = apply_rope(K, position_ids)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
            mask = mask[None, None, :, :]
        scores = scores.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return self.W_o(attn_output)

# TransformerDecoderBlock 
class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_hidden_dim):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.GELU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_input = self.norm1(x)
        attn_output = self.attention(attn_input)
        x = x + attn_output
        ffn_input = self.norm2(x)
        ffn_output = self.ffn(ffn_input)
        x = x + ffn_output
        return x


class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_blocks, ff_hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([TransformerDecoderBlock(d_model, num_heads, ff_hidden_dim) for _ in range(num_blocks)])
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for block in self.blocks:
            x = block(x)
        return self.output(x)


class QADataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        convo = self.data[idx]['conversation']
        if len(convo) < 2 or convo[0]['role'] != 'user' or convo[1]['role'] != 'assistant':
            next_idx = (idx + 1) % len(self.data)
            if next_idx == idx:
                raise ValueError("No valid conversations found in dataset")
            return self.__getitem__(next_idx)
        question = convo[0]['content']
        answer = convo[1]['content']
        qa_text = f"Question: {question}\nAnswer: {answer}"
        tokens = self.tokenizer.encode(
            qa_text,
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        ).squeeze()
        return tokens[:-1], tokens[1:]

# Hyperparameters
vocab_size = 50257
d_model = 256
num_heads = 8
num_blocks = 4
ff_hidden_dim = 1024
max_len = 64
batch_size = 2
num_epochs = 100  # Increased for better training
learning_rate = 5e-5
gradient_accumulation_steps = 8

# Load Dataset (1,000 examples)
dataset = load_dataset("lmsys/lmsys-chat-1m", split='train').select(range(1000))

# Filter valid conversations
def is_valid_conversation(example):
    return (
        'conversation' in example and
        len(example['conversation']) >= 2 and
        example['conversation'][0]['role'] == 'user' and
        example['conversation'][1]['role'] == 'assistant'
    )

valid_data = dataset.filter(is_valid_conversation)
print(f"Number of valid conversations: {len(valid_data)}")

# Initialize tokenizer and dataset
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
train_dataset = QADataset(valid_data, tokenizer, max_len)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: custom_collate_fn(batch, tokenizer.pad_token_id)
)

# Model, Optimizer, Loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT(vocab_size, d_model, num_heads, num_blocks, ff_hidden_dim).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Training Loop
model.train()
total_loss = 0
steps = 0
for epoch in range(num_epochs):
    total_loss = 0
    steps = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        loss = loss / gradient_accumulation_steps
        loss.backward()
        total_loss += loss.item()
        steps += 1
        if steps % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        del inputs, labels, outputs, loss
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {total_loss / steps:.4f}")

# Improved Generation with Top-k Sampling
def generate_response(question, max_new_tokens=50, top_k=50, temperature=0.7):
    model.eval()
    input_ids = tokenizer.encode(f"Question: {question}\nAnswer:", return_tensors='pt').to(device)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(input_ids)
            logits = outputs[:, -1, :] / temperature
            # Top-k sampling
            top_k_probs, top_k_indices = torch.topk(F.softmax(logits, dim=-1), top_k, dim=-1)
            next_token = torch.multinomial(top_k_probs, num_samples=1)
            next_token = top_k_indices.gather(-1, next_token)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            if next_token.item() == tokenizer.eos_token_id:
                break
    response = tokenizer.decode(input_ids[0])
    del input_ids, outputs, logits, top_k_probs, top_k_indices, next_token
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    return response

print("Sample Generation:", generate_response("What is AI?"))

# Save model
torch.save(model.state_dict(), "gpt_like_qa_model.pth")
print("Model saved.")

