In [None]:
# deepspeed初始化分布式环境
if args.local_rank != -1:
    deepspeed.int_distributed(dist_backend='nccl')
    torch.cuda.set_device(args.local_rank)
    logger.info("finish")

In [None]:
# 定义一个GRPOTrainer类来加载分词器和基础模型，加载lora权重
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
model_kwargs = {
    'torch.dtype': torch.float16 if cuda_avilable else torch.float32,
    'use_cache': False
}
self.model = AutoModelForCausalLM.from_pretrained(
    self.base_model_path,
    **model_kwargs
)
self.model = PeftModel.from_pretrained(
    self.model,
    self.lora_model_path
)

In [None]:
# Lora加载的参数默认不可训练，需要手动设置可训练
trainable_params = []
for name, param in self.model.name_parameters():
    if 'lora' in name.lower():
        param.requires_grad = True
        trainable_params.append(param)

In [None]:
# deepspeed分布式封装，用于分布式训练
ds_args = {
    "model":self.model,
    "model_parameters": trainable_params,
    "config": ds_config
}
model_engine, optimizer, _, _ = deepspeed.initialize(**ds_args)
self.model = model_engine
self.optimizer = optimizer

In [None]:
# 数据集处理
train_data = []
if "data" in json_data:
    logger.info(f"find {len(json_data["data"])} samples")
    for item in json_data["data"]:
        source_code = item.get("source_code", "")
        test_code = item.get("test_code", "")
        if not source_code:
            continue
        prompt = f"请为以下Java类生成单元测试用例: ```java {source_code} ```生成的测试用例："
        train_data.append({
            "prompt": prompt,
            "source_code": source_code,
            "test_code": test_code
        })
else:
    logger.info("no data")
    train_data = json_data

In [None]:
# 编写测试环境
class TestGenerationEnvironment:
    def __init__(self, jacoco_path:str = None, pit_path: str = None):
        self.coverage_weight = 0.4
        self.mutation_weight = 0.3
        self.readability_weight = 0.3

        # 创建临时工作目录
        self.temp_dir = tempfile.mkdtemp(prefix="test_eval_")
        logger.info(f"new temp work folder")
    
    def evaluate_test(self,
                      generated_test:str,
                      source_code:str,
                      reference_test:str = None) -> Tuple[float, Dict]:
        logger.info("start test")
        try:
            # 静态覆盖率分析
            converage_score = self._static_coverage_analysis(generated_test)
            # 静态变异分析
            mutation_score = self._static_mutation_analysis(generated_test)
            # 计算可读性分数
            readability_score = self._calculate_readability(generated_test)
            # 参考测试
            similarity_score = 0.0
            if reference_test:
                similarity_score = self._calculate_similarity(generated_test, reference_test)
            
            # 总分
            total_score = self.coverage_weight * coverage_score + self.mutation_weight * mutation_score + self.readability_weight * readability_score
            logger.info(f"total score: {total_score}")

            return total_score, {
                "coverage_score": coverage_score,
                "mutation_score": mutation_score,
                "readability_score": readability_score,
                "similarity_score": similarity_score
            }
            
        except Exception as e:
            logger.error(f"测试评估过程中出现未捕获的错误: {str(e)}")
            return 0.65, {                
                "coverage_score": 0.6,
                "mutation_score": 0.6,
                "readability_score": 0.8,
                "similarity_score": 0.0,
                "error": str(e)}

In [None]:
# 生成测试样本
try:
    generated_test = self.generated_test(prompt, num_samples=num_samples)
except Exception as e:
    return {"loss": 0, "mean_reward": 0, "mean_kl_div": 0, "num_samples": 0}

try:
    rewards, metrics_list = self.compute_rewards(generated_test, source_code, reference_test)
except Excetion as e:
    return {"loss": 0, "mean_reward": 0, "mean_kl_div": 0, "num_samples": 0}

value = rewards.mean()

advantages = rewards - value

if len(advantages) > 1 and advantages.std() >0:
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

In [None]:
def compute_logprobs(self, model, inputs, attention_mask=None):
    model_inputs = {"input_ids": input_ids}