# ts peft

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import math
import warnings
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_recall_fscore_support,
    roc_auc_score, average_precision_score, confusion_matrix,
    matthews_corrcoef, cohen_kappa_score
)
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row['func'])
        label = int(row['label'])
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1.0, dropout=0.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    def forward(self, x):
        result = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        return result * self.scaling


class TSPEFTLayer(nn.Module):
    def __init__(self, base_layer, rank=4, alpha=1.0, dropout=0.0, s=4e-5, lambda_reg=1e-5, beta1=0.9, beta2=0.98, eps=1e-8):
        super().__init__()
        self.base_layer = base_layer
        self.base_layer.requires_grad_(False)
        
        self.lora = LoRALayer(
            base_layer.in_features,
            base_layer.out_features,
            rank=rank,
            alpha=alpha,
            dropout=dropout,
        )
        
        self.s = s
        self.lambda_reg = lambda_reg
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        
        self.register_buffer('tau', torch.tensor(0.0))
        self.register_buffer('m', torch.tensor(0.0))
        self.register_buffer('v', torch.tensor(0.0))
        self.register_buffer('step', torch.tensor(0))
        
    def compute_relative_magnitude(self, base_output, lora_output):
        base_norm = torch.norm(base_output, p=2, dim=-1, keepdim=True)
        lora_norm = torch.norm(lora_output, p=2, dim=-1, keepdim=True)
        r_i = lora_norm / (base_norm + self.eps)
        return r_i.squeeze(-1)
        
    def forward(self, x):
        base_output = self.base_layer(x)
        lora_output = self.lora(x)
        
        if not self.training:
            r_i = self.compute_relative_magnitude(base_output, lora_output)
            gate = (r_i >= self.tau).float().unsqueeze(-1)
            return base_output + gate * lora_output
        
        r_i = self.compute_relative_magnitude(base_output, lora_output)
        gate = (r_i >= self.tau).float()
        
        gated_output = base_output + gate.unsqueeze(-1) * lora_output
        
        self._cache_for_backward = {
            'r_i': r_i,
            'gate': gate,
            'lora_output': lora_output,
            'base_output': base_output,
        }
        
        return gated_output
    
    def compute_threshold_gradient(self, grad_output):
        if not hasattr(self, '_cache_for_backward'):
            return 0.0
            
        cache = self._cache_for_backward
        r_i = cache['r_i']
        gate = cache['gate']
        lora_output = cache['lora_output']
        
        mu_i = (grad_output * lora_output).sum(dim=-1)
        
        consistency_mask = ((mu_i >= 0).float() == gate).float()
        sparsity_mask = gate
        
        grad_loss = -self.s * (consistency_mask * mu_i).sum()
        grad_sparsity = -self.s * (sparsity_mask * self.lambda_reg).sum()
        
        g_k = grad_loss + grad_sparsity
        
        return g_k.item()
    
    def update_threshold(self, grad_output, lr=1.0):
        if not self.training:
            return
            
        g_k = self.compute_threshold_gradient(grad_output)
        
        self.step += 1
        
        self.m = self.beta1 * self.m + (1 - self.beta1) * g_k
        self.v = self.beta2 * self.v + (1 - self.beta2) * (g_k ** 2)
        
        m_hat = self.m / (1 - self.beta1 ** self.step.item())
        v_hat = self.v / (1 - self.beta2 ** self.step.item())
        
        tau_update = lr * self.s * m_hat / (torch.sqrt(v_hat) + self.eps)
        self.tau = torch.clamp(self.tau + tau_update, min=0.0)
        
        if hasattr(self, '_cache_for_backward'):
            delattr(self, '_cache_for_backward')


class TSPEFTVulnerabilityModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2, rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        
        encoder = self.base_model.get_encoder()
        self.ts_peft_layers = nn.ModuleDict()
        
        for i, block in enumerate(encoder.block):
            q_proj = block.layer[0].SelfAttention.q
            v_proj = block.layer[0].SelfAttention.v
            
            self.ts_peft_layers[f'encoder_q_{i}'] = TSPEFTLayer(
                q_proj, rank=rank, alpha=alpha, dropout=dropout, s=s, lambda_reg=lambda_reg
            )
            self.ts_peft_layers[f'encoder_v_{i}'] = TSPEFTLayer(
                v_proj, rank=rank, alpha=alpha, dropout=dropout, s=s, lambda_reg=lambda_reg
            )
            
            block.layer[0].SelfAttention.q = self.ts_peft_layers[f'encoder_q_{i}']
            block.layer[0].SelfAttention.v = self.ts_peft_layers[f'encoder_v_{i}']
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(self.d_model, num_classes)
        )
    
    def encode_with_ts_peft(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        return pooled
    
    def update_thresholds(self, lr=1.0):
        for layer in self.ts_peft_layers.values():
            if hasattr(layer, '_cache_for_backward') and layer.training:
                grad_output = torch.ones_like(layer._cache_for_backward['base_output'])
                layer.update_threshold(grad_output, lr)
    
    def forward(self, input_ids, attention_mask, labels=None):
        encoded = self.encode_with_ts_peft(input_ids, attention_mask)
        logits = self.classifier(encoded)
        
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            return loss, logits
        
        return logits


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        loss, _ = model(input_ids, attention_mask, labels)
        
        loss.backward()
        
        model.update_thresholds(lr=1.0)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate(model, dataloader, device):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            probs = F.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    return all_labels, all_preds, all_probs


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing TS-PEFT Model for Vulnerability Detection...")
    model = TSPEFTVulnerabilityModel(model_name, num_classes=2, rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING TS-PEFT FOR VULNERABILITY DETECTION")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        labels, preds, probs = evaluate(model, test_loader, device)
        metrics = compute_comprehensive_metrics(labels, preds, probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Acc: {metrics['accuracy']:.4f} | F1: {metrics['f1_binary']:.4f} | AUC: {metrics['roc_auc_binary']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['f1_binary'] > best_f1:
            best_f1 = metrics['f1_binary']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    labels, preds, probs = evaluate(model, test_loader, device)
    final_metrics = compute_comprehensive_metrics(labels, preds, probs)
    
    print_comprehensive_metrics(final_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"TS-PEFT VULNERABILITY DETECTION COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# L:oRa

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score, 
    confusion_matrix, classification_report, average_precision_score,
    matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score
)
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row.iloc[0])
        label = int(row.iloc[1]) if len(row) > 1 else 0
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=8, alpha=16):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = self.alpha / self.rank
        
        self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
        
        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    def forward(self, x):
        return (x @ self.lora_A @ self.lora_B) * self.scaling


class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank=8, alpha=16):
        super().__init__()
        self.linear = linear_layer
        self.lora = LoRALayer(
            linear_layer.in_features,
            linear_layer.out_features,
            rank=rank,
            alpha=alpha
        )
        
        for param in self.linear.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        return self.linear(x) + self.lora(x)


class LoRATuningModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2, rank=8, alpha=16):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        self.d_model = self.base_model.config.d_model
        
        self._apply_lora_to_model(rank, alpha)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def _apply_lora_to_model(self, rank, alpha):
        encoder = self.base_model.get_encoder()
        
        for block in encoder.block:
            attention = block.layer[0].SelfAttention
            
            attention.q = LoRALinear(attention.q, rank=rank, alpha=alpha)
            attention.v = LoRALinear(attention.v, rank=rank, alpha=alpha)
    
    def forward(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(pooled_output)
        
        return logits


def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss, np.array(all_labels), np.array(all_preds), np.array(all_probs)


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing LoRA Tuning Model...")
    model = LoRATuningModel(model_name, num_classes=2, rank=8, alpha=16).to(device)
    
    trainable_params, total_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING LORA TUNING")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        val_loss, val_labels, val_preds, val_probs = evaluate(
            model, test_loader, criterion, device
        )
        
        val_metrics = compute_comprehensive_metrics(val_labels, val_preds, val_probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1 (Macro): {val_metrics['f1_macro']:.4f}")
        print(f"{'-'*80}")
        
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    test_loss, test_labels, test_preds, test_probs = evaluate(
        model, test_loader, criterion, device
    )
    
    test_metrics = compute_comprehensive_metrics(test_labels, test_preds, test_probs)
    
    print_comprehensive_metrics(test_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"LORA TUNING COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# GateRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from tqdm import tqdm
import numpy as np
import math
import warnings
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_recall_fscore_support,
    roc_auc_score, average_precision_score, confusion_matrix,
    matthews_corrcoef, cohen_kappa_score
)
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row['func'])
        label = int(row['label'])
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class GatingModule(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.gate_linear = nn.Linear(input_dim, 1, bias=True)
        nn.init.zeros_(self.gate_linear.weight)
        nn.init.zeros_(self.gate_linear.bias)
    
    def forward(self, x):
        gate_logits = self.gate_linear(x)
        gate_values = torch.sigmoid(gate_logits)
        return gate_values


class GateRALayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, input_dim, output_dim):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.scaling = alpha / rank
        self.dropout = dropout
        
        self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        
        self.gating_module = GatingModule(input_dim)
        
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x):
        base_output = self.base_layer(x)
        
        if x.dim() == 3:
            batch_size, seq_len, hidden_dim = x.shape
            x_2d = x.reshape(-1, hidden_dim)
        else:
            x_2d = x
            batch_size = None
            seq_len = None
        
        gate_values = self.gating_module(x_2d)
        
        lora_output = x_2d @ self.lora_A @ self.lora_B
        lora_output = self.dropout_layer(lora_output)
        
        gated_lora_output = gate_values * lora_output * self.scaling
        
        if batch_size is not None and seq_len is not None:
            gated_lora_output = gated_lora_output.reshape(batch_size, seq_len, -1)
        
        final_output = base_output + gated_lora_output
        
        return final_output


class GateRAVulnerabilityModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2, rank=16, alpha=16.0, dropout=0.0, entropy_reg_weight=0.01):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.d_kv = config.d_kv
        self.num_heads = config.num_heads
        self.entropy_reg_weight = entropy_reg_weight
        
        encoder = self.base_model.get_encoder()
        self.gatera_layers = nn.ModuleDict()
        
        for i, block in enumerate(encoder.block):
            q_proj = block.layer[0].SelfAttention.q
            v_proj = block.layer[0].SelfAttention.v
            
            q_gatera = GateRALayer(
                q_proj, rank=rank, alpha=alpha, dropout=dropout, 
                input_dim=self.d_model, output_dim=self.d_kv * self.num_heads
            )
            v_gatera = GateRALayer(
                v_proj, rank=rank, alpha=alpha, dropout=dropout,
                input_dim=self.d_model, output_dim=self.d_kv * self.num_heads
            )
            
            self.gatera_layers[f'encoder_q_{i}'] = q_gatera
            self.gatera_layers[f'encoder_v_{i}'] = v_gatera
            
            block.layer[0].SelfAttention.q = q_gatera
            block.layer[0].SelfAttention.v = v_gatera
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(self.d_model, num_classes)
        )
    
    def compute_entropy_loss(self, gate_values):
        eps = 1e-8
        gate_values = torch.clamp(gate_values, eps, 1.0 - eps)
        entropy = -gate_values * torch.log(gate_values) - (1 - gate_values) * torch.log(1 - gate_values)
        return entropy.mean()
    
    def encode_with_gatera(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_hidden / sum_mask
        
        return pooled
    
    def forward(self, input_ids, attention_mask, labels=None):
        encoded = self.encode_with_gatera(input_ids, attention_mask)
        logits = self.classifier(encoded)
        
        if labels is not None:
            classification_loss = F.cross_entropy(logits, labels)
            
            if self.training and self.entropy_reg_weight > 0:
                total_entropy_loss = 0.0
                gate_count = 0
                
                for name, layer in self.gatera_layers.items():
                    if hasattr(layer, 'gating_module'):
                        try:
                            dummy_input = torch.randn(
                                256,
                                layer.lora_A.shape[0],
                                device=input_ids.device
                            )
                            gate_vals = layer.gating_module(dummy_input)
                            entropy_loss = self.compute_entropy_loss(gate_vals)
                            total_entropy_loss += entropy_loss
                            gate_count += 1
                        except Exception as e:
                            continue
                
                if gate_count > 0:
                    avg_entropy_loss = total_entropy_loss / gate_count
                    total_loss = classification_loss + self.entropy_reg_weight * avg_entropy_loss
                else:
                    total_loss = classification_loss
            else:
                total_loss = classification_loss
            
            return total_loss, logits
        
        return logits


def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        loss, _ = model(input_ids, attention_mask, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    return avg_loss


def evaluate(model, dataloader, device):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            probs = F.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    return all_labels, all_preds, all_probs


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    frozen = total - trainable
    return trainable, total, frozen


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing GateRA Model for Vulnerability Detection...")
    model = GateRAVulnerabilityModel(model_name, num_classes=2, rank=16, alpha=16.0, dropout=0.0, entropy_reg_weight=0.01).to(device)
    
    trainable_params, total_params, frozen_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Frozen Parameters:     {frozen_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING GateRA FOR VULNERABILITY DETECTION")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        labels, preds, probs = evaluate(model, test_loader, device)
        metrics = compute_comprehensive_metrics(labels, preds, probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Acc: {metrics['accuracy']:.4f} | F1: {metrics['f1_binary']:.4f} | AUC: {metrics['roc_auc_binary']:.4f}")
        print(f"{'-'*80}")
        
        if metrics['f1_binary'] > best_f1:
            best_f1 = metrics['f1_binary']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    labels, preds, probs = evaluate(model, test_loader, device)
    final_metrics = compute_comprehensive_metrics(labels, preds, probs)
    
    print_comprehensive_metrics(final_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"GateRA VULNERABILITY DETECTION COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# prefix

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score, 
    confusion_matrix, classification_report, average_precision_score,
    matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score
)
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row.iloc[0])
        label = int(row.iloc[1]) if len(row) > 1 else 0
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class PrefixEncoder(nn.Module):
    def __init__(self, prefix_length, num_layers, num_heads, head_dim, prefix_hidden_size=512):
        super().__init__()
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        self.embedding = nn.Embedding(prefix_length, prefix_hidden_size)
        
        self.trans = nn.Sequential(
            nn.Linear(prefix_hidden_size, prefix_hidden_size),
            nn.Tanh(),
            nn.Linear(prefix_hidden_size, num_layers * 2 * num_heads * head_dim)
        )
        
    def forward(self, batch_size):
        prefix_tokens = torch.arange(self.prefix_length).to(self.embedding.weight.device)
        prefix_tokens = prefix_tokens.unsqueeze(0).expand(batch_size, -1)
        
        past_key_values = self.trans(self.embedding(prefix_tokens))
        
        past_key_values = past_key_values.view(
            batch_size,
            self.prefix_length,
            self.num_layers * 2,
            self.num_heads,
            self.head_dim
        )
        
        past_key_values = past_key_values.permute(2, 0, 3, 1, 4)
        
        past_key_values_list = past_key_values.split(2)
        
        return past_key_values_list


class PrefixTuningModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2, prefix_length=10, prefix_hidden_size=512):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.head_dim = config.d_kv
        
        self.prefix_length = prefix_length
        
        self.prefix_encoder = PrefixEncoder(
            prefix_length=prefix_length,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
            prefix_hidden_size=prefix_hidden_size
        )
        
        self.dropout = nn.Dropout(config.dropout_rate)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def get_prompt(self, batch_size):
        past_key_values = self.prefix_encoder(batch_size)
        
        past_key_values_formatted = []
        for layer_past in past_key_values:
            key_past = layer_past[0]
            value_past = layer_past[1]
            past_key_values_formatted.append((key_past, value_past))
        
        return tuple(past_key_values_formatted)
    
    def forward(self, input_ids, attention_mask):
        batch_size = input_ids.shape[0]
        
        past_key_values = self.get_prompt(batch_size)
        
        prefix_attention_mask = torch.ones(
            batch_size, self.prefix_length
        ).to(attention_mask.device)
        
        attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)
        
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            return_dict=True
        )
        
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        logits = self.classifier(pooled_output)
        
        return logits


def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss, np.array(all_labels), np.array(all_preds), np.array(all_probs)


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing Prefix Tuning Model...")
    model = PrefixTuningModel(model_name, num_classes=2, prefix_length=10, prefix_hidden_size=512).to(device)
    
    trainable_params, total_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING PREFIX TUNING")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        val_loss, val_labels, val_preds, val_probs = evaluate(
            model, test_loader, criterion, device
        )
        
        val_metrics = compute_comprehensive_metrics(val_labels, val_preds, val_probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1 (Macro): {val_metrics['f1_macro']:.4f}")
        print(f"{'-'*80}")
        
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    test_loss, test_labels, test_preds, test_probs = evaluate(
        model, test_loader, criterion, device
    )
    
    test_metrics = compute_comprehensive_metrics(test_labels, test_preds, test_probs)
    
    print_comprehensive_metrics(test_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"PREFIX TUNING COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# Adpater tuning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score, 
    confusion_matrix, classification_report, average_precision_score,
    matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score
)
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row.iloc[0])
        label = int(row.iloc[1]) if len(row) > 1 else 0
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class AdapterModule(nn.Module):
    def __init__(self, input_dim, bottleneck_dim, non_linearity='gelu'):
        super().__init__()
        self.down_project = nn.Linear(input_dim, bottleneck_dim)
        self.up_project = nn.Linear(bottleneck_dim, input_dim)
        
        if non_linearity == 'gelu':
            self.activation = nn.GELU()
        elif non_linearity == 'relu':
            self.activation = nn.ReLU()
        else:
            self.activation = nn.Tanh()
        
        nn.init.normal_(self.down_project.weight, std=1e-3)
        nn.init.normal_(self.up_project.weight, std=1e-3)
        nn.init.zeros_(self.down_project.bias)
        nn.init.zeros_(self.up_project.bias)
    
    def forward(self, x):
        residual = x
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        return x + residual


class AdapterTuningModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2, reduction_factor=16, non_linearity='gelu'):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        config = self.base_model.config
        self.d_model = config.d_model
        self.num_layers = config.num_layers
        self.reduction_factor = reduction_factor
        
        self.bottleneck_dim = self.d_model // reduction_factor
        
        self.adapters = nn.ModuleList()
        for _ in range(self.num_layers * 2):
            adapter = AdapterModule(self.d_model, self.bottleneck_dim, non_linearity)
            self.adapters.append(adapter)
        
        self.dropout = nn.Dropout(config.dropout_rate)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        all_hidden_states = outputs.hidden_states
        
        hidden_states = all_hidden_states[0]
        
        adapter_idx = 0
        for layer_idx in range(1, len(all_hidden_states)):
            layer_hidden = all_hidden_states[layer_idx]
            layer_hidden = self.adapters[adapter_idx](layer_hidden)
            adapter_idx += 1
            
            if adapter_idx < len(self.adapters):
                layer_hidden = self.adapters[adapter_idx](layer_hidden)
                adapter_idx += 1
            
            hidden_states = layer_hidden
        
        pooled_output = hidden_states[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        logits = self.classifier(pooled_output)
        
        return logits


def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss, np.array(all_labels), np.array(all_preds), np.array(all_probs)


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing Adapter Tuning Model...")
    model = AdapterTuningModel(model_name, num_classes=2, reduction_factor=16, non_linearity='gelu').to(device)
    
    trainable_params, total_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING ADAPTER TUNING")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        val_loss, val_labels, val_preds, val_probs = evaluate(
            model, test_loader, criterion, device
        )
        
        val_metrics = compute_comprehensive_metrics(val_labels, val_preds, val_probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1 (Macro): {val_metrics['f1_macro']:.4f}")
        print(f"{'-'*80}")
        
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    test_loss, test_labels, test_preds, test_probs = evaluate(
        model, test_loader, criterion, device
    )
    
    test_metrics = compute_comprehensive_metrics(test_labels, test_preds, test_probs)
    
    print_comprehensive_metrics(test_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"ADAPTER TUNING COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# BitFit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score, 
    confusion_matrix, classification_report, average_precision_score,
    matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score
)
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings('ignore')


class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        code = str(row.iloc[0])
        label = int(row.iloc[1]) if len(row) > 1 else 0
        
        encoding = self.tokenizer(
            code,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class BitFitTuningModel(nn.Module):
    def __init__(self, base_model_name, num_classes=2):
        super().__init__()
        
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        
        for name, param in self.base_model.named_parameters():
            param.requires_grad = False
        
        self._enable_bias_training()
        
        config = self.base_model.config
        self.d_model = config.d_model
        
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def _enable_bias_training(self):
        for name, param in self.base_model.named_parameters():
            if 'bias' in name:
                param.requires_grad = True
        
        encoder = self.base_model.get_encoder()
        if hasattr(encoder, 'embed_tokens'):
            if encoder.embed_tokens.weight is not None:
                encoder.embed_tokens.weight.requires_grad = False
        
        for block in encoder.block:
            for layer in block.layer:
                if hasattr(layer, 'layer_norm'):
                    if layer.layer_norm.weight is not None:
                        layer.layer_norm.weight.requires_grad = True
        
        if hasattr(encoder, 'final_layer_norm'):
            if encoder.final_layer_norm.weight is not None:
                encoder.final_layer_norm.weight.requires_grad = True
    
    def forward(self, input_ids, attention_mask):
        encoder = self.base_model.get_encoder()
        
        outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(pooled_output)
        
        return logits


def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss, np.array(all_labels), np.array(all_preds), np.array(all_probs)


def compute_comprehensive_metrics(labels, preds, probs):
    metrics = {}
    
    metrics['accuracy'] = accuracy_score(labels, preds)
    metrics['balanced_accuracy'] = balanced_accuracy_score(labels, preds)
    
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    metrics['precision_macro'] = prec_macro
    metrics['recall_macro'] = rec_macro
    metrics['f1_macro'] = f1_macro
    
    prec_binary, rec_binary, f1_binary, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )
    metrics['precision_binary'] = prec_binary
    metrics['recall_binary'] = rec_binary
    metrics['f1_binary'] = f1_binary
    
    prec_weighted, rec_weighted, f1_weighted, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )
    metrics['precision_weighted'] = prec_weighted
    metrics['recall_weighted'] = rec_weighted
    metrics['f1_weighted'] = f1_weighted
    
    prec_per_class, rec_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )
    metrics['precision_per_class'] = prec_per_class
    metrics['recall_per_class'] = rec_per_class
    metrics['f1_per_class'] = f1_per_class
    metrics['support_per_class'] = support
    
    try:
        metrics['roc_auc_binary'] = roc_auc_score(labels, probs[:, 1])
    except:
        metrics['roc_auc_binary'] = 0.5
    
    try:
        metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
    except:
        metrics['roc_auc_macro'] = 0.5
    
    try:
        metrics['pr_auc'] = average_precision_score(labels, probs[:, 1])
    except:
        metrics['pr_auc'] = 0.5
    
    metrics['confusion_matrix'] = confusion_matrix(labels, preds)
    
    metrics['mcc'] = matthews_corrcoef(labels, preds)
    
    metrics['cohen_kappa'] = cohen_kappa_score(labels, preds)
    
    tn, fp, fn, tp = metrics['confusion_matrix'].ravel()
    metrics['true_negatives'] = tn
    metrics['false_positives'] = fp
    metrics['false_negatives'] = fn
    metrics['true_positives'] = tp
    
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
    metrics['fnr'] = fn / (fn + tp) if (fn + tp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0
    metrics['fdr'] = fp / (fp + tp) if (fp + tp) > 0 else 0
    
    return metrics


def print_comprehensive_metrics(metrics, phase="Test"):
    print(f"\n{'='*80}")
    print(f"{phase.upper()} EVALUATION METRICS")
    print(f"{'='*80}\n")
    
    print(f"Overall Metrics:")
    print(f"  Accuracy:                {metrics['accuracy']:.4f}")
    print(f"  Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
    print(f"  Matthews Correlation:    {metrics['mcc']:.4f}")
    print(f"  Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
    
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision (Macro):       {metrics['precision_macro']:.4f}")
    print(f"  Recall (Macro):          {metrics['recall_macro']:.4f}")
    print(f"  F1-Score (Macro):        {metrics['f1_macro']:.4f}")
    print(f"  ROC-AUC (Macro):         {metrics['roc_auc_macro']:.4f}")
    
    print(f"\nBinary Metrics:")
    print(f"  Precision (Binary):      {metrics['precision_binary']:.4f}")
    print(f"  Recall (Binary):         {metrics['recall_binary']:.4f}")
    print(f"  F1-Score (Binary):       {metrics['f1_binary']:.4f}")
    print(f"  ROC-AUC (Binary):        {metrics['roc_auc_binary']:.4f}")
    print(f"  PR-AUC:                  {metrics['pr_auc']:.4f}")
    
    print(f"\nWeighted Metrics:")
    print(f"  Precision (Weighted):    {metrics['precision_weighted']:.4f}")
    print(f"  Recall (Weighted):       {metrics['recall_weighted']:.4f}")
    print(f"  F1-Score (Weighted):     {metrics['f1_weighted']:.4f}")
    
    print(f"\nPer-Class Metrics:")
    for i in range(len(metrics['precision_per_class'])):
        print(f"  Class {i}:")
        print(f"    Precision:  {metrics['precision_per_class'][i]:.4f}")
        print(f"    Recall:     {metrics['recall_per_class'][i]:.4f}")
        print(f"    F1-Score:   {metrics['f1_per_class'][i]:.4f}")
        print(f"    Support:    {metrics['support_per_class'][i]}")
    
    print(f"\nConfusion Matrix Components:")
    print(f"  True Positives:          {metrics['true_positives']}")
    print(f"  True Negatives:          {metrics['true_negatives']}")
    print(f"  False Positives:         {metrics['false_positives']}")
    print(f"  False Negatives:         {metrics['false_negatives']}")
    
    print(f"\nAdditional Binary Metrics:")
    print(f"  Sensitivity (TPR):       {metrics['sensitivity']:.4f}")
    print(f"  Specificity (TNR):       {metrics['specificity']:.4f}")
    print(f"  False Positive Rate:     {metrics['fpr']:.4f}")
    print(f"  False Negative Rate:     {metrics['fnr']:.4f}")
    print(f"  Negative Pred. Value:    {metrics['npv']:.4f}")
    print(f"  False Discovery Rate:    {metrics['fdr']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    print(f"{'='*80}")


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f'Device: {device}')
    
    model_name = 'Salesforce/codet5p-220m'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    train_dataset = CodeDataset('/traincodex.csv', tokenizer)
    test_dataset = CodeDataset('/testcodex.csv', tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("\nInitializing BitFit Tuning Model...")
    model = BitFitTuningModel(model_name, num_classes=2).to(device)
    
    trainable_params, total_params = count_parameters(model)
    
    print(f"\n{'='*80}")
    print(f"MODEL PARAMETERS")
    print(f"{'='*80}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,}")
    print(f"Trainable Percentage:  {100 * trainable_params / total_params:.4f}%")
    print(f"{'='*80}\n")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4,
        weight_decay=0.01
    )
    
    num_epochs = 5
    
    print(f"{'='*80}")
    print(f"TRAINING BITFIT TUNING")
    print(f"{'='*80}\n")
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        val_loss, val_labels, val_preds, val_probs = evaluate(
            model, test_loader, criterion, device
        )
        
        val_metrics = compute_comprehensive_metrics(val_labels, val_preds, val_probs)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_metrics['accuracy']:.4f} | Val F1 (Macro): {val_metrics['f1_macro']:.4f}")
        print(f"{'-'*80}")
        
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
    
    print(f"\n{'='*80}")
    print(f"FINAL TEST EVALUATION")
    print(f"{'='*80}\n")
    
    test_loss, test_labels, test_preds, test_probs = evaluate(
        model, test_loader, criterion, device
    )
    
    test_metrics = compute_comprehensive_metrics(test_labels, test_preds, test_probs)
    
    print_comprehensive_metrics(test_metrics, phase="Test")
    
    print(f"\n{'='*80}")
    print(f"BITFIT TUNING COMPLETED")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()

# A-Lore Code vul

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import time
import psutil
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.metrics import (confusion_matrix, accuracy_score, balanced_accuracy_score, 
                             precision_score, recall_score, f1_score, matthews_corrcoef, 
                             cohen_kappa_score, jaccard_score)
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CodeDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_len=512):
        df = pd.read_csv(csv_path)
        self.codes = df['code'].astype(str).tolist()
        self.labels = df['label'].astype(int).tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.codes)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.codes[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        return (
            enc['input_ids'].squeeze(0),
            enc['attention_mask'].squeeze(0),
            torch.tensor(self.labels[idx], dtype=torch.long)
        )

class AttentionAlignedAdapter(nn.Module):
    def __init__(self, hidden_dim=768, rank=64):
        super().__init__()
        self.Wd = nn.Linear(hidden_dim, rank, bias=False)
        self.Wu = nn.Linear(rank, hidden_dim, bias=False)

    def forward(self, X, S):
        Z = self.Wd(X)
        Zp = torch.matmul(S, Z)
        return self.Wu(Zp)

class PEFTSelfAttention(nn.Module):
    def __init__(self, t5_attn, hidden_dim=768, rank=64):
        super().__init__()
        self.t5_attn = t5_attn
        self.adapter = AttentionAlignedAdapter(hidden_dim, rank)
        self.b_o = nn.Parameter(torch.zeros(hidden_dim))
        
        self.dim = self.t5_attn.d_model
        self.n_heads = self.t5_attn.n_heads
        self.key_value_proj_dim = self.t5_attn.key_value_proj_dim

    def forward(self, hidden_states, mask=None, position_bias=None, **kwargs):
        batch_size, seq_length = hidden_states.shape[:2]

        q = self.t5_attn.q(hidden_states)
        k = self.t5_attn.k(hidden_states)
        v = self.t5_attn.v(hidden_states)

        def shape(x):
            return x.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        q = shape(q)
        k = shape(k)
        v = shape(v)

        scores = torch.matmul(q, k.transpose(3, 2))

        if position_bias is None:
            if self.t5_attn.has_relative_attention_bias:
                position_bias = self.t5_attn.compute_bias(seq_length, seq_length)
            else:
                position_bias = torch.zeros_like(scores)

        if position_bias is not None:
            scores += position_bias

        if mask is not None:
            scores += mask

        attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores)
        
        outputs = torch.matmul(attn_weights, v)
        outputs = outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)

        s_align = attn_weights.mean(dim=1)
        delta_a = self.adapter(hidden_states, s_align)
        
        outputs = outputs + delta_a
        
        outputs = self.t5_attn.o(outputs)
        outputs = outputs + self.b_o

        return outputs, position_bias

class PEFTCodeT5(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-220m")
        
        for p in self.model.parameters():
            p.requires_grad = False

        encoder = self.model.encoder
        for block in encoder.block:
            old_attn = block.layer[0].SelfAttention
            peft_attn = PEFTSelfAttention(old_attn)
            block.layer[0].SelfAttention = peft_attn

        self.classifier = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        logits = self.classifier(last_hidden_state[:, 0, :])
        return logits

def compute_metrics(y_true, y_pred, time_taken=0.0, mem_used=0.0, avg="macro"):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    metrics = {
        "Accuracy": accuracy_score(y_true, y_pred),
        "Balanced Acc": balanced_accuracy_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred, average=avg, zero_division=0),
        "Recall": recall_score(y_true, y_pred, average=avg, zero_division=0),
        "F1 Score": f1_score(y_true, y_pred, average=avg, zero_division=0),
        "Jaccard": jaccard_score(y_true, y_pred, average=avg, zero_division=0),
        "MCC": matthews_corrcoef(y_true, y_pred),
        "Kappa": cohen_kappa_score(y_true, y_pred),
        "Specificity": tn / (tn + fp + 1e-9),
        "NPV": tn / (tn + fn + 1e-9),
        "FPR": fp / (fp + tn + 1e-9),
        "FNR": fn / (fn + tp + 1e-9),
        "Inference Time (s)": time_taken,
        "Memory (MB)": mem_used
    }
    return metrics

def run_evaluation(model, loader, description="Evaluating"):
    model.eval()
    y_true, y_pred = [], []
    loss_accum = 0.0
    
    start_time = time.time()
    torch.cuda.reset_peak_memory_stats()
    
    criterion = nn.CrossEntropyLoss()
    
    pbar = tqdm(loader, desc=description, leave=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
    
    with torch.no_grad():
        for ids, mask, labels in pbar:
            ids, mask, labels = ids.to(device), mask.to(device), labels.to(device)
            logits = model(ids, mask)
            loss = criterion(logits, labels)
            loss_accum += loss.item()
            
            preds = torch.argmax(logits, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            
    end_time = time.time()
    mem_used = torch.cuda.max_memory_allocated() / (1024 * 1024)
    
    metrics = compute_metrics(y_true, y_pred, end_time - start_time, mem_used, avg="macro")
    metrics['Loss'] = loss_accum / len(loader)
    return metrics

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-220m", use_fast=False)
full_train = CodeDataset('/traincodex.csv', tokenizer)
test_dataset = CodeDataset('/testcodex.csv', tokenizer)

val_size = int(0.1 * len(full_train))
train_size = len(full_train) - val_size
train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

model = PEFTCodeT5().to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4)

for epoch in range(10):
    model.train()
    train_loss = 0.0
    y_true, y_pred = [], []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/10", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}, Loss={postfix[0]:.4f}, Acc={postfix[1]:.2f}%]', postfix=[0.0, 0.0])
    
    for i, (ids, mask, labels) in enumerate(pbar):
        ids, mask, labels = ids.to(device), mask.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(ids, mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        
        current_acc = accuracy_score(y_true, y_pred) * 100
        pbar.postfix[0] = train_loss / (i + 1)
        pbar.postfix[1] = current_acc
        pbar.update(0)

    val_metrics = run_evaluation(model, val_loader, "Evaluating")
    
    train_f1 = f1_score(y_true, y_pred, average='macro')
    train_loss_avg = train_loss / len(train_loader)
    
    print(f"\nTrain Loss: {train_loss_avg:.4f} | Train F1: {train_f1:.4f}")
    print(f"Val Loss:   {val_metrics['Loss']:.4f} | Val F1:   {val_metrics['F1 Score']:.4f}\n")
    
    if (epoch + 1) == 5 or (epoch + 1) == 10:
        print("="*80)
        print(f"PERFORMING FULL TEST ANALYSIS AT EPOCH {epoch+1}")
        print("="*80)
        test_metrics = run_evaluation(model, test_loader, "Testing")
        
        for k, v in test_metrics.items():
            if "Time" in k:
                print(f"{k:<20}: {v:.4f} s")
            elif "Memory" in k:
                print(f"{k:<20}: {v:.2f} MB")
            else:
                print(f"{k:<20}: {v:.4f}")
        print("="*80 + "\n")