<a href="https://colab.research.google.com/github/Shadid12/NLP_Projects/blob/main/improved_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch
!pip install -U "datasets>=2.14.6"
!pip install "fsspec==2023.9.2"
!pip install tokenizers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 512):
        super().__init__()

        # Create matrix of shape [max_seq_len, d_model]
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        # Apply sin to even indices in the array; 2i
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cos to odd indices in the array; 2i+1
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as a buffer (not a parameter, but moves with model.to(device))
        self.register_buffer("pe", pe.unsqueeze(0))  # shape [1, max_seq_len, d_model]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: Tensor of shape [batch_size, seq_len, d_model]
        returns: same shape, with positional encoding added
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]


class TransformerInputEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, max_seq_len: int = 512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        input_ids: Tensor of shape [batch_size, seq_len]
        returns: Tensor of shape [batch_size, seq_len, d_model]
        """
        token_emb = self.token_embedding(input_ids)
        return self.positional_encoding(token_emb)

In [None]:
from datasets import load_dataset
ds = load_dataset("code_search_net", "javascript", download_mode="force_redownload")

In [None]:
 # a Byte-Level BPE tokenizer
from tokenizers import ByteLevelBPETokenizer
tokenizer = ByteLevelBPETokenizer()

def code_iterator():
    for row in ds['train']:      # ds is an IterableDataset; lazy → low RAM
        yield row["func_code_string"]

# Make byte pair encodings
tokenizer.train_from_iterator(
    code_iterator(),
    vocab_size=32_000,                 # typical sweet-spot for code
    min_frequency=2,
    special_tokens=["<pad>", "<s>", "</s>", "<unk>", "<mask>"]
)

In [None]:
import os
os.makedirs("js-bpe")
tokenizer.save_model("js-bpe")

## Improved Transformer

In [None]:

class GPTBlock(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-norm architecture (more stable)
        seq_len = x.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)

        # Self-attention with causal mask
        attn_out, _ = self.attn(
            self.ln1(x), self.ln1(x), self.ln1(x),
            attn_mask=causal_mask
        )
        x = x + self.dropout(attn_out)

        # MLP
        x = x + self.mlp(self.ln2(x))
        return x

class GPTStyleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.embedding = TransformerInputEmbedding(vocab_size, d_model, max_seq_len)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([GPTBlock(d_model, nhead, dropout) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits

## improved code dataset

In [None]:
import re
from torch.utils.data import Dataset

class ImprovedCodeDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=256, min_length=50):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.samples = []

        # Filter and preprocess texts
        processed_texts = []
        for text in texts:
            # Clean the code
            cleaned = self.clean_code(text)
            if len(cleaned.strip()) > min_length:  # Filter out very short snippets
                processed_texts.append(cleaned)

        print(f"Processed {len(processed_texts)} code samples")

        # Create training samples with sliding window
        for text in processed_texts:
            token_ids = tokenizer.encode(text).ids

            # Skip if too short
            if len(token_ids) < seq_len + 1:
                continue

            # Create overlapping windows (stride = seq_len // 2 for more data)
            stride = seq_len // 2
            for i in range(0, len(token_ids) - seq_len, stride):
                input_ids = token_ids[i:i + seq_len]
                target_ids = token_ids[i + 1:i + 1 + seq_len]

                if len(input_ids) == seq_len and len(target_ids) == seq_len:
                    self.samples.append((input_ids, target_ids))

        print(f"Created {len(self.samples)} training samples")

    def clean_code(self, code):
        """Clean and normalize JavaScript code"""
        # Remove excessive whitespace
        code = re.sub(r'\n\s*\n\s*\n', '\n\n', code)  # Max 2 consecutive newlines
        code = re.sub(r' +', ' ', code)  # Multiple spaces to single space

        # Normalize common patterns
        code = code.replace('\t', '    ')  # Tabs to 4 spaces

        # Remove comments that are too long (they don't help with code prediction)
        code = re.sub(r'//.*?(?=\n|$)', '', code)  # Single line comments
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)  # Multi-line comments

        return code.strip()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_ids, target_ids = self.samples[idx]
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(target_ids, dtype=torch.long),
        }

# Create dataset with better preprocessing
def create_improved_dataset():
    # Load more data if available
    texts = []
    for split in ['train']:  # Add 'validation' if available
        for row in ds[split]:
            texts.append(row["func_code_string"])

    # Remove duplicates and very short snippets
    unique_texts = list(set(texts))
    print(f"Unique code samples: {len(unique_texts)}")

    dataset = ImprovedCodeDataset(unique_texts, tokenizer, seq_len=256)
    return dataset


### Encode Decode and Helper functins

In [None]:
from typing import List, Sequence, Union

tokenizer = ByteLevelBPETokenizer(
    "js-bpe/vocab.json",
    "js-bpe/merges.txt",
)


def encode(
    text: Union[str, Sequence[str]],
    add_special_tokens: bool = True,
) -> Union[List[int], List[List[int]]]:
    """
    Convert raw JavaScript (str) → list[int] token IDs.
    Accepts a single string or an iterable of strings (batch).
    """
    if isinstance(text, str):
        return tokenizer.encode(text, add_special_tokens=add_special_tokens).ids
    # Batch mode
    return [
        enc.ids for enc in tokenizer.encode_batch(
            list(text), add_special_tokens=add_special_tokens
        )
    ]

def decode(
    ids: Union[Sequence[int], Sequence[Sequence[int]]],
    skip_special_tokens: bool = True,
) -> Union[str, List[str]]:
    """
    Convert token IDs back to a JavaScript string.
    Accepts a single list[int] or a batch of them.
    """
    # Detect batch vs single
    if ids and isinstance(ids[0], (list, tuple)):
        return [tokenizer.decode(seq, skip_special_tokens=skip_special_tokens) for seq in ids]  # type: ignore[arg-type]
    return tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)  # type: ignore[arg-type]

In [None]:
import torch
import torch.nn.functional as F

def top_k_top_p_filtering(logits, top_k=50, top_p=0.9, filter_value=-float('Inf')):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering"""
    top_k = min(top_k, logits.size(-1))  # Safety check

    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    return logits

def improved_generate(model, prompt: str, max_new_tokens: int = 50,
                     temperature: float = 0.8, top_k: int = 50,
                     top_p: float = 0.9, repetition_penalty: float = 1.1):
    """
    Improved generation with better sampling strategies
    """
    device = next(model.parameters()).device
    model.eval()

    # Encode the prompt
    input_ids = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
    original_length = input_ids.size(1)

    # Keep track of generated tokens for repetition penalty
    generated_tokens = []

    with torch.no_grad():
        for step in range(max_new_tokens):
            # Forward pass
            outputs = model(input_ids)
            next_token_logits = outputs[0, -1, :] / temperature

            # Apply repetition penalty
            if repetition_penalty != 1.0 and generated_tokens:
                for token_id in set(generated_tokens):
                    next_token_logits[token_id] /= repetition_penalty

            # Apply top-k and top-p filtering
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)

            # Sample from the filtered distribution
            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            generated_tokens.append(next_token.item())

            # Stop if we hit end token
            if next_token.item() == tokenizer.token_to_id("</s>"):
                break

            # Truncate input if it gets too long (sliding window)
            if input_ids.size(1) > 512:  # Adjust based on your max_seq_len
                input_ids = input_ids[:, 1:]  # Remove first token

    # Decode the result
    generated_text = decode(input_ids[0].tolist())
    return generated_text

# Alternative: Beam search for more deterministic generation
def beam_search_generate(model, prompt: str, max_new_tokens: int = 50,
                        num_beams: int = 4, temperature: float = 1.0):
    """
    Beam search generation for more coherent but less diverse output
    """
    device = next(model.parameters()).device
    model.eval()

    # Encode prompt
    input_ids = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
    batch_size = 1

    # Initialize beams
    beams = [(input_ids, 0.0)]  # (sequence, score)

    with torch.no_grad():
        for step in range(max_new_tokens):
            candidates = []

            for seq, score in beams:
                if seq[0, -1].item() == tokenizer.token_to_id("</s>"):
                    candidates.append((seq, score))
                    continue

                # Get next token probabilities
                outputs = model(seq)
                next_token_logits = outputs[0, -1, :] / temperature
                next_token_probs = F.softmax(next_token_logits, dim=-1)

                # Get top k candidates
                top_probs, top_indices = torch.topk(next_token_probs, num_beams)

                for prob, idx in zip(top_probs, top_indices):
                    new_seq = torch.cat([seq, idx.unsqueeze(0).unsqueeze(0)], dim=-1)
                    new_score = score + torch.log(prob).item()
                    candidates.append((new_seq, new_score))

            # Select top beams
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:num_beams]

    # Return best sequence
    best_seq = beams[0][0]
    return decode(best_seq[0].tolist())

## Training loop

In [None]:
from torch.utils.data import DataLoader
from itertools import islice
from torch.optim import Adam
import torch.nn as nn

# Prepare data - Use the dataset class, not the model class
texts = [row["func_code_string"] for row in islice(ds["train"], 15000)]
# Use ImprovedCodeDataset (the dataset class) instead of ImprovedTransformer (the model class)
dataset = ImprovedCodeDataset(texts, tokenizer, seq_len=128)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Model and training - Create the model properly
vocab_size = tokenizer.get_vocab_size()
model = GPTStyleTransformer(
    vocab_size=vocab_size,
    d_model=512,
    nhead=8,
    num_layers=4,
    max_seq_len=128,
    dropout=0.1
).to("cuda")

optimizer = Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for epoch in range(30):
    total_loss = 0
    model.train()
    for batch in dataloader:
        input_ids = batch["input_ids"].to("cuda")        # [B, T]
        labels = batch["labels"].to("cuda")              # [B, T]

        optimizer.zero_grad()
        logits = model(input_ids)                        # [B, T, V]
        loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")

print("Training completed!")

# Test generation
model.eval()
test_prompts = [
    "function add(a, b) {",
    "const arr = [1, 2, 3];",
    "if (condition) {"
]

for prompt in test_prompts:
    result = improved_generate(model, prompt, max_new_tokens=30)
    print(f"Prompt: {prompt}")
    print(f"Generated: {result}")
    print("-" * 50)

## Saving Model for later inference

In [None]:
import json

print("Saving the model...")

# 1. Create a directory to save the model
output_dir = "./gpt_js_model"
os.makedirs(output_dir, exist_ok=True)

# 2. Save the model's state_dict
model_save_path = os.path.join(output_dir, "pytorch_model.bin")
torch.save(model.state_dict(), model_save_path)

# 3. Save the model's configuration
model_config = {
    "vocab_size": model.lm_head.out_features,
    "d_model": model.embedding.token_embedding.embedding_dim,
    "nhead": model.blocks[0].attn.num_heads,
    "num_layers": len(model.blocks),
    "max_seq_len": model.embedding.positional_encoding.pe.size(1),
    "dropout": model.dropout.p,
}
config_save_path = os.path.join(output_dir, "config.json")
with open(config_save_path, 'w') as f:
    json.dump(model_config, f)

# 4. The tokenizer is already saved in the "js-bpe" directory.
#    You can optionally copy it to your model directory for a self-contained package.
import shutil
tokenizer_dir = os.path.join(output_dir, "tokenizer")
if os.path.exists(tokenizer_dir):
    shutil.rmtree(tokenizer_dir) # remove if it exists
shutil.copytree("js-bpe", tokenizer_dir)


print(f"Model, config, and tokenizer saved in {output_dir}")