# GateRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import math
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class GatingModule(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.gate_linear = nn.Linear(input_dim, 1, bias=True)
        nn.init.zeros_(self.gate_linear.weight)
        nn.init.zeros_(self.gate_linear.bias)
    
    def forward(self, x):
        gate_logits = self.gate_linear(x)
        gate_values = torch.sigmoid(gate_logits)
        return gate_values


class GateRALayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, input_dim, output_dim):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.scaling = alpha / rank
        self.dropout = dropout
        
        self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        
        self.gating_module = GatingModule(input_dim)
        
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x):
        base_output = self.base_layer(x)
        
        if x.dim() == 3:
            batch_size, seq_len, hidden_dim = x.shape
            x_2d = x.reshape(-1, hidden_dim)
        else:
            x_2d = x
            batch_size = None
            seq_len = None
        
        gate_values = self.gating_module(x_2d)
        
        lora_output = x_2d @ self.lora_A @ self.lora_B
        lora_output = self.dropout_layer(lora_output)
        
        gated_lora_output = gate_values * lora_output * self.scaling
        
        if batch_size is not None and seq_len is not None:
            gated_lora_output = gated_lora_output.reshape(batch_size, seq_len, -1)
        
        final_output = base_output + gated_lora_output
        
        return final_output


class GateRACodeSearchModel(nn.Module):
    def __init__(self, base_model_name, rank=16, alpha=16.0, dropout=0.0, entropy_reg_weight=0.01):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.d_kv = config.d_kv
        self.num_heads = config.num_heads
        self.entropy_reg_weight = entropy_reg_weight
        
        encoder = self.base_model.get_encoder()
        self.gatera_layers = nn.ModuleDict()
        
        for i, block in enumerate(encoder.block):
            q_proj = block.layer[0].SelfAttention.q
            v_proj = block.layer[0].SelfAttention.v
            
            q_gatera = GateRALayer(
                q_proj, rank=rank, alpha=alpha, dropout=dropout, 
                input_dim=self.d_model, output_dim=self.d_kv * self.num_heads
            )
            v_gatera = GateRALayer(
                v_proj, rank=rank, alpha=alpha, dropout=dropout,
                input_dim=self.d_model, output_dim=self.d_kv * self.num_heads
            )
            
            self.gatera_layers[f'encoder_q_{i}'] = q_gatera
            self.gatera_layers[f'encoder_v_{i}'] = v_gatera
            
            block.layer[0].SelfAttention.q = q_gatera
            block.layer[0].SelfAttention.v = v_gatera
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def compute_entropy_loss(self, gate_values):
        eps = 1e-8
        gate_values = torch.clamp(gate_values, eps, 1.0 - eps)
        entropy = -gate_values * torch.log(gate_values) - (1 - gate_values) * torch.log(1 - gate_values)
        return entropy.mean()
    
    def encode_with_gatera(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode_with_gatera(comment_input_ids, comment_attention_mask)
        code_emb = self.encode_with_gatera(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        contrastive_loss = F.cross_entropy(similarities, labels)
        
        if self.training and self.entropy_reg_weight > 0:
            total_entropy_loss = 0.0
            gate_count = 0
            
            for name, layer in self.gatera_layers.items():
                if hasattr(layer, 'gating_module'):
                    try:
                        dummy_input = torch.randn(
                            256,
                            layer.lora_A.shape[0],
                            device=comment_input_ids.device
                        )
                        gate_vals = layer.gating_module(dummy_input)
                        entropy_loss = self.compute_entropy_loss(gate_vals)
                        total_entropy_loss += entropy_loss
                        gate_count += 1
                    except Exception as e:
                        continue
            
            if gate_count > 0:
                avg_entropy_loss = total_entropy_loss / gate_count
                total_loss = contrastive_loss + self.entropy_reg_weight * avg_entropy_loss
            else:
                total_loss = contrastive_loss
        else:
            total_loss = contrastive_loss
        
        return total_loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode_with_gatera(code_input_ids, code_attention_mask)
            comment_emb = model.encode_with_gatera(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('/test (3).csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing GateRA Model for Code Search...")
    model = GateRACodeSearchModel(model_name, rank=16, alpha=16.0, dropout=0.0, entropy_reg_weight=0.01).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING GateRA FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"GateRA CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# TS EPFT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import math
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1.0, dropout=0.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    def forward(self, x):
        result = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        return result * self.scaling


class TSPEFTLayer(nn.Module):
    def __init__(self, base_layer, rank=4, alpha=1.0, dropout=0.0, s=4e-5, lambda_reg=1e-5, beta1=0.9, beta2=0.98, eps=1e-8):
        super().__init__()
        self.base_layer = base_layer
        self.base_layer.requires_grad_(False)
        
        self.lora = LoRALayer(
            base_layer.in_features,
            base_layer.out_features,
            rank=rank,
            alpha=alpha,
            dropout=dropout,
        )
        
        self.s = s
        self.lambda_reg = lambda_reg
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        
        self.register_buffer('tau', torch.tensor(0.0))
        self.register_buffer('m', torch.tensor(0.0))
        self.register_buffer('v', torch.tensor(0.0))
        self.register_buffer('step', torch.tensor(0))
        
    def compute_relative_magnitude(self, base_output, lora_output):
        base_norm = torch.norm(base_output, p=2, dim=-1, keepdim=True)
        lora_norm = torch.norm(lora_output, p=2, dim=-1, keepdim=True)
        r_i = lora_norm / (base_norm + self.eps)
        return r_i.squeeze(-1)
        
    def forward(self, x):
        base_output = self.base_layer(x)
        lora_output = self.lora(x)
        
        if not self.training:
            r_i = self.compute_relative_magnitude(base_output, lora_output)
            gate = (r_i >= self.tau).float().unsqueeze(-1)
            return base_output + gate * lora_output
        
        r_i = self.compute_relative_magnitude(base_output, lora_output)
        gate = (r_i >= self.tau).float()
        
        gated_output = base_output + gate.unsqueeze(-1) * lora_output
        
        self._cache_for_backward = {
            'r_i': r_i,
            'gate': gate,
            'lora_output': lora_output,
            'base_output': base_output,
        }
        
        return gated_output
    
    def compute_threshold_gradient(self, grad_output):
        if not hasattr(self, '_cache_for_backward'):
            return 0.0
            
        cache = self._cache_for_backward
        r_i = cache['r_i']
        gate = cache['gate']
        lora_output = cache['lora_output']
        
        mu_i = (grad_output * lora_output).sum(dim=-1)
        
        consistency_mask = ((mu_i >= 0).float() == gate).float()
        sparsity_mask = gate
        
        grad_loss = -self.s * (consistency_mask * mu_i).sum()
        grad_sparsity = -self.s * (sparsity_mask * self.lambda_reg).sum()
        
        g_k = grad_loss + grad_sparsity
        
        return g_k.item()
    
    def update_threshold(self, grad_output, lr=1.0):
        if not self.training:
            return
            
        g_k = self.compute_threshold_gradient(grad_output)
        
        self.step += 1
        
        self.m = self.beta1 * self.m + (1 - self.beta1) * g_k
        self.v = self.beta2 * self.v + (1 - self.beta2) * (g_k ** 2)
        
        m_hat = self.m / (1 - self.beta1 ** self.step.item())
        v_hat = self.v / (1 - self.beta2 ** self.step.item())
        
        tau_update = lr * self.s * m_hat / (torch.sqrt(v_hat) + self.eps)
        self.tau = torch.clamp(self.tau + tau_update, min=0.0)
        
        if hasattr(self, '_cache_for_backward'):
            delattr(self, '_cache_for_backward')


class TSPEFTCodeSearchModel(nn.Module):
    def __init__(self, base_model_name, rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.d_kv = config.d_kv
        self.num_heads = config.num_heads
        
        encoder = self.base_model.get_encoder()
        self.ts_peft_layers = nn.ModuleDict()
        
        for i, block in enumerate(encoder.block):
            q_proj = block.layer[0].SelfAttention.q
            v_proj = block.layer[0].SelfAttention.v
            
            self.ts_peft_layers[f'encoder_q_{i}'] = TSPEFTLayer(
                q_proj, rank=rank, alpha=alpha, dropout=dropout, s=s, lambda_reg=lambda_reg
            )
            self.ts_peft_layers[f'encoder_v_{i}'] = TSPEFTLayer(
                v_proj, rank=rank, alpha=alpha, dropout=dropout, s=s, lambda_reg=lambda_reg
            )
            
            block.layer[0].SelfAttention.q = self.ts_peft_layers[f'encoder_q_{i}']
            block.layer[0].SelfAttention.v = self.ts_peft_layers[f'encoder_v_{i}']
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def encode_with_ts_peft(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def update_thresholds(self, lr=1.0):
        for layer in self.ts_peft_layers.values():
            if hasattr(layer, '_cache_for_backward') and layer.training:
                grad_output = torch.ones_like(layer._cache_for_backward['base_output'])
                layer.update_threshold(grad_output, lr)
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode_with_ts_peft(comment_input_ids, comment_attention_mask)
        code_emb = self.encode_with_ts_peft(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        loss = F.cross_entropy(similarities, labels)
        
        return loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        
        model.update_thresholds(lr=1.0)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode_with_ts_peft(code_input_ids, code_attention_mask)
            comment_emb = model.encode_with_ts_peft(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('/test (3).csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing TS-PEFT Model for Code Search...")
    model = TSPEFTCodeSearchModel(model_name, rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING TS-PEFT FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"TS-PEFT CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# ADpater

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class Adapter(nn.Module):
    def __init__(self, input_size, bottleneck_size=64):
        super().__init__()
        self.down_project = nn.Linear(input_size, bottleneck_size)
        self.up_project = nn.Linear(bottleneck_size, input_size)
        self.activation = nn.ReLU()
        
    def forward(self, x):
        residual = x
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        return x + residual


class AdapterCodeSearchModel(nn.Module):
    def __init__(self, base_model_name, bottleneck_size=64):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        
        self.adapters = nn.ModuleList()
        for i in range(config.num_layers):
            self.adapters.append(Adapter(self.d_model, bottleneck_size))
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def encode(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
            output_hidden_states=True
        )
        
        hidden_states = outputs.hidden_states
        
        for i, adapter in enumerate(self.adapters):
            if i + 1 < len(hidden_states):
                hidden_states_list = list(hidden_states)
                hidden_states_list[i + 1] = adapter(hidden_states_list[i + 1])
                hidden_states = tuple(hidden_states_list)
        
        final_hidden = hidden_states[-1]
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(final_hidden.size()).float()
        sum_hidden = torch.sum(final_hidden * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode(comment_input_ids, comment_attention_mask)
        code_emb = self.encode(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        loss = F.cross_entropy(similarities, labels)
        
        return loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode(code_input_ids, code_attention_mask)
            comment_emb = model.encode(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('//test (3).csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing Adapter Model for Code Search...")
    model = AdapterCodeSearchModel(model_name, bottleneck_size=64).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING ADAPTER FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"ADAPTER CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# LoRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, r=8, lora_alpha=16, lora_dropout=0.05):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / r
        
        self.lora_dropout = nn.Dropout(p=lora_dropout)
        self.lora_A = nn.Parameter(torch.randn(in_features, r) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(r, out_features))
        
    def forward(self, x, original_output):
        lora_out = self.lora_dropout(x) @ self.lora_A @ self.lora_B
        return original_output + lora_out * self.scaling


class LoRACodeSearchModel(nn.Module):
    def __init__(self, base_model_name, r=8, lora_alpha=16, lora_dropout=0.05):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.d_kv = config.d_kv
        self.num_heads = config.num_heads
        
        encoder = self.base_model.get_encoder()
        self.lora_layers = nn.ModuleDict()
        
        for i, block in enumerate(encoder.block):
            q_proj = block.layer[0].SelfAttention.q
            v_proj = block.layer[0].SelfAttention.v
            
            self.lora_layers[f'encoder_q_{i}'] = LoRALayer(
                self.d_model, self.d_kv * self.num_heads, r, lora_alpha, lora_dropout
            )
            self.lora_layers[f'encoder_v_{i}'] = LoRALayer(
                self.d_model, self.d_kv * self.num_heads, r, lora_alpha, lora_dropout
            )
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def encode_with_lora(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        batch_size = input_ids.shape[0]
        device = input_ids.device
        
        hidden_states = encoder.embed_tokens(input_ids)
        
        for i, block in enumerate(encoder.block):
            attention = block.layer[0].SelfAttention
            
            original_q = attention.q(hidden_states)
            original_v = attention.v(hidden_states)
            
            q = self.lora_layers[f'encoder_q_{i}'](hidden_states, original_q)
            k = attention.k(hidden_states)
            v = self.lora_layers[f'encoder_v_{i}'](hidden_states, original_v)
            
            batch_size, seq_length, _ = hidden_states.shape
            
            q = q.view(batch_size, seq_length, self.num_heads, self.d_kv).transpose(1, 2)
            k = k.view(batch_size, seq_length, self.num_heads, self.d_kv).transpose(1, 2)
            v = v.view(batch_size, seq_length, self.num_heads, self.d_kv).transpose(1, 2)
            
            scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_kv ** 0.5)
            
            attention_mask_expanded = attention_mask[:, None, None, :].to(dtype=scores.dtype)
            attention_mask_expanded = (1.0 - attention_mask_expanded) * -10000.0
            scores = scores + attention_mask_expanded
            
            attn_weights = F.softmax(scores, dim=-1)
            attn_output = torch.matmul(attn_weights, v)
            
            attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
            attn_output = attention.o(attn_output)
            
            hidden_states = hidden_states + block.layer[0].layer_norm(attn_output)
            
            ff_output = block.layer[1](hidden_states)
            hidden_states = hidden_states + ff_output
        
        hidden_states = encoder.final_layer_norm(hidden_states)
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode_with_lora(comment_input_ids, comment_attention_mask)
        code_emb = self.encode_with_lora(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        loss = F.cross_entropy(similarities, labels)
        
        return loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode_with_lora(code_input_ids, code_attention_mask)
            comment_emb = model.encode_with_lora(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('/test (3).csv', tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing LoRA Model for Code Search...")
    model = LoRACodeSearchModel(model_name, r=8, lora_alpha=16, lora_dropout=0.05).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING LORA FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"LORA CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# Bitfit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class BitFitCodeSearchModel(nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for name, param in self.base_model.named_parameters():
            if 'bias' not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True
        
        config = self.base_model.config
        self.d_model = config.d_model
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def encode(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode(comment_input_ids, comment_attention_mask)
        code_emb = self.encode(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        loss = F.cross_entropy(similarities, labels)
        
        return loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode(code_input_ids, code_attention_mask)
            comment_emb = model.encode(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('/test (3).csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing BitFit Model for Code Search...")
    model = BitFitCodeSearchModel(model_name).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING BITFIT FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"BITFIT CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# Prefix

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')


class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        comment = str(row['comment'])
        code = str(row['code'])
        
        comment_encoding = self.tokenizer(
            comment,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        code_encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'comment_input_ids': comment_encoding['input_ids'].squeeze(0),
            'comment_attention_mask': comment_encoding['attention_mask'].squeeze(0),
            'code_input_ids': code_encoding['input_ids'].squeeze(0),
            'code_attention_mask': code_encoding['attention_mask'].squeeze(0)
        }


class PrefixEncoder(nn.Module):
    def __init__(self, prefix_length, num_layers, num_heads, head_dim, prefix_hidden_size=512):
        super().__init__()
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        self.embedding = nn.Embedding(prefix_length, prefix_hidden_size)
        
        self.trans = nn.Sequential(
            nn.Linear(prefix_hidden_size, prefix_hidden_size),
            nn.Tanh(),
            nn.Linear(prefix_hidden_size, num_layers * 2 * num_heads * head_dim)
        )
        
    def forward(self, batch_size):
        prefix_tokens = torch.arange(self.prefix_length).to(self.embedding.weight.device)
        prefix_tokens = prefix_tokens.unsqueeze(0).expand(batch_size, -1)
        
        past_key_values = self.trans(self.embedding(prefix_tokens))
        
        past_key_values = past_key_values.view(
            batch_size,
            self.prefix_length,
            self.num_layers * 2,
            self.num_heads,
            self.head_dim
        )
        
        past_key_values = past_key_values.permute(2, 0, 3, 1, 4)
        
        past_key_values_list = past_key_values.split(2)
        
        return past_key_values_list


class PrefixCodeSearchModel(nn.Module):
    def __init__(self, base_model_name, prefix_length=10, prefix_hidden_size=512):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.num_encoder_layers = config.num_layers
        self.num_decoder_layers = config.num_decoder_layers
        self.num_heads = config.num_heads
        self.head_dim = config.d_kv
        
        self.prefix_length = prefix_length
        
        self.encoder_prefix = PrefixEncoder(
            prefix_length=prefix_length,
            num_layers=self.num_encoder_layers,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
            prefix_hidden_size=prefix_hidden_size
        )
        
        self.encoder_prefix_embeds = nn.Parameter(torch.randn(1, prefix_length, self.d_model) * 0.02)
        
        self.projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh()
        )
    
    def encode_with_prefix(self, input_ids, attention_mask):
        batch_size = input_ids.shape[0]
        device = input_ids.device
        
        encoder = self.base_model.get_encoder()
        
        embeddings = encoder.embed_tokens(input_ids)
        
        prefix_embeds = self.encoder_prefix_embeds.expand(batch_size, -1, -1)
        
        inputs_embeds = torch.cat([prefix_embeds, embeddings], dim=1)
        
        prefix_attention_mask = torch.ones(batch_size, self.prefix_length).to(device)
        
        extended_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)
        
        outputs = encoder(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state[:, self.prefix_length:, :]
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        projected = self.projection(pooled)
        
        return projected
    
    def forward(self, comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask):
        comment_emb = self.encode_with_prefix(comment_input_ids, comment_attention_mask)
        code_emb = self.encode_with_prefix(code_input_ids, code_attention_mask)
        
        comment_emb = F.normalize(comment_emb, p=2, dim=1)
        code_emb = F.normalize(code_emb, p=2, dim=1)
        
        batch_size = comment_emb.shape[0]
        
        similarities = torch.matmul(comment_emb, code_emb.T)
        
        labels = torch.arange(batch_size).to(comment_emb.device)
        
        temperature = 0.07
        similarities = similarities / temperature
        
        loss = F.cross_entropy(similarities, labels)
        
        return loss


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        comment_input_ids = batch['comment_input_ids'].to(device)
        comment_attention_mask = batch['comment_attention_mask'].to(device)
        code_input_ids = batch['code_input_ids'].to(device)
        code_attention_mask = batch['code_attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        loss = model(comment_input_ids, comment_attention_mask, code_input_ids, code_attention_mask)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate_code_search(model, dataloader, device):
    model.eval()
    
    all_code_embeddings = []
    all_comment_embeddings = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            code_input_ids = batch['code_input_ids'].to(device)
            code_attention_mask = batch['code_attention_mask'].to(device)
            comment_input_ids = batch['comment_input_ids'].to(device)
            comment_attention_mask = batch['comment_attention_mask'].to(device)
            
            code_emb = model.encode_with_prefix(code_input_ids, code_attention_mask)
            comment_emb = model.encode_with_prefix(comment_input_ids, comment_attention_mask)
            
            all_code_embeddings.append(code_emb)
            all_comment_embeddings.append(comment_emb)
        
        all_code_embeddings = torch.cat(all_code_embeddings, dim=0)
        all_comment_embeddings = torch.cat(all_comment_embeddings, dim=0)
    
    end_time = time.time()
    total_time = end_time - start_time
    num_samples = all_code_embeddings.shape[0]
    total_tokens = num_samples * 256 * 2
    tokens_per_second = total_tokens / total_time if total_time > 0 else 0
    
    return all_comment_embeddings, all_code_embeddings, tokens_per_second


def compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10]):
    num_queries = comment_embeddings.shape[0]
    
    comment_embeddings = F.normalize(comment_embeddings, p=2, dim=1)
    code_embeddings = F.normalize(code_embeddings, p=2, dim=1)
    
    similarities = torch.matmul(comment_embeddings, code_embeddings.T)
    
    recall_at_k = {k: 0.0 for k in top_k_values}
    reciprocal_ranks = []
    
    for i in range(num_queries):
        sim_scores = similarities[i]
        
        sorted_indices = torch.argsort(sim_scores, descending=True)
        
        true_index = i
        rank_position = (sorted_indices == true_index).nonzero(as_tuple=True)[0]
        
        if len(rank_position) > 0:
            rank = rank_position[0].item() + 1
        else:
            rank = num_queries + 1
        
        reciprocal_ranks.append(1.0 / rank)
        
        for k in top_k_values:
            if rank <= k:
                recall_at_k[k] += 1
    
    for k in top_k_values:
        recall_at_k[k] = recall_at_k[k] / num_queries
    
    mrr = np.mean(reciprocal_ranks)
    
    metrics = {
        'MRR': mrr,
        'Recall@1': recall_at_k[1],
        'Recall@5': recall_at_k[5],
        'Recall@10': recall_at_k[10]
    }
    
    return metrics


def print_code_search_metrics(metrics, tokens_per_second, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} CODE SEARCH METRICS")
    print(f"{'='*80}\n")
    
    print(f"Ranking Metrics:")
    print(f"  MRR (Mean Reciprocal Rank):  {metrics['MRR']:.4f}")
    print(f"  Recall@1:                    {metrics['Recall@1']:.4f}")
    print(f"  Recall@5:                    {metrics['Recall@5']:.4f}")
    print(f"  Recall@10:                   {metrics['Recall@10']:.4f}")
    
    print(f"\nGeneration Speed:")
    print(f"  Tokens per Second:           {tokens_per_second:.2f}")
    
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeSearchDataset('/train (1).csv', tokenizer)
    test_dataset = CodeSearchDataset('/test (3).csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\nInitializing Prefix Tuning Model for Code Search...")
    model = PrefixCodeSearchModel(model_name, prefix_length=10, prefix_hidden_size=512).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING PREFIX TUNING FOR CODE SEARCH")
    print(f"{'='*80}\n")
    
    best_mrr = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
            model, test_loader, device
        )
        
        metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val MRR:    {metrics['MRR']:.4f} | Recall@1: {metrics['Recall@1']:.4f} | Recall@5: {metrics['Recall@5']:.4f} | Recall@10: {metrics['Recall@10']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['MRR'] > best_mrr:
            best_mrr = metrics['MRR']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    comment_embeddings, code_embeddings, tokens_per_second = evaluate_code_search(
        model, test_loader, device
    )
    
    final_metrics = compute_code_search_metrics_fast(comment_embeddings, code_embeddings, top_k_values=[1, 5, 10])
    
    print_code_search_metrics(final_metrics, tokens_per_second, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"PREFIX TUNING CODE SEARCH COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()