In [6]:
from transformers import AutoTokenizer, AutoModel
import torch

import os 
os.environ["http_proxy"] = "http://127.0.0.1:6666"
os.environ["https_proxy"] = "http://127.0.0.1:6666"

# 加载 CodeBERT 模型
model_name = "microsoft/graphcodebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/home/models")
model = AutoModel.from_pretrained(model_name, cache_dir="/home/models")

def get_code_embedding(code_snippet):
    inputs = tokenizer(code_snippet, return_tensors="pt", truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)
    # 获取 [CLS] 的嵌入作为整个代码片段的表示
    return outputs.last_hidden_state[:, 0, :].detach()

from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(vec1, vec2):
    return cosine_similarity(vec1.reshape(1, -1), vec2.reshape(1, -1))[0][0]

# 示例：缺陷代码和上下文代码片段
defect_code = "this.state = state + 1;"
context_snippets = [
    "this.state = 0;",
    "console.log(this.state);",
    "if (this.state > 10) this.resetState();"
]

# 生成缺陷代码的向量
defect_vec = get_code_embedding(defect_code)
## 打印一下维度
print(defect_vec.shape)

# 生成上下文代码的向量
context_vectors = [get_code_embedding(snippet) for snippet in context_snippets]

# 计算每个上下文片段的相似度
similarities = [calculate_similarity(defect_vec, ctx_vec) for ctx_vec in context_vectors]

# 根据相似度排序
sorted_contexts = sorted(zip(context_snippets, similarities), key=lambda x: x[1], reverse=True)

# 输出过滤和排序后的上下文
for snippet, score in sorted_contexts:
    print(f"Context: {snippet} | Similarity: {score}")




Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 768])
Context: this.state = 0; | Similarity: 0.9714752435684204
Context: if (this.state > 10) this.resetState(); | Similarity: 0.9319815635681152
Context: console.log(this.state); | Similarity: 0.7940301895141602


In [127]:
prompt_template = """
You are an AI debugging assistant. Your task is to fix the provided code based on the error log.

### Input:
Code snippets:
{code_snippet}

Error log:
{error_log}

### Task:
Please analyze the code and error log, then provide the fixed code directly.
Keep the original indentation of each code snippet.

### Output format:
Return the fixed code snippets with original indentation preserved.
"""

In [143]:
import json
import re

def extract_arkts_context(code: str, variable_name: str):
    lines = code.split("\n")
    context = {
        "definition": [],
        "usage": [], 
        "blocks": []
    }
    
    # 匹配变量定义
    definition_pattern = rf"@State.*{variable_name}.*=.*;"
    # 匹配变量使用 
    usage_pattern = rf".*{variable_name}.*"
    # 匹配代码块上下文
    block_pattern = rf"(for|if|while).*{{"

    # 先找出所有blocks
    for i, line in enumerate(lines):
        if re.search(block_pattern, line):
            block_end = i + 1
            indent_level = len(line) - len(line.lstrip())  # 计算缩进层级
            while block_end < len(lines) and (
                len(lines[block_end].strip()) == 0 or  # 空行跳过
                len(lines[block_end]) - len(lines[block_end].lstrip()) > indent_level  # 子块仍在范围内
            ):
                block_end += 1
            context["blocks"].append((i + 1, lines[i:block_end]))

    # 移除嵌套子块
    filtered_blocks = []
    for start, block in context["blocks"]:
        if not any(start > parent_start and start < parent_end for parent_start, parent_block in filtered_blocks for parent_end in [parent_start + len(parent_block)]):
            filtered_blocks.append((start, block))
    context["blocks"] = filtered_blocks

    # 找出所有definition和usage
    for i, line in enumerate(lines):
        if re.search(definition_pattern, line):
            # 检查是否在block中
            in_block = False
            for block_start, block in context["blocks"]:
                block_end = block_start + len(block)
                if i + 1 >= block_start and i + 1 <= block_end:
                    in_block = True
                    break
            if not in_block:
                context["definition"].append((i + 1, line))
        elif re.search(usage_pattern, line):
            # 检查是否在block中
            in_block = False
            for block_start, block in context["blocks"]:
                block_end = block_start + len(block)
                if i + 1 >= block_start and i + 1 <= block_end:
                    in_block = True
                    break
            if not in_block:
                context["usage"].append((i + 1, line))

    return context

# 示例代码
arkts_code = """
import hilog from '@ohos.hilog'
@Entry
@Component
struct MyComponent{
  @State message: string = '';
  build() {
    Column() {
      Button('点击打印日志')
        .onClick(() => {
          this.message = 'click';
          for (let k = 0; k < 10; k++) {        
            for(let j = 0; j < 10; j++) {
              for (let i = 0; i < 10; i++) {
                hilog.info(0x0000, 'TAG', '%{public}s', this.message);
              }
            }
          }
        })
        .width('90%')
        .backgroundColor(Color.Blue)
        .fontColor(Color.White)
        .margin({
          top: 10
        })
    }
    .justifyContent(FlexAlign.Start)
    .alignItems(HorizontalAlign.Center)
    .margin({
      top: 15
    })
  }
}
"""
variable_name = "message"
context = extract_arkts_context(arkts_code, variable_name)

# 合并所有代码片段并按行号排序
code_snippets = []
code_snippets.extend([(line_num, code) for line_num, code in context["definition"]])
code_snippets.extend([(line_num, code) for line_num, code in context["usage"]])
code_snippets.extend([(start, "\n".join(block)) for start, block in context["blocks"]])

code_snippets.sort(key=lambda x: x[0])
surrounding_context = [snippet for _, snippet in code_snippets]


In [139]:
### definition, usage, blocks都作为上下文, 不需要行号，直接拿代码块
def flatten_str_or_list(item):
    if isinstance(item, str):
        return item
    elif isinstance(item, list):
        return "\n".join(item)
    return ""

surrounding_context = "\n".join(surrounding_context)

print(surrounding_context)

  @State message: string = '';
          this.message = 'click';
          for (let k = 0; k < 10; k++) {        
            for(let j = 0; j < 10; j++) {
              for (let i = 0; i < 10; i++) {
                hilog.info(0x0000, 'TAG', '%{public}s', this.message);
              }
            }


In [132]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))
from llm import get_openai_answer

## 提取潜在的```language内的内容
def extract_code_from_response(response):
    """从回复中提取代码块内容"""
    # 使用正则表达式匹配```开头和结尾的代码块
    code_blocks = re.findall(r'```(?:\w+)?\n(.*?)```', response, re.DOTALL)
    
    # 如果找到多个代码块,拼接它们
    if code_blocks:
        return '\n'.join(code_blocks)
        
    # 如果没找到代码块,返回原文
    return response
    

error_log = "Don't use state variable in the loop, use a local variable instead"

prompt = prompt_template.format(code_snippet=surrounding_context, error_log=error_log)

response = get_openai_answer(prompt, model_name="gpt-4o-mini")
diff = extract_code_from_response(response)

In [133]:
print(diff)

  @State message: string = '';
          const localMessage = 'click';
          for (let k = 0; k < 10; k++) {        
            for(let j = 0; j < 10; j++) {
              for (let i = 0; i < 10; i++) {
                hilog.info(0x0000, 'TAG', '%{public}s', localMessage);
              }
            }



In [140]:
print(context)

{'definition': [(6, "  @State message: string = '';")], 'usage': [(11, "          this.message = 'click';")], 'blocks': [(12, ['          for (let k = 0; k < 10; k++) {        ', '            for(let j = 0; j < 10; j++) {', '              for (let i = 0; i < 10; i++) {', "                hilog.info(0x0000, 'TAG', '%{public}s', this.message);", '              }', '            }'])]}


In [145]:
### 通过context, surrounding context和diff，修改arkts_code
def modify_arkts_code(original_code, context, surrounding_context, diff):
    """修改ArkTS代码"""
    prompt = f"""请帮我修改以下ArkTS代码。

原始代码:
{original_code}

提取出的有关上下文代码内容及行号：
{context}

对于上下文的修改代码:
{diff}

请返回完整的修改后代码，直接返回代码，不要返回任何其它内容。
"""
    # 获取修改建议
    response = get_openai_answer(prompt)
    
    # 提取代码
    modified_code = extract_code_from_response(response)
    
    return modified_code

# 修改代码
repair_code = modify_arkts_code(arkts_code, context, surrounding_context, diff)
print(repair_code)



修改后的代码:
import hilog from '@ohos.hilog'
@Entry
@Component
struct MyComponent{
  @State message: string = '';
  build() {
    Column() {
      Button('点击打印日志')
        .onClick(() => {
          const localMessage = 'click';
          for (let k = 0; k < 10; k++) {        
            for(let j = 0; j < 10; j++) {
              for (let i = 0; i < 10; i++) {
                hilog.info(0x0000, 'TAG', '%{public}s', localMessage);
              }
            }
          }
        })
        .width('90%')
        .backgroundColor(Color.Blue)
        .fontColor(Color.White)
        .margin({
          top: 10
        })
    }
    .justifyContent(FlexAlign.Start)
    .alignItems(HorizontalAlign.Center)
    .margin({
      top: 15
    })
  }
}
