In [None]:
!export CUDA_VISIBLE_DEVICES=8
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json

In [None]:
# from datasets import load_dataset

# # Load SQuAD dataset and select only 5 examples
# squad = load_dataset("squad", split="train[:1]")

# for example in squad:
#     print(f"Context: {example['context']}")
#     print(f"Question: {example['question']}")
#     print(f"Answer: {example['answers']['text'][0]}")
#     print("="*80)
#     break


In [None]:
# from transformers import GPT2Tokenizer

# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# tokenizer.pad_token = tokenizer.eos_token

# def format_qa(example):
#     prompt = f"Context: {example['context']}\nQuestion: {example['question']}\nAnswer:"
#     answer = example['answers']['text'][0] + tokenizer.eos_token
#     return {"input_text": prompt, "output_text": answer}

# formatted = squad.map(format_qa)

# def tokenize(example):
#     input_ids = tokenizer(example['input_text'], truncation=True, padding="max_length", max_length=512, return_tensors="pt")["input_ids"][0]
#     labels = tokenizer(example['input_text'] + " " + example['output_text'], truncation=True, padding="max_length", max_length=512, return_tensors="pt")["input_ids"][0]
#     return {"input_ids": input_ids, "labels": labels}

# tokenized_dataset = formatted.map(tokenize)


In [None]:
# from transformers import GPT2LMHeadModel, Trainer, TrainingArguments
# import torch
# from torch.utils.data import Dataset

# class QADataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __getitem__(self, idx):
#         return {
#             "input_ids": self.data[idx]["input_ids"],
#             "attention_mask": self.data[idx]["input_ids"] != tokenizer.pad_token_id,
#             "labels": self.data[idx]["labels"]
#         }

#     def __len__(self):
#         return len(self.data)

# model = GPT2LMHeadModel.from_pretrained("gpt2")
# model.resize_token_embeddings(len(tokenizer))

# qa_dataset = QADataset(tokenized_dataset)

# training_args = TrainingArguments(
#     output_dir="./qa-gpt2",
#     per_device_train_batch_size=8,
#     num_train_epochs=2,
#     logging_steps=1,
#     save_steps=10,
#     save_total_limit=1,
#     report_to="none",
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=qa_dataset,
# )

# trainer.train()


In [None]:
# model.eval()

# def generate_answer(context, question):
#     prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
#     input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to('cuda')
#     output = model.generate(input_ids, max_new_tokens=10,  eos_token_id=tokenizer.eos_token_id,pad_token_id=tokenizer.eos_token_id)  # Avoid warning if pad token is undefined)
#     answer = tokenizer.decode(output[0], skip_special_tokens=True)
#     return answer.split("Answer:")[-1].strip()

# # Test on training examples
# for example in squad:
#     print("Q:", example['question'])
#     print("Predicted A:", generate_answer(example['context'], example['question']))
#     print("True A:", example['answers']['text'][0])
#     print("="*60)
#     break


### Custom Model with Separate Key Query and Value

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2PreTrainedModel
from transformers.modeling_utils import Conv1D

# --------- Custom Attention Module using Conv1D ----------
class CustomGPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.q_proj = Conv1D(self.embed_dim, self.embed_dim)
        self.k_proj = Conv1D(self.embed_dim, self.embed_dim)
        self.v_proj = Conv1D(self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.scale = self.head_dim ** -0.5
        
        # Register causal mask buffer (same as original GPT2)
        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))

    def _split_heads(self, x):
        batch_size, seq_len, embed_dim = x.size()
        # embed_dim should equal num_heads * head_dim
        assert embed_dim == self.num_heads * self.head_dim, f"Embed dim {embed_dim} != num_heads * head_dim {self.num_heads * self.head_dim}"
        
        # reshape to (batch_size, seq_len, num_heads, head_dim)
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        # permute to (batch_size, num_heads, seq_len, head_dim)
        return x.permute(0, 2, 1, 3)

    def _merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (self.embed_dim,)
        return x.view(*new_shape)

    def forward(self, hidden_states, layer_past=None, attention_mask=None,
                head_mask=None, use_cache=False, output_attentions=False):
    
        # Project to Q, K, V
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        # Split heads
        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        # Handle past keys/values for generation
        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache:
            present = (key, value)
        else:
            present = None

        # Compute attention weights
        attn_weights = torch.matmul(query, key.transpose(-1, -2))
        attn_weights = attn_weights * self.scale

        # Apply causal mask
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
        mask_value = torch.finfo(attn_weights.dtype).min
        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
        attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

        # Apply attention mask if provided
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # Softmax
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply head mask if provided
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, value)
        attn_output = self._merge_heads(attn_output)

        # Final projection
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

In [None]:
# --------- Custom GPT2 Block ----------
class CustomGPT2Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.attn = CustomGPT2Attention(config)
        self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2Block(config).mlp

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)

        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        attn_output = attn_outputs[0]
        outputs = attn_outputs[1:]

        hidden_states = residual + attn_output

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        return (hidden_states,) + outputs

In [None]:
# --------- Custom GPT2 Model ----------
class CustomGPT2Model(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        self.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.num_hidden_layers)])

In [None]:
class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        # Initialize the parent class first
        super(GPT2PreTrainedModel, self).__init__(config)
        
        # Replace the transformer with our custom one
        self.transformer = CustomGPT2Model(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Initialize weights
        self.post_init()

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Forward through transformer (excluding labels)
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # only last token for inputs_ids if past is defined in kwargs
        if past_key_values:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -input_ids.shape[1] :]

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]
        else:
            position_ids = None

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "position_ids": position_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
            }
        )

        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

In [None]:
# --------- Copy Weights from Original GPT2 Model ----------
def copy_weights(original_model, custom_model):
    orig_state_dict = original_model.state_dict()
    custom_state_dict = custom_model.state_dict()

    for name, param in orig_state_dict.items():
        if "attn.c_attn.weight" in name:
            layer_num = int(name.split('.')[2])
            prefix = f'transformer.h.{layer_num}.attn.'
            
            # Original c_attn weight shape: (embed_dim, 3 * embed_dim)
            # Need to split along dim=1 (the 3 * embed_dim dimension)
            embed_dim = param.shape[0]
            q_weight, k_weight, v_weight = torch.split(param, embed_dim, dim=1)
            
            # Conv1D weight shape is (input_dim, output_dim), no transpose needed
            custom_state_dict[f'{prefix}q_proj.weight'].copy_(q_weight)
            custom_state_dict[f'{prefix}k_proj.weight'].copy_(k_weight)  
            custom_state_dict[f'{prefix}v_proj.weight'].copy_(v_weight)

        elif "attn.c_attn.bias" in name:
            layer_num = int(name.split('.')[2])
            prefix = f'transformer.h.{layer_num}.attn.'
            hidden_size = param.shape[0] // 3

            q_bias, k_bias, v_bias = torch.split(param, hidden_size)
            custom_state_dict[f'{prefix}q_proj.bias'].copy_(q_bias)
            custom_state_dict[f'{prefix}k_proj.bias'].copy_(k_bias)
            custom_state_dict[f'{prefix}v_proj.bias'].copy_(v_bias)

        elif "attn.c_proj.weight" in name:
            # Copy c_proj weights directly
            layer_num = int(name.split('.')[2])
            prefix = f'transformer.h.{layer_num}.attn.'
            custom_state_dict[f'{prefix}c_proj.weight'].copy_(param)
            
        else:
            if name in custom_state_dict:
                custom_state_dict[name].copy_(param)

    custom_model.load_state_dict(custom_state_dict)

In [None]:
# config = GPT2Config.from_pretrained("gpt2")
# original_model = GPT2LMHeadModel.from_pretrained("gpt2")
# custom_model = CustomGPT2LMHeadModel(config)
# copy_weights(original_model, custom_model)


In [None]:
# 4 times

In [None]:
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_scheduler
from transformers import GPT2TokenizerFast


from torch.optim import AdamW,SGD
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# -------------------------------
# Load and prepare SQuAD dataset
# -------------------------------
squad = load_dataset("squad", split="train[:20000]")  # Small subset for testing

#tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Initialize the fast tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Set pad token to EOS

def format_qa(example):
    prompt = f"Context: {example['context']}\nQuestion: {example['question']}\nAnswer:"
    answer = example['answers']['text'][0] + tokenizer.eos_token
    return {"input_text": prompt, "output_text": answer}

formatted = squad.map(format_qa)

def tokenize(example):
    input_enc = tokenizer(example['input_text'], truncation=True, padding="max_length", max_length=512)
    output_enc = tokenizer(example['input_text'] + " " + example['output_text'], truncation=True, padding="max_length", max_length=512)
    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "labels": output_enc["input_ids"]
    }

tokenized = formatted.map(tokenize)

# -------------------------------
# PyTorch Dataset
# -------------------------------
class QADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.data[idx]["input_ids"]),
            "attention_mask": torch.tensor(self.data[idx]["attention_mask"]),
            "labels": torch.tensor(self.data[idx]["labels"]),
        }

    def __len__(self):
        return len(self.data)

qa_dataset = QADataset(tokenized)

# DataLoader with default collate_fn
dataloader = DataLoader(qa_dataset, batch_size=16, shuffle=True)

# -------------------------------
# Model and Optimizer Setup
# -------------------------------
#model = GPT2LMHeadModel.from_pretrained("gpt2")

config = GPT2Config.from_pretrained("gpt2")
original_model = GPT2LMHeadModel.from_pretrained("gpt2")
model = CustomGPT2LMHeadModel(config)
copy_weights(original_model, model)
model.resize_token_embeddings(len(tokenizer))

optimizer = AdamW(model.parameters(), lr=5e-5)


# optimizer = AdamW(
#     [
#         {"params": [p for n, p in model.named_parameters() if 'attn.q_proj' in n or 'attn.k_proj' in n], "lr": 2e-4},
#         {"params": [p for n, p in model.named_parameters() if 'attn.q_proj' not in n and 'attn.k_proj' not in n], "lr": 5e-5}
#     ],
#     weight_decay=0.01,
#     eps=1e-8
# )





epochs = 25
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Scheduler
num_training_steps = len(dataloader) * epochs
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# -------------------------------
# Training Loop with Logging
# -------------------------------
logging_steps = 200
global_step = 0

model.train()
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    loop = tqdm(dataloader, desc="Training", leave=False)
    
    for step, batch in enumerate(loop):
        global_step += 1
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        loop.set_postfix(loss=loss.item())

        if global_step % logging_steps == 0:
            print(f"Step {global_step} - Loss: {loss.item():.4f}")

In [None]:
import numpy as np
import torch.nn.functional as F

In [None]:
import string

def normalize_text(text):
    """Lowercase, remove punctuation, and normalize whitespace."""
    text = text.lower().strip()
    text = ''.join(ch for ch in text if ch not in string.punctuation)
    text = ' '.join(text.split())
    return text

def generate_outputs(model, tokenizer, input_ids, answer_ids, device):
    eos_token_id = tokenizer.eos_token_id
    generated = input_ids.clone()
    all_attentions = []
    true_token_probs = []

    for step in range(len(answer_ids)):
        with torch.no_grad():
            outputs = model(input_ids=generated, output_attentions=True)
        logits = outputs.logits
        attentions = outputs.attentions

        next_token_logits = logits[:, -1, :]
        probs = F.softmax(next_token_logits, dim=-1)

        true_token_id = answer_ids[step].item()
        true_prob = probs[0, true_token_id].item()
        true_token_probs.append(true_prob)

        # Debug print (optional)
        # print(f"Step {step}: True token id = {true_token_id}, Prob = {true_prob:.8f}")

        next_token = torch.argmax(probs, dim=-1)

        # Save attention from last token (list of tensors: layers × [batch, heads, seq_len])
        all_attentions.append([a[:, :, -1, :] for a in attentions])

        generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1)

        if next_token.item() == eos_token_id:
            break

    # Decode only the generated tokens AFTER the prompt
    gen_tokens = generated[0, input_ids.shape[-1]:]
    Final_output = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

    return Final_output, all_attentions, true_token_probs

In [None]:
# Main evaluation loop
count = 0
model.eval()
analysis = []
for example in tqdm(squad):
    context = example["context"]
    question = example["question"]
    answer_text = example["answers"]["text"][0]
    prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"

    # Encode prompt with offsets
    encodings = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt")
    input_ids = encodings["input_ids"].to(device)
    offsets = encodings["offset_mapping"][0].tolist()  # list of (start, end)

    # Locate where context starts in prompt (to offset answer span properly)
    context_start_in_prompt = prompt.find(context)
    answer_start_char = example["answers"]["answer_start"][0] + context_start_in_prompt
    answer_end_char = answer_start_char + len(answer_text)

    # Map character span of answer to token indices in the prompt
    token_indices = [
        i for i, (start, end) in enumerate(offsets)
        if start < answer_end_char and end > answer_start_char
    ]

    # Encode ground-truth answer tokens WITH leading space because generated tokens usually include it
    answer_ids = tokenizer.encode(" " + answer_text, add_special_tokens=False)
    answer_ids = torch.tensor(answer_ids).to(device)

    # Generate predicted answer + track true token probabilities
    output, attns, true_token_probs = generate_outputs(model, tokenizer, input_ids, answer_ids, device)

    # Normalize answers for fair comparison
    norm_true = normalize_text(answer_text)
    norm_pred = normalize_text(output)

    # Average true token probability over all steps
    avg_true_token_prob = sum(true_token_probs) / len(true_token_probs)

    # Compute average attention score over heads and steps for true token indices
    total_attention = 0.0
    for step_attn in attns:
        #print(len(step_attn),step_attn[0].shape)
        last_layer_attn = step_attn[1][0]  # last layer, batch 0: (num_heads, seq_len)
        step_total = 0.0
        for h in range(last_layer_attn.shape[0]):
            head_attn = last_layer_attn[h]
            step_total += sum(head_attn[j] for j in token_indices if j < head_attn.shape[0])
        avg_step_attn = step_total / last_layer_attn.shape[0]  # average over heads
        total_attention += avg_step_attn
    true_token_attention_score = total_attention / len(attns)  # average over steps

    # Print results
    #print("=" * 60)
    #print(f"Example #{count + 1}")
    #print("True Answer     :", answer_text)
    #print("Predicted Answer:", output)
    #print("Normalized True :", norm_true)
    #print("Normalized Pred :", norm_pred)
    #print("Avg True Token Prob: {:.6f}".format(avg_true_token_prob))
    #print("True Token Attention Score: {:.6f}".format(true_token_attention_score.item()))

    analysis.append([norm_true,norm_pred,avg_true_token_prob,true_token_attention_score.item()])

    # count += 1
    # if count >= 200:
    #     break

In [None]:
import json

# Save to file
with open("same_train_qa_eval_results_layer_1.json", "w") as f:
    json.dump(analysis, f, indent=2)

In [None]:
tt = np.array(analysis)[:,-1].astype(float)
sum(tt<=0.05), sum(tt>0.05)
#386 114 # 10 times  500 dp
#325, 208# 100 times 500 dp
#419, 81  500 dp

In [None]:
tt,1283/5000,149/5000

In [None]:
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns


bin_edges = [np.array([0, 0.05, 0.1, 0.5, 1]), np.array([0, 0.05, 0.1, 0.5,  1])]

num = 200

hist,_,_= np.histogram2d( np.array(analysis)[:,-1].astype(float),  np.array(analysis)[:,-2].astype(float), bins=bin_edges)

hist = hist/num
# Assuming `mean_hist_norm` and `var_hist_norm` are your 2D arrays of the same shape
# If you prefer to show std instead of variance in the annotation, convert variance to std:


# Create annotation labels with both mean and std (formatted as strings)
annot_array = np.empty_like(hist, dtype=object)


for i in range(hist.shape[0]):
    for j in range(hist.shape[1]):
        annot_array[i, j] = f"{hist[i, j]:.1f}"#\n(±{std_hist_norm[i, j]:.1f})"  # mean ± std

# # Define tick positions for edges of bins (for correct labeling)
x_edges = np.array([0, 0.05, 0.1, 0.5, 1])
y_edges = np.array([0, 0.05, 0.1, 0.5, 1])
xtick_positions = np.arange(len(x_edges) - 1) + 1.0  # Right edges
ytick_positions = np.arange(len(y_edges) - 1) + 1.0  # Top edges

# Plot heatmap for mean values (color intensity)
plt.figure(figsize=(6, 6))
ax = sns.heatmap(hist.T, annot=annot_array.T, cmap=sns.color_palette("coolwarm"), 
                  annot_kws={"size":18}, cbar=False, vmin=5, vmax=70,fmt="")

# Adjust tick positions for x and y axes (move them to the edges)
ax.set_xticks(xtick_positions)
ax.set_yticks(ytick_positions)

# Set the labels for ticks (x and y edges)
ax.set_xticklabels(x_edges[1:])
ax.set_yticklabels(y_edges[1:])

# # Make the tick labels bold
ax.tick_params(axis='x', labelsize=14)  # Bold x-axis labels
ax.tick_params(axis='y', labelsize=14)  # Bold y-axis labels


# Invert y-axis to align with typical heatmap style
ax.invert_yaxis()

# Labels and title
plt.xlabel(r"Distinct Token Attention", fontweight="bold", fontsize=16)
plt.ylabel(r"Relevant Token Probability", fontweight="bold", fontsize=16)
#plt.title("Mean (±Std Dev) Heatmap")

# # Save the figure
plt.savefig("faster_qk_10_times_squad_train_l1.pdf", bbox_inches='tight')
plt.show()

In [None]:
# model.eval()

# def generate_answer(context, question):
#     prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
#     input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to('cuda')
#     output = model.generate(input_ids, max_new_tokens=10,  eos_token_id=tokenizer.eos_token_id,pad_token_id=tokenizer.eos_token_id)  # Avoid warning if pad token is undefined)
#     answer = tokenizer.decode(output[0], skip_special_tokens=True)
#     return answer.split("Answer:")[-1].strip()

# count = 0 
# # Test on training examples
# for example in squad:
#     print("Q:", example['question'])
#     print("Predicted A:", generate_answer(example['context'], example['question']))
#     print("True A:", example['answers']['text'][0])
#     print("="*60)
#     count +=1
#     if count>50:
#         break


In [None]:
import string
import re
from collections import Counter

def normalize_answer(s):
    """Normalize answer text for comparison"""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.IGNORECASE)
        return re.sub(regex, ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    """Split text into tokens"""
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact_match(a_gold, a_pred):
    """Compute exact match score"""
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
    """Compute F1 score"""
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    
    if not gold_toks and not pred_toks:
        return 1.0
    
    if not gold_toks or not pred_toks:
        return 0.0
    
    common = Counter(gold_toks) & Counter(pred_toks)
    num_same = sum(common.values())
    
    if num_same == 0:
        return 0.0
    
    precision = num_same / len(pred_toks)
    recall = num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return f1

# Your evaluation code with metrics
model.eval()

def generate_answer(context, question):
    prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to('cuda')
    output = model.generate(input_ids, max_new_tokens=10, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id)
    answer = tokenizer.decode(output[0], skip_special_tokens=True)
    return answer.split("Answer:")[-1].strip()

# Initialize metrics tracking
count = 0
exact_match_scores = []
f1_scores = []
predictions = []
ground_truths = []

print("Evaluating model on SQuAD examples...")
print("=" * 80)

# Test on training examples
for example in tqdm(squad):
    # Generate prediction
    predicted_answer = generate_answer(example['context'], example['question'])
    true_answers = example['answers']['text']  # List of all valid answers
    
    # Take the best score across all possible answers (SQuAD style)
    em_score = max(compute_exact_match(true_ans, predicted_answer) for true_ans in true_answers)
    f1_score = max(compute_f1(true_ans, predicted_answer) for true_ans in true_answers)
    
    # Store scores
    exact_match_scores.append(em_score)
    f1_scores.append(f1_score)
    predictions.append(predicted_answer)
    ground_truths.append(true_answers[0])  # First answer for display
    
    # Display results
    #print(f"Question {count + 1}:")
    #print(f"Q: {example['question']}")
    #print(f"Predicted A: {predicted_answer}")
    #print(f"True A: {true_answers[0]}")
    #print(f"EM Score: {em_score} | F1 Score: {f1_score:.3f}")
    #print("=" * 60)
    
    # count += 1
    # if count > 5000:
    #     break

# Calculate overall metrics
overall_em = sum(exact_match_scores) / len(exact_match_scores) * 100
overall_f1 = sum(f1_scores) / len(f1_scores) * 100

# Print summary
print("\n" + "=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
print(f"Total Questions: {len(exact_match_scores)}")
print(f"Exact Match Score: {overall_em:.2f}%")
print(f"F1 Score: {overall_f1:.2f}%")

# Additional statistics
print(f"\nDetailed Statistics:")
print(f"Questions with EM = 1: {sum(exact_match_scores)} ({sum(exact_match_scores)/len(exact_match_scores)*100:.1f}%)")
print(f"Questions with F1 > 0.5: {sum(1 for f1 in f1_scores if f1 > 0.5)} ({sum(1 for f1 in f1_scores if f1 > 0.5)/len(f1_scores)*100:.1f}%)")
print(f"Average F1 for non-zero scores: {sum(f1 for f1 in f1_scores if f1 > 0) / max(1, sum(1 for f1 in f1_scores if f1 > 0)):.3f}")

# Show some examples of different performance levels
print(f"\n" + "=" * 80)
print("EXAMPLE ANALYSIS")
print("=" * 80)

# Find examples with different performance levels
perfect_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if em == 1]
partial_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if em == 0 and f1 > 0.3]
no_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if f1 == 0]

if perfect_matches:
    print(f"\nPerfect Matches (EM=1): {len(perfect_matches)} examples")
    for i, (idx, pred, true) in enumerate(perfect_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}'")

if partial_matches:
    print(f"\nPartial Matches (EM=0, F1>0.3): {len(partial_matches)} examples")
    for i, (idx, pred, true) in enumerate(partial_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}' | F1: {f1_scores[idx]:.3f}")

if no_matches:
    print(f"\nNo Matches (F1=0): {len(no_matches)} examples")
    for i, (idx, pred, true) in enumerate(no_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}'")

print("=" * 80)

In [None]:
# Create a directory to save
save_path = "./faster_saved"

# Save model
model.save_pretrained(save_path)

# Save tokenizer
tokenizer.save_pretrained(save_path)

In [None]:
squad_validation = load_dataset("squad", split="validation[:5000]")  # Small subset for testing


formatted_validation = squad_validation.map(format_qa) # format_qa is a function


tokenized_validation = formatted_validation.map(tokenize) # tokenize is a function



qa_dataset_validation = QADataset(tokenized_validation)

# DataLoader with default collate_fn
dataloader_validation = DataLoader(qa_dataset_validation, batch_size=8, shuffle=True)

In [None]:
# Main evaluation loop
count = 0
model.eval()
analysis_validation = []
for example in tqdm(squad_validation):
    context = example["context"]
    question = example["question"]
    answer_text = example["answers"]["text"][0]
    prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"

    # Encode prompt with offsets
    encodings = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt")
    input_ids = encodings["input_ids"].to(device)
    offsets = encodings["offset_mapping"][0].tolist()  # list of (start, end)

    # Locate where context starts in prompt (to offset answer span properly)
    context_start_in_prompt = prompt.find(context)
    answer_start_char = example["answers"]["answer_start"][0] + context_start_in_prompt
    answer_end_char = answer_start_char + len(answer_text)

    # Map character span of answer to token indices in the prompt
    token_indices = [
        i for i, (start, end) in enumerate(offsets)
        if start < answer_end_char and end > answer_start_char
    ]

    # Encode ground-truth answer tokens WITH leading space because generated tokens usually include it
    answer_ids = tokenizer.encode(" " + answer_text, add_special_tokens=False)
    answer_ids = torch.tensor(answer_ids).to(device)

    # Generate predicted answer + track true token probabilities
    output, attns, true_token_probs = generate_outputs(model, tokenizer, input_ids, answer_ids, device)

    # Normalize answers for fair comparison
    norm_true = normalize_text(answer_text)
    norm_pred = normalize_text(output)

    # Average true token probability over all steps
    avg_true_token_prob = sum(true_token_probs) / len(true_token_probs)

    # Compute average attention score over heads and steps for true token indices
    total_attention = 0.0
    for step_attn in attns:
        last_layer_attn = step_attn[1][0]  # last layer, batch 0: (num_heads, seq_len)
        step_total = 0.0
        for h in range(last_layer_attn.shape[0]):
            head_attn = last_layer_attn[h]
            step_total += sum(head_attn[j] for j in token_indices if j < head_attn.shape[0])
        avg_step_attn = step_total / last_layer_attn.shape[0]  # average over heads
        total_attention += avg_step_attn
    true_token_attention_score = total_attention / len(attns)  # average over steps

    # Print results
    #print("=" * 60)
    #print(f"Example #{count + 1}")
    #print("True Answer     :", answer_text)
    #print("Predicted Answer:", output)
    #print("Normalized True :", norm_true)
    #print("Normalized Pred :", norm_pred)
    #print("Avg True Token Prob: {:.6f}".format(avg_true_token_prob))
    #print("True Token Attention Score: {:.6f}".format(true_token_attention_score.item()))

    analysis_validation.append([norm_true,norm_pred,avg_true_token_prob,true_token_attention_score.item()])

    # count += 1
    # if count >= 200:
    #     break

In [None]:
import json

# Save to file
with open("same_validation_qa_eval_results_validation_l1.json", "w") as f:
    json.dump(analysis_validation, f, indent=2)

In [None]:
tt = np.array(analysis_validation)[:,-1].astype(float)
sum(tt<=0.05), sum(tt>0.05)

In [None]:
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns


bin_edges = [np.array([0, 0.05, 0.1, 0.5, 1]), np.array([0, 0.05, 0.1, 0.5,  1])]

num = 50

hist,_,_= np.histogram2d( np.array(analysis_validation)[:,-1].astype(float),  np.array(analysis_validation)[:,-2].astype(float), bins=bin_edges)

hist = hist/num
# Assuming `mean_hist_norm` and `var_hist_norm` are your 2D arrays of the same shape
# If you prefer to show std instead of variance in the annotation, convert variance to std:


# Create annotation labels with both mean and std (formatted as strings)
annot_array = np.empty_like(hist, dtype=object)


for i in range(hist.shape[0]):
    for j in range(hist.shape[1]):
        annot_array[i, j] = f"{hist[i, j]:.1f}"#\n(±{std_hist_norm[i, j]:.1f})"  # mean ± std

# # Define tick positions for edges of bins (for correct labeling)
x_edges = np.array([0, 0.05, 0.1, 0.5, 1])
y_edges = np.array([0, 0.05, 0.1, 0.5, 1])
xtick_positions = np.arange(len(x_edges) - 1) + 1.0  # Right edges
ytick_positions = np.arange(len(y_edges) - 1) + 1.0  # Top edges

# Plot heatmap for mean values (color intensity)
plt.figure(figsize=(6, 6))
ax = sns.heatmap(hist.T, annot=annot_array.T, cmap=sns.color_palette("coolwarm"), 
                  annot_kws={"size":18}, cbar=False, vmin=5, vmax=70,fmt="")

# Adjust tick positions for x and y axes (move them to the edges)
ax.set_xticks(xtick_positions)
ax.set_yticks(ytick_positions)

# Set the labels for ticks (x and y edges)
ax.set_xticklabels(x_edges[1:])
ax.set_yticklabels(y_edges[1:])

# # Make the tick labels bold
ax.tick_params(axis='x', labelsize=14)  # Bold x-axis labels
ax.tick_params(axis='y', labelsize=14)  # Bold y-axis labels


# Invert y-axis to align with typical heatmap style
ax.invert_yaxis()

# Labels and title
plt.xlabel(r"Distinct Token Attention", fontweight="bold", fontsize=16)
plt.ylabel(r"Relevant Token Probability", fontweight="bold", fontsize=16)
#plt.title("Mean (±Std Dev) Heatmap")

# # Save the figure
plt.savefig("faster_qk_10_times_squad_validation_l1.pdf", bbox_inches='tight')
plt.show()

In [None]:
# Initialize metrics tracking
count = 0
exact_match_scores = []
f1_scores = []
predictions = []
ground_truths = []

print("Evaluating model on SQuAD examples...")
print("=" * 80)

# Test on training examples
for example in tqdm(squad_validation):
    # Generate prediction
    predicted_answer = generate_answer(example['context'], example['question'])
    true_answers = example['answers']['text']  # List of all valid answers
    
    # Take the best score across all possible answers (SQuAD style)
    em_score = max(compute_exact_match(true_ans, predicted_answer) for true_ans in true_answers)
    f1_score = max(compute_f1(true_ans, predicted_answer) for true_ans in true_answers)
    
    # Store scores
    exact_match_scores.append(em_score)
    f1_scores.append(f1_score)
    predictions.append(predicted_answer)
    ground_truths.append(true_answers[0])  # First answer for display
    
    # Display results
    #print(f"Question {count + 1}:")
    #print(f"Q: {example['question']}")
    #print(f"Predicted A: {predicted_answer}")
    #print(f"True A: {true_answers[0]}")
    #print(f"EM Score: {em_score} | F1 Score: {f1_score:.3f}")
    #print("=" * 60)
    
    # count += 1
    # if count > 5000:
    #     break

# Calculate overall metrics
overall_em = sum(exact_match_scores) / len(exact_match_scores) * 100
overall_f1 = sum(f1_scores) / len(f1_scores) * 100

# Print summary
print("\n" + "=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
print(f"Total Questions: {len(exact_match_scores)}")
print(f"Exact Match Score: {overall_em:.2f}%")
print(f"F1 Score: {overall_f1:.2f}%")

# Additional statistics
print(f"\nDetailed Statistics:")
print(f"Questions with EM = 1: {sum(exact_match_scores)} ({sum(exact_match_scores)/len(exact_match_scores)*100:.1f}%)")
print(f"Questions with F1 > 0.5: {sum(1 for f1 in f1_scores if f1 > 0.5)} ({sum(1 for f1 in f1_scores if f1 > 0.5)/len(f1_scores)*100:.1f}%)")
print(f"Average F1 for non-zero scores: {sum(f1 for f1 in f1_scores if f1 > 0) / max(1, sum(1 for f1 in f1_scores if f1 > 0)):.3f}")

# Show some examples of different performance levels
print(f"\n" + "=" * 80)
print("EXAMPLE ANALYSIS")
print("=" * 80)

# Find examples with different performance levels
perfect_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if em == 1]
partial_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if em == 0 and f1 > 0.3]
no_matches = [(i, predictions[i], ground_truths[i]) for i, (em, f1) in enumerate(zip(exact_match_scores, f1_scores)) if f1 == 0]

if perfect_matches:
    print(f"\nPerfect Matches (EM=1): {len(perfect_matches)} examples")
    for i, (idx, pred, true) in enumerate(perfect_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}'")

if partial_matches:
    print(f"\nPartial Matches (EM=0, F1>0.3): {len(partial_matches)} examples")
    for i, (idx, pred, true) in enumerate(partial_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}' | F1: {f1_scores[idx]:.3f}")

if no_matches:
    print(f"\nNo Matches (F1=0): {len(no_matches)} examples")
    for i, (idx, pred, true) in enumerate(no_matches[:3]):
        print(f"  {i+1}. Predicted: '{pred}' | True: '{true}'")

print("=" * 80)

### Custom Model and GPT Pretrained Model Weight and Output Comparison

In [None]:
# import torch
# import torch.nn as nn
# from transformers import GPT2LMHeadModel, GPT2Config
# from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2PreTrainedModel
# from transformers.modeling_utils import Conv1D

# # --------- Custom Attention Module using Conv1D ----------
# class CustomGPT2Attention(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.embed_dim = config.hidden_size
#         self.num_heads = config.num_attention_heads
#         self.head_dim = self.embed_dim // self.num_heads
#         assert self.head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

#         self.q_proj = Conv1D(self.embed_dim, self.embed_dim)
#         self.k_proj = Conv1D(self.embed_dim, self.embed_dim)
#         self.v_proj = Conv1D(self.embed_dim, self.embed_dim)
#         self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

#         self.attn_dropout = nn.Dropout(config.attn_pdrop)
#         self.resid_dropout = nn.Dropout(config.resid_pdrop)
#         self.scale = self.head_dim ** -0.5
        
#         # Register causal mask buffer (same as original GPT2)
#         max_positions = config.max_position_embeddings
#         self.register_buffer(
#             "bias",
#             torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
#                 1, 1, max_positions, max_positions
#             ),
#         )
#         self.register_buffer("masked_bias", torch.tensor(-1e4))

#     def _split_heads(self, x):
#         batch_size, seq_len, embed_dim = x.size()
#         # embed_dim should equal num_heads * head_dim
#         assert embed_dim == self.num_heads * self.head_dim, f"Embed dim {embed_dim} != num_heads * head_dim {self.num_heads * self.head_dim}"
        
#         # reshape to (batch_size, seq_len, num_heads, head_dim)
#         x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
#         # permute to (batch_size, num_heads, seq_len, head_dim)
#         return x.permute(0, 2, 1, 3)

#     def _merge_heads(self, x):
#         x = x.permute(0, 2, 1, 3).contiguous()
#         new_shape = x.size()[:-2] + (self.embed_dim,)
#         return x.view(*new_shape)

#     def forward(self, hidden_states, layer_past=None, attention_mask=None,
#                 head_mask=None, use_cache=False, output_attentions=False):
    
#         # Project to Q, K, V
#         query = self.q_proj(hidden_states)
#         key = self.k_proj(hidden_states)
#         value = self.v_proj(hidden_states)

#         # Split heads
#         query = self._split_heads(query)
#         key = self._split_heads(key)
#         value = self._split_heads(value)

#         # Handle past keys/values for generation
#         if layer_past is not None:
#             past_key, past_value = layer_past
#             key = torch.cat((past_key, key), dim=-2)
#             value = torch.cat((past_value, value), dim=-2)

#         if use_cache:
#             present = (key, value)
#         else:
#             present = None

#         # Compute attention weights
#         attn_weights = torch.matmul(query, key.transpose(-1, -2))
#         attn_weights = attn_weights * self.scale

#         # Apply causal mask
#         query_length, key_length = query.size(-2), key.size(-2)
#         causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
#         mask_value = torch.finfo(attn_weights.dtype).min
#         mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
#         attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

#         # Apply attention mask if provided
#         if attention_mask is not None:
#             attn_weights = attn_weights + attention_mask

#         # Softmax
#         attn_weights = nn.functional.softmax(attn_weights, dim=-1)
#         attn_weights = attn_weights.type(value.dtype)
#         attn_weights = self.attn_dropout(attn_weights)

#         # Apply head mask if provided
#         if head_mask is not None:
#             attn_weights = attn_weights * head_mask

#         # Apply attention to values
#         attn_output = torch.matmul(attn_weights, value)
#         attn_output = self._merge_heads(attn_output)

#         # Final projection
#         attn_output = self.c_proj(attn_output)
#         attn_output = self.resid_dropout(attn_output)

#         outputs = (attn_output, present)
#         if output_attentions:
#             outputs += (attn_weights,)

#         return outputs
# # --------- Custom GPT2 Block ----------
# class CustomGPT2Block(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
#         self.attn = CustomGPT2Attention(config)
#         self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
#         self.mlp = GPT2Block(config).mlp

#     def forward(
#         self,
#         hidden_states,
#         layer_past=None,
#         attention_mask=None,
#         head_mask=None,
#         encoder_hidden_states=None,
#         encoder_attention_mask=None,
#         use_cache=False,
#         output_attentions=False,
#     ):
#         residual = hidden_states
#         hidden_states = self.ln_1(hidden_states)

#         attn_outputs = self.attn(
#             hidden_states,
#             layer_past=layer_past,
#             attention_mask=attention_mask,
#             head_mask=head_mask,
#             use_cache=use_cache,
#             output_attentions=output_attentions,
#         )

#         attn_output = attn_outputs[0]
#         outputs = attn_outputs[1:]

#         hidden_states = residual + attn_output

#         residual = hidden_states
#         hidden_states = self.ln_2(hidden_states)
#         feed_forward_hidden_states = self.mlp(hidden_states)
#         hidden_states = residual + feed_forward_hidden_states

#         return (hidden_states,) + outputs

# # --------- Custom GPT2 Model ----------
# class CustomGPT2Model(GPT2Model):
#     def __init__(self, config):
#         super().__init__(config)
#         self.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.num_hidden_layers)])

# # --------- Custom GPT2 LM Model ----------
# class CustomGPT2LMHeadModel(GPT2PreTrainedModel):
#     def __init__(self, config):
#         super().__init__(config)
#         self.transformer = CustomGPT2Model(config)
#         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
#         self.init_weights()

#     def forward(self, input_ids=None, **kwargs):
#         transformer_outputs = self.transformer(input_ids=input_ids, **kwargs)
#         hidden_states = transformer_outputs[0]
#         lm_logits = self.lm_head(hidden_states)
#         return lm_logits

# # --------- Copy Weights from Original GPT2 Model ----------
# def copy_weights(original_model, custom_model):
#     orig_state_dict = original_model.state_dict()
#     custom_state_dict = custom_model.state_dict()

#     for name, param in orig_state_dict.items():
#         if "attn.c_attn.weight" in name:
#             layer_num = int(name.split('.')[2])
#             prefix = f'transformer.h.{layer_num}.attn.'
            
#             # Original c_attn weight shape: (embed_dim, 3 * embed_dim)
#             # Need to split along dim=1 (the 3 * embed_dim dimension)
#             embed_dim = param.shape[0]
#             q_weight, k_weight, v_weight = torch.split(param, embed_dim, dim=1)
            
#             # Conv1D weight shape is (input_dim, output_dim), no transpose needed
#             custom_state_dict[f'{prefix}q_proj.weight'].copy_(q_weight)
#             custom_state_dict[f'{prefix}k_proj.weight'].copy_(k_weight)  
#             custom_state_dict[f'{prefix}v_proj.weight'].copy_(v_weight)

#         elif "attn.c_attn.bias" in name:
#             layer_num = int(name.split('.')[2])
#             prefix = f'transformer.h.{layer_num}.attn.'
#             hidden_size = param.shape[0] // 3

#             q_bias, k_bias, v_bias = torch.split(param, hidden_size)
#             custom_state_dict[f'{prefix}q_proj.bias'].copy_(q_bias)
#             custom_state_dict[f'{prefix}k_proj.bias'].copy_(k_bias)
#             custom_state_dict[f'{prefix}v_proj.bias'].copy_(v_bias)

#         elif "attn.c_proj.weight" in name:
#             # Copy c_proj weights directly
#             layer_num = int(name.split('.')[2])
#             prefix = f'transformer.h.{layer_num}.attn.'
#             custom_state_dict[f'{prefix}c_proj.weight'].copy_(param)
            
#         else:
#             if name in custom_state_dict:
#                 custom_state_dict[name].copy_(param)

#     custom_model.load_state_dict(custom_state_dict)

# # --------- Parameter Comparison ----------
# def compare_model_parameters(original_model, custom_model):
#     orig_params = dict(original_model.named_parameters())
#     cust_params = dict(custom_model.named_parameters())

#     print("--- Comparing Non-Split Parameters ---")
#     for name, orig_param in orig_params.items():
#         if "attn.c_attn" in name or "attn.c_proj.weight" in name:
#             continue
#         if name not in cust_params:
#             print(f"Parameter {name} missing in custom model")
#             continue
#         cust_param = cust_params[name]
#         if torch.allclose(orig_param, cust_param, atol=1e-6):
#             c = 1 #print(f"Parameter {name} matches.")
#         else:
#             print(f"Parameter {name} differs!")

#     print("\n--- Comparing Split QKV Parameters ---")
#     for i in range(original_model.config.num_hidden_layers):
#         prefix = f'transformer.h.{i}.'
#         orig_w = orig_params[f'{prefix}attn.c_attn.weight']
#         orig_b = orig_params[f'{prefix}attn.c_attn.bias']
#         orig_proj_w = orig_params[f'{prefix}attn.c_proj.weight']
        
#         hidden_size = orig_w.shape[0]

#         # Split the original concatenated weights
#         orig_qw, orig_kw, orig_vw = torch.split(orig_w, hidden_size, dim=1)
#         orig_qb, orig_kb, orig_vb = torch.split(orig_b, hidden_size)

#         # Get custom model parameters
#         cust_qw = cust_params[f'{prefix}attn.q_proj.weight']
#         cust_qb = cust_params[f'{prefix}attn.q_proj.bias']
#         cust_kw = cust_params[f'{prefix}attn.k_proj.weight']
#         cust_kb = cust_params[f'{prefix}attn.k_proj.bias']
#         cust_vw = cust_params[f'{prefix}attn.v_proj.weight']
#         cust_vb = cust_params[f'{prefix}attn.v_proj.bias']
#         cust_proj_w = cust_params[f'{prefix}attn.c_proj.weight']

#         # Compare weights (no transpose needed for Conv1D)
#         if torch.allclose(orig_qw, cust_qw, atol=1e-6): c = 1#print(f"Layer {i} Q weight matches.")
#         else: print(f"Layer {i} Q weight differs!")
        
#         if torch.allclose(orig_kw, cust_kw, atol=1e-6): c =1 #print(f"Layer {i} K weight matches.")
#         else: print(f"Layer {i} K weight differs!")
        
#         if torch.allclose(orig_vw, cust_vw, atol=1e-6): c =1 #print(f"Layer {i} V weight matches.")
#         else: print(f"Layer {i} V weight differs!")
        
#         if torch.allclose(orig_proj_w, cust_proj_w, atol=1e-6): c= 1#print(f"Layer {i} c_proj weight matches.")
#         else: print(f"Layer {i} c_proj weight differs!")

#         # Compare biases (no transpose needed)
#         if torch.allclose(orig_qb, cust_qb, atol=1e-6): c= 1#print(f"Layer {i} Q bias matches.")
#         else: print(f"Layer {i} Q bias differs!")
        
#         if torch.allclose(orig_kb, cust_kb, atol=1e-6): c =1 #print(f"Layer {i} K bias matches.")
#         else: print(f"Layer {i} K bias differs!")
        
#         if torch.allclose(orig_vb, cust_vb, atol=1e-6): c = 1#print(f"Layer {i} V bias matches.")
#         else: print(f"Layer {i} V bias differs!")

# # --------- Output Comparison ----------
# def check_outputs_identical(original_model, custom_model):
#     original_model.eval()
#     custom_model.eval()

#     batch_size = 2
#     seq_len = 16
#     vocab_size = original_model.config.vocab_size
#     input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

#     with torch.no_grad():
#         orig_logits = original_model(input_ids).logits
#         cust_logits = custom_model(input_ids)

#     max_diff = torch.max(torch.abs(orig_logits - cust_logits)).item()
#     print(f"\n--- Comparing Model Outputs ---")
#     print(f"Max absolute difference between outputs: {max_diff:.10f}")# Additional debugging steps to minimize differences

# def debug_attention_step_by_step(original_model, custom_model, input_ids):
#     """Compare intermediate outputs in the first attention layer"""
#     original_model.eval()
#     custom_model.eval()
    
#     with torch.no_grad():
#         # Get embeddings (should be identical)
#         orig_embeds = original_model.transformer.wte(input_ids) + original_model.transformer.wpe(torch.arange(input_ids.size(1), device=input_ids.device))
#         cust_embeds = custom_model.transformer.wte(input_ids) + custom_model.transformer.wpe(torch.arange(input_ids.size(1), device=input_ids.device))
        
#         print("Embedding difference:", torch.max(torch.abs(orig_embeds - cust_embeds)).item())
        
#         # First layer norm
#         orig_ln1 = original_model.transformer.h[0].ln_1(orig_embeds)
#         cust_ln1 = custom_model.transformer.h[0].ln_1(cust_embeds)
        
#         print("First LayerNorm difference:", torch.max(torch.abs(orig_ln1 - cust_ln1)).item())
        
#         # QKV projections
#         orig_qkv = original_model.transformer.h[0].attn.c_attn(orig_ln1)
#         orig_q, orig_k, orig_v = orig_qkv.split(original_model.config.hidden_size, dim=2)
        
#         cust_q = custom_model.transformer.h[0].attn.q_proj(cust_ln1)
#         cust_k = custom_model.transformer.h[0].attn.k_proj(cust_ln1)
#         cust_v = custom_model.transformer.h[0].attn.v_proj(cust_ln1)
        
#         print("Q projection difference:", torch.max(torch.abs(orig_q - cust_q)).item())
#         print("K projection difference:", torch.max(torch.abs(orig_k - cust_k)).item())
#         print("V projection difference:", torch.max(torch.abs(orig_v - cust_v)).item())

# def ensure_identical_initialization():
#     """Ensure both models have identical random initialization states"""
#     torch.manual_seed(42)
#     torch.cuda.manual_seed_all(42)
#     # Set deterministic behavior
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# # Alternative: Use double precision for comparison
# def check_outputs_double_precision(original_model, custom_model):
#     """Check outputs using double precision for more accurate comparison"""
#     original_model.eval()
#     custom_model.eval()
    
#     # Convert to double precision
#     original_model = original_model.double()
#     custom_model = custom_model.double()
    
#     batch_size = 2
#     seq_len = 16
#     vocab_size = original_model.config.vocab_size
#     input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

#     with torch.no_grad():
#         orig_logits = original_model(input_ids).logits
#         cust_logits = custom_model(input_ids)

#     max_diff = torch.max(torch.abs(orig_logits - cust_logits)).item()
#     print(f"Double precision max difference: {max_diff:.15f}")
#     return max_diff
#     if max_diff < 1e-4:
#         print("Success! Outputs match within tolerance.")
#     else:
#         print("Failure! Outputs still differ.")


# config = GPT2Config.from_pretrained("gpt2")

# print("Loading original GPT-2 model...")
# original_model = GPT2LMHeadModel.from_pretrained("gpt2")

# print("Creating custom GPT-2 model with split Q,K,V...")
# custom_model = CustomGPT2LMHeadModel(config)

# print("Copying weights...")
# copy_weights(original_model, custom_model)

# print("\nComparing parameters...")
# compare_model_parameters(original_model, custom_model)

# check_outputs_identical(original_model, custom_model)


In [None]:
# # Additional debugging steps to minimize differences

# def debug_attention_step_by_step(original_model, custom_model, input_ids):
#     """Compare intermediate outputs in the first attention layer"""
#     original_model.eval()
#     custom_model.eval()
    
#     with torch.no_grad():
#         # Get embeddings (should be identical)
#         orig_embeds = original_model.transformer.wte(input_ids) + original_model.transformer.wpe(torch.arange(input_ids.size(1), device=input_ids.device))
#         cust_embeds = custom_model.transformer.wte(input_ids) + custom_model.transformer.wpe(torch.arange(input_ids.size(1), device=input_ids.device))
        
#         print("Embedding difference:", torch.max(torch.abs(orig_embeds - cust_embeds)).item())
        
#         # First layer norm
#         orig_ln1 = original_model.transformer.h[0].ln_1(orig_embeds)
#         cust_ln1 = custom_model.transformer.h[0].ln_1(cust_embeds)
        
#         print("First LayerNorm difference:", torch.max(torch.abs(orig_ln1 - cust_ln1)).item())
        
#         # QKV projections
#         orig_qkv = original_model.transformer.h[0].attn.c_attn(orig_ln1)
#         orig_q, orig_k, orig_v = orig_qkv.split(original_model.config.hidden_size, dim=2)
        
#         cust_q = custom_model.transformer.h[0].attn.q_proj(cust_ln1)
#         cust_k = custom_model.transformer.h[0].attn.k_proj(cust_ln1)
#         cust_v = custom_model.transformer.h[0].attn.v_proj(cust_ln1)
        
#         print("Q projection difference:", torch.max(torch.abs(orig_q - cust_q)).item())
#         print("K projection difference:", torch.max(torch.abs(orig_k - cust_k)).item())
#         print("V projection difference:", torch.max(torch.abs(orig_v - cust_v)).item())

# def ensure_identical_initialization():
#     """Ensure both models have identical random initialization states"""
#     torch.manual_seed(42)
#     torch.cuda.manual_seed_all(42)
#     # Set deterministic behavior
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# # Alternative: Use double precision for comparison
# def check_outputs_double_precision(original_model, custom_model):
#     """Check outputs using double precision for more accurate comparison"""
#     original_model.eval()
#     custom_model.eval()
    
#     # Convert to double precision
#     original_model = original_model.double()
#     custom_model = custom_model.double()
    
#     batch_size = 2
#     seq_len = 16
#     vocab_size = original_model.config.vocab_size
#     input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

#     with torch.no_grad():
#         orig_logits = original_model(input_ids).logits
#         cust_logits = custom_model(input_ids)

#     max_diff = torch.max(torch.abs(orig_logits - cust_logits)).item()
#     print(f"Double precision max difference: {max_diff:.15f}")
    
#     # Convert back to float
#     original_model = original_model.float()
#     custom_model = custom_model.float()
    
#     return max_diff

In [None]:
# vocab_size = original_model.config.vocab_size
# input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

# debug_attention_step_by_step(original_model, custom_model, input_ids)

In [None]:
# check_outputs_double_precision(original_model, custom_model)