In [None]:
import json
import os
import sys
sys.path.append(os.path.abspath('/Circuit_LoRa'))
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer
)

import random
import numpy as np
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

from changecircuit_edge_keylayer import define_critical_layers_via_edges
from circuit_weighted_lora import (
    apply_circuit_weighted_lora,
    freeze_non_critical_layers,
    circuit_regularization,
    save_initial_params
)


In [None]:
def identify_critical_layers(before_circuit_json_path, after_circuit_json_path, model_prefix='base_model.model.gpt_neox.layers', threshold=None, top_k=5):

    if top_k == 0:  
        return []
    critical_layers = define_critical_layers_via_edges(
        before_circuit_json_path,
        after_circuit_json_path,
        model_prefix=model_prefix,
        threshold=threshold,
        top_k=top_k
    )
    
    for layer in critical_layers:
        print(layer)
    
    return critical_layers

before_circuit_json_path = '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/graph_100_2arithmetic_operations_1.4b_epoch_0.json'  # 微调前的电路JSON路径
after_circuit_json_path = '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/graph_100_2arithmetic_operations_1.4b_r2_epoch_2.json'    # 微调后的电路JSON路径

model_prefix = 'gpt_neox.layers'

threshold = None  
top_k = 0       

critical_layers = identify_critical_layers(
    before_circuit_json_path,
    after_circuit_json_path,
    model_prefix=model_prefix,
    threshold=threshold,
    top_k=top_k
)


In [None]:
def load_and_preprocess_data(train_file, validation_file, tokenizer):
   
    data_files = {
        'train': train_file,
        'validation': validation_file
    }
    dataset = load_dataset('json', data_files=data_files)
    
    def preprocess_function(examples):
        max_length = 32
        inputs = examples['input']
        outputs = [str(o) for o in examples['output']]

        prompts = [f"{inp}\n" for inp in inputs]
        full_texts = [prompt + out for prompt, out in zip(prompts, outputs)]

        tokenized_full = tokenizer(full_texts, truncation=True, padding='max_length', max_length=max_length)

        tokenized_prompt = tokenizer(prompts, truncation=True, padding='max_length', max_length=max_length)

        labels = []
        for i in range(len(full_texts)):
            prompt_len = len(tokenizer.encode(prompts[i], truncation=True, max_length=max_length))

            label = [-100] * prompt_len + tokenized_full['input_ids'][i][prompt_len:]
            label = label[:max_length]
            if len(label) < max_length:
                label += [-100] * (max_length - len(label))
            labels.append(label)

        tokenized_full['labels'] = labels

        return tokenized_full
    
    tokenized_datasets = dataset.map(preprocess_function, batched=True)
    
    return tokenized_datasets

train_file = '/2_arithmetic_operations_100/finetune_pythia_100/finetune_data/train_100.jsonl'
validation_file = '/2_arithmetic_operations_100/finetune_pythia_100/finetune_data/test_100.jsonl'

model_name = 'EleutherAI/gpt-neo-2.7B'
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenized_datasets = load_and_preprocess_data(train_file, validation_file, tokenizer)

print(tokenized_datasets['train'][:5])
print(tokenized_datasets['validation'][:5])

train_size = len(tokenized_datasets['train'])
validation_size = len(tokenized_datasets['validation'])

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name)
model = apply_circuit_weighted_lora(
    model=model,
    critical_layers=critical_layers,
    r=32,                 
    alpha=64,             
    extra_r=0,             
    critical_alpha=0,    
    dropout=0           
)

model = freeze_non_critical_layers(model, critical_layers)

initial_params = save_initial_params(model, critical_layers)

def print_trainable_parameters(model):
    trainable_params = 0
    total_params = 0
    print("\nTrainable Parameters:")
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            print(f" - {name}: {param.numel()} parameters")
    print(f"\nTotal trainable params: {trainable_params} / Total params: {total_params}")

print_trainable_parameters(model)


In [10]:
import os
import json
from safetensors.torch import save_file

def save_updated_weights(model, critical_layers, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    updated_weights = {}

    for name, param in model.named_parameters():
        if "lora_" in name and param.requires_grad:
            updated_weights[name] = param.detach().cpu()
            print(f"Saving LoRA weight: {name}")
        elif any(layer in name for layer in critical_layers) and param.requires_grad:
            updated_weights[name] = param.detach().cpu()
            print(f"Saving critical layer weight: {name}")

    weights_path = os.path.join(output_dir, "adapter_model.safetensors")
    try:
        save_file(updated_weights, weights_path)
        print(f"Updated weights saved to {weights_path}")
    except Exception as e:
        print(f"Error saving safetensors file: {e}")

    config_path = os.path.join(output_dir, "adapter_config.json")
    with open(config_path, "w") as f:
        config = {
            "critical_layers": critical_layers,
            "r": 32,
            "alpha": 64,
            "extra_r": 0,
            "critical_alpha": 0,
            "dropout": 0,
        }
        json.dump(config, f, indent=4)
    print(f"LoRA config saved to {config_path}")

class CustomTrainer(Trainer):
    def __init__(self, *args, critical_layers=None, initial_params=None, lambda_reg=1e-3, **kwargs):
        super().__init__(*args, **kwargs)
        self.critical_layers = critical_layers  
        self.initial_params = initial_params    
        self.lambda_reg = lambda_reg           

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss
        if self.critical_layers and self.initial_params:
            reg_loss = circuit_regularization(model, self.critical_layers, self.initial_params, self.lambda_reg)
            loss = loss + reg_loss

        return (loss, outputs) if return_outputs else loss

    def save_model(self, output_dir=None, **kwargs):
        if output_dir is None:
            output_dir = self.args.output_dir
        save_updated_weights(self.model, self.critical_layers, output_dir)


In [11]:
training_args = TrainingArguments(
    output_dir='./lora_gpt_results/r32a64',  # Output directory
    num_train_epochs=2,                            # Number of training epochs
    per_device_train_batch_size=8,                 # Batch size per device
    warmup_steps=25,                               # Number of warmup steps
    weight_decay=0.01,                             # Weight decay
    logging_dir='./circuit_weighted_lora_logs',    # Logging directory
    logging_steps=10,                              # Log every 10 steps
    save_steps=25,                                 # Save model every 50 steps
    save_strategy="steps",                         # Save by steps
    save_total_limit=10,                            # Keep at most 1 model
    fp16=True,                                     # Mixed precision
    gradient_accumulation_steps=4,                 # Gradient accumulation steps
    report_to="none",                              # Disable default reporting
    learning_rate=2e-4,                            # Learning rate (higher than for full fine-tuning)
    
)


In [None]:
from torch.optim import AdamW

optimizer_grouped_parameters = []
for name, param in model.named_parameters():
    if param.requires_grad:
        if any(layer in name for layer in critical_layers):
            lr = 3e-4  
            print(f"Critical Layer Param: {name} | Learning Rate: {lr}")
        else:
            lr = 3e-4 
            print(f"Non-Critical Layer Param: {name} | Learning Rate: {lr}")
        optimizer_grouped_parameters.append({"params": param, "lr": lr})

optimizer = AdamW(optimizer_grouped_parameters, weight_decay=0.01)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    critical_layers=critical_layers,
    initial_params=initial_params,
    lambda_reg=0,  
    optimizers=(optimizer, None)
)

In [13]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    critical_layers=critical_layers,
    initial_params=initial_params,
    lambda_reg=0,  
)

In [None]:
trainer.train()