<a href="https://colab.research.google.com/github/CHEN-886a/bart_pretrain02/blob/main/bart_pretrain02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
import json

class FinancialNewsDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, tokenizer, max_length=512):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        record = self.data[idx]
        company_name = record["input"]["company_name"]
        news_content = record["input"]["content"]
        combined_input = f"{company_name}: {news_content}"
        summary = record["output"]

        inputs = self.tokenizer(combined_input, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")
        outputs = self.tokenizer(summary, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": outputs["input_ids"].squeeze()
        }

# 加载分词器和数据集
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
train_file_path = "/content/drive/MyDrive/colab notebook/data/train_dataset_bart_02.jsonl"
test_file_path = "/content/drive/MyDrive/colab notebook/data/test_dataset_bart_02.jsonl"

train_dataset = FinancialNewsDataset(train_file_path, tokenizer)
test_dataset = FinancialNewsDataset(test_file_path, tokenizer)

# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 加载预训练模型并调整层数
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base', num_hidden_layers=12)  # 增加Decoder层数
model.to(device)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/colab notebook/results',
    num_train_epochs=5,  # 增加训练轮数
    per_device_train_batch_size=4,  # 根据GPU内存调整批量大小
    per_device_eval_batch_size=4,
    save_steps=1000,  # 调整保存频率
    save_total_limit=3,
    logging_dir='/content/drive/MyDrive/colab notebook/logs',
    logging_steps=100,
    eval_strategy="steps",
    eval_steps=500,  # 调整评估频率
    load_best_model_at_end=True,
    metric_for_best_model="loss"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# 定义训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
)

# 开始训练
trainer.train()

# 保存模型和分词器
model.save_pretrained('/content/drive/MyDrive/colab notebook/results/trained_model')
tokenizer.save_pretrained('/content/drive/MyDrive/colab notebook/results/trained_model')


In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

# 加载模型和分词器
model_path = 'trained_model'
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
# Move the model to the device
model.to(device)

# 输入文本示例
input_text = "Mainland China and Hong Kong stocks ended lower, with a key index logging its fifth straight losing session. Investors were disappointed by a lack of policy stimulus measures amid a weak economic recovery, rising geopolitical tensions and foreign outflows.In France, a leftist alliance unexpectedly took top spot ahead of the far right in Sunday's election, a major upset that was set to prevent Marine Le Pen's National Rally from running the government.The weaker than expected showing for the far right was something of a relief for investors, though they also have concerns the left s plans could unwind many of President Emmanuel Macrons pro-market reforms"

# 处理输入并生成摘要
inputs = tokenizer(input_text, return_tensors="pt").to(device) # Move the input tensors to the device
summary_ids = model.generate(inputs['input_ids'])

# 解码摘要并打印结果
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print("Generated Summary:", summary)