In [2]:
import pandas as pd
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from torch.utils.data import DataLoader, Dataset
conversations_df = pd.read_csv("./formatted_csv.csv")

In [12]:
class ConversationDataset(Dataset):
    def __init__(self, conversations, tokenizer, max_length):
        self.conversations = conversations
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.conversations)
    
    def __getitem__(self, idx):
        conversation = self.conversations.iloc[idx]
        speaker1 = str(conversation['Speaker1'])  # Convert to string
        speaker2 = str(conversation['Speaker2'])  # Convert to string
        input_text = speaker1 + " <|endoftext|> " + speaker2
        inputs = self.tokenizer.encode(
            input_text, 
            add_special_tokens=True, 
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True
        )
        return {"input_ids": torch.tensor(inputs, dtype=torch.long)}


In [13]:
# Load and preprocess your CSV data

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
max_length = 128
dataset = ConversationDataset(conversations_df, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [14]:
# 2. 模型选择和准备
model_config = BartConfig.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration(config=model_config)

# 3. 微调模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    # 使用 tqdm 创建进度条，并在每个 epoch 中更新
    dataloader_iterator = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for batch in dataloader_iterator:
        inputs = batch['input_ids'].to(device)
        labels = batch['input_ids'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        # 更新进度条，显示当前的损失
        dataloader_iterator.set_postfix({"Loss": loss.item()}, refresh=True)

In [None]:
model.save_pretrained("chatbot_bart_model")