In [2]:
#gpt2
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer
import torch
from transformers import GPT2Config, GPT2Tokenizer, GPT2ForSequenceClassification
import numpy as np
import evaluate
import torch.nn.functional as F



# 设置随机种子
torch.manual_seed(48)
np.random.seed(56)

# 加载预训练的GPT-2模型和tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("./GPT-2")

config = GPT2Config.from_pretrained("./GPT-2")



# 调整模型的层数
config.n_layer = 2  # 设置为6层

model = GPT2ForSequenceClassification.from_pretrained("./GPT-2",config = config)




# 加载计算工具
metric = evaluate.load("./tools/accuracy.py")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)







# 数据文件路径，数据需要提前下载
data_file = "./data/train.csv" 


# 加载数据集
dataset = load_dataset("csv", data_files=data_file)
dataset = dataset.filter(lambda x: x["seq"] is not None)
datasets = dataset["train"].train_test_split(0.1)

# 数据集处理
tokenizer = AutoTokenizer.from_pretrained("GPT-2")

def process_function(examples):
    for i in range(len(examples['seq'])):
        examples['seq'][i] = ' '.join(list(examples['seq'][i]))
    tokenized_examples = tokenizer(examples["seq"], max_length=500, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True)

tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


args = TrainingArguments(
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_steps = 4,
    seed=4112,
    output_dir="model_for_seqclassification",
    logging_steps=10,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True
)

trainer = Trainer(
  model,
  args,
  train_dataset=tokenized_datasets["train"],
  eval_dataset=tokenized_datasets["test"],
  tokenizer=tokenizer,
  compute_metrics=compute_metrics,
 data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
)
trainer.train()


# 数据文件路径，数据需要提前下载
data_file = "./data/test.csv" 

# 加载数据集
dataset = load_dataset("csv", data_files=data_file)
dataset = dataset.filter(lambda x: x["seq"] is not None)
datasets = dataset["train"]



# def process_function(examples):
#     for i in range(len(examples['seq'])):
#         examples['seq'][i] = ' '.join(list(examples['seq'][i]))
#     tokenized_examples = tokenizer(examples["seq"], max_length=500, truncation=True)
#     tokenized_examples["labels"] = examples["label"]
#     return tokenized_examples

test_datasets = datasets.map(process_function, batched=True)
predictions = trainer.predict(test_datasets)
print(predictions)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at ./GPT-2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 425/425 [00:00<00:00, 2240.28 examples/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6772,0.694218,0.508235


PredictionOutput(predictions=array([[ 2.26898193e-02, -3.03649902e-02],
       [-5.37597656e-01, -6.26464844e-01],
       [-1.58996582e-02,  3.30200195e-02],
       [ 8.11157227e-02,  4.89501953e-02],
       [-3.43017578e-02, -1.34521484e-01],
       [-2.37792969e-01, -2.64160156e-01],
       [-3.97949219e-01, -5.11718750e-01],
       [-5.83496094e-01, -6.38671875e-01],
       [-6.87011719e-01, -6.26464844e-01],
       [-1.42578125e-01, -1.52099609e-01],
       [-5.60058594e-01, -6.54296875e-01],
       [-5.62500000e-01, -6.23535156e-01],
       [-4.09423828e-01, -4.83398438e-01],
       [-3.97644043e-02, -9.10644531e-02],
       [-3.78417969e-01, -4.52636719e-01],
       [-2.55126953e-01, -2.95166016e-01],
       [-4.04541016e-01, -4.48242188e-01],
       [-4.05029297e-01, -4.48486328e-01],
       [-3.61022949e-02, -9.23461914e-02],
       [-1.32324219e-01, -1.29028320e-01],
       [-4.55810547e-01, -5.55664062e-01],
       [-5.40039062e-01, -6.63574219e-01],
       [ 1.30176544e-03, 