In [2]:
import pandas as pd
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from torch.utils.data import DataLoader, Dataset
conversations_df = pd.read_csv("./formatted_csv.csv")

In [12]:
class ConversationDataset(Dataset):
    def __init__(self, conversations, tokenizer, max_length):
        self.conversations = conversations
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.conversations)
    
    def __getitem__(self, idx):
        conversation = self.conversations.iloc[idx]
        speaker1 = str(conversation['Speaker1'])  # Convert to string
        speaker2 = str(conversation['Speaker2'])  # Convert to string
        input_text = speaker1 + " <|endoftext|> " + speaker2
        inputs = self.tokenizer.encode(
            input_text, 
            add_special_tokens=True, 
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True
        )
        return {"input_ids": torch.tensor(inputs, dtype=torch.long)}


In [13]:
# Load and preprocess your CSV data

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
max_length = 128
dataset = ConversationDataset(conversations_df, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [14]:
# 2. 模型选择和准备
model_config = BartConfig.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration(config=model_config)

# 3. 微调模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [15]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    # 使用 tqdm 创建进度条，并在每个 epoch 中更新
    dataloader_iterator = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for batch in dataloader_iterator:
        inputs = batch['input_ids'].to(device)
        labels = batch['input_ids'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        # 更新进度条，显示当前的损失
        dataloader_iterator.set_postfix({"Loss": loss.item()}, refresh=True)

Epoch [1/5] - Loss: 10.96448802947998
Epoch [1/5] - Loss: 10.698569297790527
Epoch [1/5] - Loss: 10.543939590454102
Epoch [1/5] - Loss: 10.395811080932617
Epoch [1/5] - Loss: 10.256211280822754
Epoch [1/5] - Loss: 10.130712509155273
Epoch [1/5] - Loss: 10.011857986450195
Epoch [1/5] - Loss: 9.788372039794922
Epoch [1/5] - Loss: 9.749850273132324
Epoch [1/5] - Loss: 9.563823699951172
Epoch [1/5] - Loss: 9.426315307617188
Epoch [1/5] - Loss: 9.313831329345703
Epoch [1/5] - Loss: 9.151132583618164
Epoch [1/5] - Loss: 8.959895133972168
Epoch [1/5] - Loss: 8.869429588317871
Epoch [1/5] - Loss: 8.610897064208984
Epoch [1/5] - Loss: 8.531085014343262
Epoch [1/5] - Loss: 8.364372253417969
Epoch [1/5] - Loss: 8.247746467590332
Epoch [1/5] - Loss: 8.11317253112793
Epoch [1/5] - Loss: 7.959365367889404
Epoch [1/5] - Loss: 7.7777180671691895
Epoch [1/5] - Loss: 7.629390716552734
Epoch [1/5] - Loss: 7.504663467407227
Epoch [1/5] - Loss: 7.3996148109436035
Epoch [1/5] - Loss: 7.204122066497803
Epoch

KeyboardInterrupt: 

In [None]:
model.save_pretrained("chatbot_bart_model")