In [None]:
pip install tiktoken

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import pickle
import mmap
import random
import os
import logging 
import tiktoken
import re
import unicodedata
import time
import math

In [None]:
# Conservative settings for 15GB VRAM
batch_size = 2
context_win = 512
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 1024  # Must be even for RoPE
n_layer = 4
n_head = 4   # Must be divisible by n_query_groups (2)
max_epochs = 5
dropout = 0.1
vocab_size = 100277
torch.manual_seed(2021)

# Verify dimensions
assert n_embd % 2 == 0, "n_embd must be even for RoPE"
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
assert n_head % 2 == 0, "n_head must be divisible by n_query_groups"

# Verify dimensions
assert n_embd % 2 == 0, "n_embd must be even for RoPE"
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"

# Estimated VRAM usage: ~13GB
# Will need gradient checkpointing enabled

In [None]:
# # Standard settings for 40GB VRAM
# batch_size = 16        # Your original setting
# context_win = 8192     # Your original setting
# learning_rate = 5e-4
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# n_embd = 4096         # Your original setting
# n_layer = 32          # Your original setting
# n_head = 32           # Your original setting
# max_epochs = 5
# dropout = 0.2
# vocab_size = 100277
# torch.manual_seed(2021)

# # Estimated VRAM usage: ~32GB
# # Can run without gradient checkpointing

In [None]:
# # Expanded settings for 100GB+ VRAM
# batch_size = 32        # Larger batch size
# context_win = 16384    # Doubled context window
# learning_rate = 8e-4   # Slightly higher learning rate
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# n_embd = 5120         # Larger embedding size
# n_layer = 40          # More layers
# n_head = 40           # More attention heads
# max_epochs = 5
# dropout = 0.2
# vocab_size = 100277
# torch.manual_seed(2021)

# # Estimated VRAM usage: ~80GB
# # Can enable additional features like:
# # - Larger expert count in MoE
# # - Bigger batch sizes
# # - Model parallelism

In [None]:
n_query_groups = 2  # for GQA
window_size = 256   # for sliding window attention
num_experts = 8     # for sparse MoE
num_active = 2      # for sparse MoE
capacity_factor = 1.2  # for sparse MoE

In [None]:
warmup_steps = 1000
max_lr = learning_rate
min_lr = learning_rate * 0.1

In [None]:
# #only for TPU run
# num_cores = xm.xrt_world_size()
# # Select the TPU device
# device = xm.xla_device()

In [None]:
enc = tiktoken.get_encoding("cl100k_base")

In [None]:
eval_interval = 1000
eval_iters = 200

In [None]:
device

In [None]:
train_file = r"/kaggle/input/small-web-crawled-text/webcrawled_train.txt"
val_file = r"/kaggle/input/small-web-crawled-text/webcrawled_val.txt"

In [None]:
logging.basicConfig(filename='data_loading.log', level=logging.WARNING)  


In [None]:
def get_batch(split):
    global current_position 
    filename = train_file if split == 'train' else val_file
    filesize = os.path.getsize(filename)
    
    try:
        with open(filename, 'r', encoding='utf-8', errors='replace') as f:
            # Reset position if we're near the end of file
            if current_position >= filesize - (context_win * batch_size):
                current_position = 0
            
            f.seek(current_position)
            
            # Read a larger chunk to ensure we have enough data
            required_size = context_win * batch_size * 2  # Double size for safety
            data_chunk = f.read(required_size)
            
            if not data_chunk:
                current_position = 0
                return None, None
                
            # Encode the chunk
            tokens = enc.encode(data_chunk)
            data = torch.tensor(tokens, dtype=torch.long)
            
            # Ensure we have enough data for the context window
            if len(data) <= context_win:
                current_position = 0
                return None, None
                
            # Calculate valid random indices
            max_idx = len(data) - context_win - 1
            if max_idx < 1:
                current_position = 0
                return None, None
                
            # Generate random indices and create batches
            ix = torch.randint(0, max_idx, (batch_size,))
            x = torch.stack([data[i:i+context_win] for i in ix])
            y = torch.stack([data[i+1:i+context_win+1] for i in ix])
            
            # Update position
            current_position += len(data_chunk) // 2  # Move forward by half the chunk
            
            # Move to device
            x, y = x.to(device), y.to(device)
            
            return x, y
            
    except Exception as e:
        print(f"Error in get_batch: {str(e)}")
        print(f"Current position: {current_position}")
        print(f"File size: {filesize}")
        current_position = 0
        return None, None

In [None]:
@torch.no_grad()
def estimate_loss():
        out = {}
        model.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(split)
                if X is None:  # Handle case when get_batch returns None
                    continue
                logits, loss = model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        model.train()
        return out

In [None]:
class CosineWarmupScheduler:
    def __init__(self, optimizer, warmup_steps, max_lr, min_lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.current_step = 0
        
    def step(self):
        self.current_step += 1
        if self.current_step < self.warmup_steps:
            # Linear warmup
            lr = self.max_lr * (self.current_step / self.warmup_steps)
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / (num_batches_per_epoch * max_epochs - self.warmup_steps)
            lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
            
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

In [None]:
class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key = nn.Linear(n_embd,head_size)
        self.query = nn.Linear(n_embd,head_size)
        self.value = nn.Linear(n_embd,head_size)
        self.register_buffer('tril',torch.tril(torch.ones(context_win,context_win)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))
        wei = F.softmax(wei,dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out
        

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, n_embd, n_head, n_query_groups=2):
        super().__init__()
        assert n_head % n_query_groups == 0, "n_head must be divisible by n_query_groups"
        
        self.n_head = n_head
        self.n_query_groups = n_query_groups
        self.head_size = n_embd // n_head
        
        # Adjust dimensions for grouped queries
        self.key = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd * n_query_groups // n_head)
        
        self.register_buffer('tril', torch.tril(torch.ones(context_win, context_win)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, C = x.shape  # batch, sequence length, embedding dimension
        
        # Split heads while keeping query groups
        k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_query_groups, self.head_size).transpose(1, 2)
        
        # Repeat KV for each query group
        heads_per_group = self.n_head // self.n_query_groups
        k = k.reshape(B, self.n_query_groups, heads_per_group, T, self.head_size)
        v = v.reshape(B, self.n_query_groups, heads_per_group, T, self.head_size)
        q = q.unsqueeze(2)  # Add dimension for heads_per_group
        
        # Attention scores
        att = (q @ k.transpose(-2, -1)) * (self.head_size ** -0.5)
        att = att.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        # Combine heads
        out = (att @ v).transpose(2, 3).contiguous()
        out = out.view(B, T, C)
        
        return out

In [None]:
class SlidingWindowAttention(nn.Module):
    def __init__(self, n_embd, window_size=256):
        super().__init__()
        self.window_size = window_size
        self.query = nn.Linear(n_embd, n_embd)
        self.key = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        
    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # Create sliding windows
        padding = self.window_size // 2
        padded_k = F.pad(k, (0, 0, padding, padding))
        padded_v = F.pad(v, (0, 0, padding, padding))
        
        # Compute attention for each position using only its window
        output = []
        for i in range(T):
            start = i
            end = i + self.window_size
            window_k = padded_k[:, start:end]
            window_v = padded_v[:, start:end]
            
            att = (q[:, i:i+1] @ window_k.transpose(-2, -1)) / math.sqrt(C)
            att = F.softmax(att, dim=-1)
            out = att @ window_v
            output.append(out)
            
        return torch.cat(output, dim=1)

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048):
        super().__init__()
        # Make sure dim is even
        dim = dim if dim % 2 == 0 else dim - 1
        
        # Create inverse frequency bands
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_position_embeddings = max_position_embeddings
        self.dim = dim

    def forward(self, x, seq_len):
        B, T, C = x.shape
        
        # Create position indices
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        
        # Calculate frequencies
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)  # [seq_len, dim/2]
        
        # Calculate cos and sin
        cos = freqs.cos()  # [seq_len, dim/2]
        sin = freqs.sin()  # [seq_len, dim/2]
        
        # Expand dimensions for broadcasting
        cos = cos.view(1, T, -1)  # [1, seq_len, dim/2]
        sin = sin.view(1, T, -1)  # [1, seq_len, dim/2]
        
        # Duplicate for all batch elements
        cos = cos.expand(B, -1, -1)  # [batch, seq_len, dim/2]
        sin = sin.expand(B, -1, -1)  # [batch, seq_len, dim/2]
        
        # Split input into even and odd dimensions
        x1 = x[..., ::2]  # [batch, seq_len, dim/2]
        x2 = x[..., 1::2]  # [batch, seq_len, dim/2]
        
        # Apply rotation
        rotated_x = torch.cat([
            x1 * cos - x2 * sin,
            x2 * cos + x1 * sin
        ], dim=-1)  # [batch, seq_len, dim]
        
        return rotated_x

In [None]:
class SparseExpertLayer(nn.Module):
    def __init__(self, n_embd, num_experts=8, num_active=2, capacity_factor=1.2):
        super().__init__()
        self.num_experts = num_experts
        self.num_active = num_active
        self.capacity_factor = capacity_factor
        
        # Expert capacity
        self.expert_capacity = int(capacity_factor * batch_size * context_win / num_experts)
        
        self.gate = nn.Linear(n_embd, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(n_embd, 4 * n_embd),
                nn.GELU(),  # Using GELU instead of ReLU
                nn.Linear(4 * n_embd, n_embd),
                nn.Dropout(dropout)
            ) for _ in range(num_experts)
        ])
        
    def forward(self, x):
        B, T, C = x.shape
        x_flat = x.view(-1, C)
        
        # Calculate gates and capacities
        gate_logits = self.gate(x_flat)
        gates = F.softmax(gate_logits, dim=-1)
        
        # Calculate expert assignment
        expert_weights, expert_indices = torch.topk(gates, self.num_active, dim=-1)
        expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
        
        # Dispatch to experts with capacity limits
        final_output = torch.zeros_like(x_flat)
        total_tokens = B * T
        
        for expert_idx in range(self.num_experts):
            # Find tokens routed to this expert
            expert_mask = (expert_indices == expert_idx).any(dim=-1)
            expert_count = expert_mask.sum().item()
            
            if expert_count > 0:
                if expert_count > self.expert_capacity:
                    # Randomly drop tokens if over capacity
                    perm = torch.randperm(expert_count)[:self.expert_capacity]
                    expert_mask[expert_mask.nonzero()[perm]] = False
                
                # Process tokens for this expert
                expert_input = x_flat[expert_mask]
                processed = self.experts[expert_idx](expert_input)
                final_output[expert_mask] = processed
        
        return final_output.view(B, T, C)

In [None]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        # Ensure n_head is divisible by n_query_groups
        n_query_groups = 2
        assert n_head % n_query_groups == 0, "n_head must be divisible by n_query_groups"
        
        self.gqa = GroupedQueryAttention(n_embd, n_head, n_query_groups)
        self.sliding_attention = SlidingWindowAttention(n_embd)
        self.rope = RotaryEmbedding(n_embd)
        self.sparse_ffwd = SparseExpertLayer(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        B, T, C = x.shape
        
        # Apply layer norm and RoPE
        x_ln = self.ln1(x)
        x_rope = self.rope(x_ln, T)
        
        # Multi-scale attention
        x = x + self.gqa(x_rope)
        x = x + self.sliding_attention(self.ln1(x))
        
        # Sparse MoE
        x = x + self.sparse_ffwd(self.ln2(x))
        return x

In [None]:
class GPTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, index, target=None):
        B, T = index.shape
        
        # Token embeddings
        x = self.token_embedding_table(index)  # [B, T, n_embd]
        
        # Process through blocks
        x = self.blocks(x)  # [B, T, n_embd]
        x = self.ln_f(x)
        logits = self.lm_head(x)  # [B, T, vocab_size]
        
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss
            
    def generate(self, index, max_size):
        for _ in range(max_size):
            index_cropped = index[:, -context_win:]
            logits, loss = self(index_cropped)
            logits = logits[:, -1, :]
            prob = F.softmax(logits, dim=-1)
            next_word = torch.multinomial(prob, num_samples=1)
            index = torch.cat((index, next_word), dim=1)
        return index    

# Create a new model instance with MoE
model = GPTModel()
model = model.to(device)

In [None]:
# checkpoint = torch.load(r'/kaggle/working/D:\gpt saved model/model_iter_300000.pt', map_location=torch.device('cuda')) 
# model.load_state_dict(checkpoint)

In [None]:
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineWarmupScheduler(optimizer, warmup_steps, max_lr, min_lr)

current_lr = scheduler.step()

In [None]:
num_batches_per_epoch = 2200000
num_batches_per_epoch

In [None]:
save_dir = r'D:\gpt saved model'
os.makedirs(save_dir, exist_ok=True)

In [None]:
def save_model(model, counter):
    model_path = os.path.join(save_dir, f'model_iter_{counter}.pt')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved at iteration {counter}")

In [None]:
def calculate_perplexity(model):
    total_loss = 0
    total_tokens = 0
    num_batches_per_validation = 32  # Define the number of batches for validation
    
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        for _ in range(num_batches_per_validation):
            try:
                # Get a batch of validation data using get_batch() function
                xb, yb = get_batch('val')  # Assuming 'val' is the split for validation data
                
                if xb is None:
                    break  # End of validation data
                
                # Forward pass and loss estimation
                losses = estimate_loss()  # Using estimate_loss() function
                loss = losses['val']
                
                # Update total loss and total tokens
                total_loss += loss.item() * yb.numel()  # Multiplying loss by number of tokens
                total_tokens += yb.numel()  # Count the number of tokens in targets (yb)
            
            except Exception as e:
                continue
    
    # Calculate perplexity
    average_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(average_loss))
    
    return perplexity


In [None]:
current_position = 0
counter = 0

In [None]:
for epoch in range(max_epochs):
    for iter in range(num_batches_per_epoch):
        # Sample a batch of data (random or sequential, based on your get_batch implementation)
            start = time.time()
            xb, yb = get_batch('train')
            if xb is None:
                continue
            else:
                # Update the counter for every iteration
                counter += 1

                # Evaluate the loss
                logits, loss = model(xb, yb)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                if counter % eval_interval == 0 or epoch == max_epochs - 1:
                                cal = time.time() 
                                endtime = (cal - start)%60
                                losses = estimate_loss() 
                                print(f"Iteration {counter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, total time taken: {endtime},current position {current_position}")
                                print(enc.decode(model.generate(torch.zeros((1,1),dtype=torch.long,device=device),max_size=256)[0].tolist()))
                                print('\n')
                                print("------------")
                # Check if it's time to save the model
                if counter % 100 == 0 :
                    print(".",end="", flush=True)
                if counter % 5000 == 0:
                    save_model(model, counter)
                    
                # Reset current_position if we've reached the end of the epoch 
                if iter == num_batches_per_epoch - 1: 
                    current_position = 0
 

In [None]:
current_position,counter

In [None]:
#previous current_position,counter (135236057, 35524)

In [None]:
perplexity = calculate_perplexity(model)
print(f"Iteration {counter}: Perplexity: {perplexity}")

In [None]:
print(enc.decode(model.generate(torch.zeros((1,1),dtype=torch.long,device=device),max_size=1500)[0].tolist()))

In [None]:
print(f"Iteration {counter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

In [None]:
save_model(model, counter)