In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 3100

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "model/LLM-Research/Meta-Llama-3___1-8B-Instruct",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
    device_map="auto",
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16, # Suggested 8,16,32,64,128
    target_modules= ["q_proj", "k_proj","v_proj","o_proj",
                                    "gate_proj","up_proj","down_proj",],
    lora_alpha = 16,
    lora_dropout =0,#Supports any,but =0 is optimized
    bias = "none", #Supports any,but =“none” is optimized
    #[NEW] “unsloth” uses 30% less VRAM,fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", #True or "unsloth” for very long context
    random_state = 3407,
    use_rslora = True,# We support rank stabilized LoRA 
    loftq_config = None, # And LoftQ
)

In [None]:
#加载数据集
# from datasets import load_dataset
# train_dataset = load_dataset('json', data_files='data/split9-1/train_dataset.json', split='train')
from datasets import load_dataset, concatenate_datasets

# 加载三个数据集
dataset1 = load_dataset('csv', data_files='data/dataset_5fold_1/dataset_fold_1.csv')['train']
dataset2 = load_dataset('csv', data_files='data/dataset_5fold_1/dataset_fold_2.csv')['train']
dataset3 = load_dataset('csv', data_files='data/dataset_5fold_1/dataset_fold_3.csv')['train']
# dataset4 = load_dataset('csv', data_files='data/dataset_5fold_1/dataset_fold_4.csv')['train']

# 加载验证集
dataset5 = load_dataset('csv', data_files='data/dataset_5fold_1/dataset_fold_4.csv')['train']

# 合并数据集
train_dataset = concatenate_datasets([dataset1, dataset2, dataset3])

# 打印合并后的数据集
print(train_dataset)

train_dataset = train_dataset.remove_columns(["report","code","label"])
validation_dataset = dataset5.remove_columns(["report","code","label"])

# 打印删除无关列的数据集
print(train_dataset)
print(validation_dataset)

In [None]:
EOS_TOKEN = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"

def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length= max_seq_length,
        padding=False,
        return_tensors=None,
    )
 
    # "self-supervised learning" means the labels are also the inputs:
    result["labels"] = result["input_ids"].copy()
 
    return result

In [None]:
# # 自动截断code/input

# MAX_LENGTH = 4000  # Llama3.1支持128k上下文

# def tokenize_and_truncate(data_point):
#     # 构建初始的 full_prompt
#     full_prompt = f"""You are a developer of the GCC compiler. Your job is to categorize bug reports. You are given a snippet of code that triggers the bug and a description of the bug.
# The bug reports are categorized as follows:'code-simplification-optimization-defects','control-flow-optimization-defects','data-flow-analysis-optimization-defects','infrastructure-defects','interprocedural-optimization-defects','memory-optimization-defects','numerical-analysis-optimization-defects','vectorization-defects'.
# ### Code Snippet:
# {data_point["code"]}
# ### Bug Description:
# {data_point["report"]}
# ### Response:
# {data_point["category"]}
# """
    
#     # Tokenize整个prompt
#     tokenized_prompt = tokenizer(full_prompt)
#     token_length = len(tokenized_prompt['input_ids'])

#     # 如果超过128k tokens，截断处理
#     if token_length > MAX_LENGTH:
#         # Tokenize code 和 input 部分
#         tokenized_code = tokenizer(data_point["code"], truncation=False)
#         tokenized_input = tokenizer(data_point["report"], truncation=False)

#         # 分别计算 code 和 input 的 token 长度
#         code_token_length = len(tokenized_code['input_ids'])
#         input_token_length = len(tokenized_input['input_ids'])

#         # 保留的长度 = MAX_LENGTH - (固定部分的token长度，即非code和input部分)
#         fixed_prompt = f"""You are a developer of the GCC compiler. Your job is to categorize bug reports. You are given a snippet of code that triggers the bug and a description of the bug.
# The bug reports are categorized as follows:'code-simplification-optimization-defects','control-flow-optimization-defects','data-flow-analysis-optimization-defects','infrastructure-defects','interprocedural-optimization-defects','memory-optimization-defects','numerical-analysis-optimization-defects','vectorization-defects'.
# ### Code Snippet:
# ### Bug Description:
# ### Response:
# {data_point["category"]}
# """
#         fixed_token_length = len(tokenizer(fixed_prompt)['input_ids'])
#         remaining_length = MAX_LENGTH - fixed_token_length

#         # 优先截断 code 和 input
#         if code_token_length + input_token_length > remaining_length:
#             # 如果总长度超过剩余长度，首先截断较长的部分
#             if code_token_length > input_token_length:
#                 # 优先截断 code
#                 truncated_code = tokenizer.decode(tokenized_code['input_ids'][:remaining_length - input_token_length])
#                 truncated_input = data_point["report"]
#             else:
#                 # 优先截断 input
#                 truncated_code = data_point["code"]
#                 truncated_input = tokenizer.decode(tokenized_input['input_ids'][:remaining_length - code_token_length])
#         else:
#             # 如果总长度不超标，不做额外截断
#             truncated_code = data_point["code"]
#             truncated_input = data_point["report"]

#         # 构建最终截断后的 prompt
#         full_prompt = f"""You are a developer of the GCC compiler. Your job is to categorize bug reports. You are given a snippet of code that triggers the bug and a description of the bug.
# The bug reports are categorized as follows:'code-simplification-optimization-defects','control-flow-optimization-defects','data-flow-analysis-optimization-defects','infrastructure-defects','interprocedural-optimization-defects','memory-optimization-defects','numerical-analysis-optimization-defects','vectorization-defects'.
# ### Code Snippet:
# {truncated_code}
# ### Bug Description:
# {truncated_input}
# ### Response:
# {data_point["category"]}
# """
    
#     # 进行最后的tokenize处理，确保token长度满足要求
#     return tokenize(full_prompt)

In [None]:
# 自动截断code/input

MAX_LENGTH = 2900  # Llama3.1支持128k上下文

def tokenize_and_truncate(data_point):
    # 构建初始的 full_prompt
    full_prompt = f"""You are a developer of the GCC compiler. Your job is to categorize bug reports. You are given a bug description.
The bug reports are categorized as follows:'code-simplification-optimization-defects','control-flow-optimization-defects','data-flow-analysis-optimization-defects','infrastructure-defects','interprocedural-optimization-defects','memory-optimization-defects','numerical-analysis-optimization-defects','vectorization-defects'.
### Bug Description:
{data_point["text"]}
### Response:
{data_point["category"]}
"""
    
    # Tokenize整个prompt
    tokenized_prompt = tokenizer(full_prompt)
    token_length = len(tokenized_prompt['input_ids'])

    # 如果超过 MAX_LENGTH，截断处理
    if token_length > MAX_LENGTH:
        tokenized_input = tokenizer(data_point["text"], truncation=False)
        truncated_input = tokenizer.decode(tokenized_input['input_ids'][:MAX_LENGTH])

        # 构建最终截断后的 prompt
        full_prompt = f"""You are a developer of the GCC compiler. Your job is to categorize bug reports. You are given a bug description.
The bug reports are categorized as follows:'code-simplification-optimization-defects','control-flow-optimization-defects','data-flow-analysis-optimization-defects','infrastructure-defects','interprocedural-optimization-defects','memory-optimization-defects','numerical-analysis-optimization-defects','vectorization-defects'.
### Bug Description:
{truncated_input}
### Response:
{data_point["category"]}
"""
    
    # 进行最后的tokenize处理，确保token长度满足要求
    return tokenize(full_prompt)

In [None]:
# 处理数据
tokenized_train_dataset = train_dataset.map(tokenize_and_truncate)
tokenized_validation_dataset = validation_dataset.map(tokenize_and_truncate)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported


trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = tokenized_train_dataset,
    # eval_dataset=tokenized_validation_dataset,
    max_seq_length = max_seq_length,
    packing = False, 
    dataset_text_field = "labels",
    dataset_num_proc = 2,
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 8,
        warmup_steps = 1000,
        num_train_epochs = 25,
        learning_rate = 2e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps =50,
        # eval_strategy="steps",
        # eval_steps=50,
        optim = "adamw_8bit",
        save_strategy="steps",
        save_steps=200,
        weight_decay = 0.01,
        lr_scheduler_type ="linear",
        output_dir = "output/Llama3.1-ensemble-text-v1-4",
        report_to="none",  # 关闭 WandB 等日志记录
    ),
    data_collator=DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

In [None]:
trainer.train()

In [None]:
# 保存模型
save = "save/Llama3.1-ensemble-text-v1-4"
model.save_pretrained(save)
tokenizer.save_pretrained(save)