In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from transformers import Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch 
import numpy as np
from transformers import DataCollatorWithPadding
from peft import get_peft_model, LoraConfig, TaskType


MODEL_NAME = "google/byt5-small"
# MODEL_NAME = "bert-base-uncased"
MAX_LENGTH = 1024
BATCH_SIZE = 12
NUM_TRAIN_EPOCHS = 3
NUM_LABELS = 10

peft_en = 0
# peft_en = 1

# byte_input = 0
byte_input = 1

# report_en = 0
report_en = 1

dataset = load_dataset("json", data_files="./datasets/cifar10_bytes_from_pil.jsonl", split="train")

def byte_list_to_char_list(byte_list):
    return [chr(b) for b in byte_list]

def rename_column(example):
    if byte_input:
        return {
            "text": byte_list_to_char_list(example["byte_array"]),  
            "label": example["label"]
        }
    else:
        return {
            "text": example["hex"],  
            "label": example["label"]
        }

dataset = dataset.map(rename_column)

dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_ds = dataset["train"]
val_ds = dataset["test"]

dataset_example = train_ds.select(range(10))
sample_text = []
for example in dataset_example:
    sample_text.append(example['text'])
    print(example)

config = AutoConfig.from_pretrained(MODEL_NAME)
config.num_labels = NUM_LABELS
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForSequenceClassification.from_config(config)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)

# tokens = tokenizer.tokenize(dataset_example["text"])
# print(tokens)

def tokenize_fn(example):
    return tokenizer(
        example["text"],
        is_split_into_words=True if byte_input else False,
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH
    )


train_ds = train_ds.map(tokenize_fn, batched=True)
val_ds = val_ds.map(tokenize_fn, batched=True)

train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

print(train_ds[0])

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    if isinstance(logits, tuple):
        logits = logits[0]
                
    preds = np.argmax(logits, axis=-1)
    # precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)

    return {
        'accuracy': acc,
        # 'precision': precision,
        # 'recall': recall,
        # 'f1': f1,
    }

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir=f"./{MODEL_NAME}-byte-classifier",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE*2,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    learning_rate=2e-4 if peft_en else 5e-5,
    report_to="wandb" if report_en else "none",
    eval_strategy="epoch",
    # eval_steps=5000,
    logging_strategy="steps", 
    logging_steps=500,
    logging_dir="./logs",
    save_strategy="epoch",
    # save_steps=5000,
    load_best_model_at_end=True,
    save_total_limit=2,
    metric_for_best_model="accuracy",  
    greater_is_better=True,             
    auto_find_batch_size=True,
    bf16=True,       
    bf16_full_eval=True,
    eval_accumulation_steps=12,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q", "v"]  # 针对 T5/BYT5 的 LoRA 插入位置
)

if(peft_en):
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()  # 检查参数量：应只有 LoRA 参数为可训练

trainer.train()

def predict_batch(texts):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # texts 应该是一个字符串列表
    inputs = tokenizer(
        texts,
        is_split_into_words=True if byte_input else False,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH
    )

    # 将所有输入移到设备
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)  # logits shape: [batch_size, num_classes]
        probs = outputs.logits.softmax(dim=-1)  # shape: [batch_size, num_classes]
        preds = probs.argmax(dim=-1)            # shape: [batch_size]

    # 转成 Python 数据类型
    preds = preds.tolist()         
    probs = probs.tolist()          

    return preds, probs



preds, probs = predict_batch(sample_text)

for cls, conf in zip(preds, probs):
    print("预测类别:", cls, "置信度:", conf)




{'label': 4, 'byte_array': [255, 216, 255, 224, 0, 16, 74, 70, 73, 70, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 255, 219, 0, 67, 0, 8, 6, 6, 7, 6, 5, 8, 7, 7, 7, 9, 9, 8, 10, 12, 20, 13, 12, 11, 11, 12, 25, 18, 19, 15, 20, 29, 26, 31, 30, 29, 26, 28, 28, 32, 36, 46, 39, 32, 34, 44, 35, 28, 28, 40, 55, 41, 44, 48, 49, 52, 52, 52, 31, 39, 57, 61, 56, 50, 60, 46, 51, 52, 50, 255, 219, 0, 67, 1, 9, 9, 9, 12, 11, 12, 24, 13, 13, 24, 50, 33, 28, 33, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 255, 192, 0, 17, 8, 0, 32, 0, 32, 3, 1, 34, 0, 2, 17, 1, 3, 17, 1, 255, 196, 0, 31, 0, 0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 255, 196, 0, 181, 16, 0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 125, 1, 2, 3, 0, 4, 17, 5, 18, 33, 49, 65, 6, 19, 81, 97, 7, 34, 113, 20, 50, 129, 145, 161, 8, 35, 66, 177, 193, 21, 82, 2

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at google/byt5-small and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'label': tensor(4), 'input_ids': tensor([198, 194, 198,  ..., 102, 197,   1]), 'attention_mask': tensor([1, 1, 1,  ..., 1, 1, 1])}


[34m[1mwandb[0m: Currently logged in as: [33mypfree[0m ([33mypfree-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 