In [1]:
import torch
import transformers

print(torch.__version__)
print(transformers.__version__)

2.0.1+cpu
4.12.1


In [4]:
import json
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, file_path, tokenizer, mode='train', max_context_length=512, max_question_length=128):
        super(MyDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_context_length = max_context_length
        self.max_question_length = max_question_length
        self.data = []
        
        if mode == "train":
            with open(file_path, 'r', encoding='utf-8') as f:
                raw_data = json.load(f)['data'][:6000]
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                raw_data = json.load(f)['data'][6000:]
        for article in raw_data:
            for qa in article['paragraphs']:
                context = qa['context']
                for question in qa['qas']:
                    answer = question['answers'][0]['text']
                    self.data.append({
                        'context': context,
                        'question': question['question'],
                        'answer': answer
                    })
                        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        context = self.data[idx]['context']
        question = self.data[idx]['question']
        answer = self.data[idx]['answer']
        return context, question, answer

In [5]:
# model

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel


class MyGPTModel(nn.Module):
    def __init__(self, tokenizer, max_context_length=512, max_question_length=128):
        super(MyGPTModel, self).__init__()
        self.tokenizer = tokenizer
        self.max_context_length = max_context_length
        self.max_question_length = max_question_length
        self.model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
        self.model.resize_token_embeddings(len(tokenizer))
        self.model.lm_head = nn.Linear(self.model.lm_head.in_features, 1)
        self.model.multiple_choice_head = nn.Linear(self.model.multiple_choice_head.in_features, 2)
        
    def forward(self, context, question, answer):
        # Encode inputs
        context_input_ids = self.tokenizer.encode(context, add_special_tokens=False, truncation=True, max_length=self.max_context_length, padding='max_length', return_tensors='pt')
        question_input_ids = self.tokenizer.encode(question, add_special_tokens=False, truncation=True, max_length=self.max_question_length, padding='max_length', return_tensors='pt')
        
        # Truncate answer if it exceeds max_context_length
        answer_start = context.find(answer)
        answer_end = answer_start + len(answer) - 1
        if answer_end >= self.max_context_length:
            answer_end = self.max_context_length - 1
        
        # Create binary labels for LM head
        labels = torch.zeros(context_input_ids.shape, dtype=torch.long)
        labels[:, answer_start:answer_end+1] = 1
        
        # Forward pass through model
        lm_logits, mc_logits = self.model(context_input_ids, question_input_ids)
        loss = F.binary_cross_entropy_with_logits(lm_logits.squeeze(-1), labels.float())
        return loss

In [None]:
class Trainer:
    def __init__(self, model, train_dataset, test_dataset, device='cuda', batch_size=8, learning_rate=1e-5, num_epochs=10):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.device = device
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
    def train(self):
        self.model.to(self.device)
        for epoch in range(self.num_epochs):
            self.model.train()
            total_loss = 0
            for batch in self.train_dataloader:
                context_batch = batch[0]
                question_batch = batch[1]
                answer_batch = batch[2]
                loss = self.model(context_batch, question_batch, answer_batch)
                total_loss += loss.item()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_loss = total_loss / len(self.train_dataloader)
            print(f'Epoch {epoch+1}, train_loss: {train_loss:.4f}')
            
            self.model.eval()
            with torch.no_grad():
                total = 0
                correct = 0
                for batch in self.test_dataloader:
                    context_batch = batch[0]
                    question_batch = batch[1]
                    answer_batch = batch[2]
                    for i in range(len(context_batch)):
                        context = context_batch[i]
                        question = question_batch[i]
                        answer = answer_batch[i]
                        context_input_ids = self.model.tokenizer.encode(context, add_special_tokens=False, truncation=True, max_length=self.model.max_context_length, padding='max_length', return_tensors='pt').to(self.device)
                        question_input_ids = self.model.tokenizer.encode(question, add_special_tokens=False, truncation=True, max_length=self.model.max_question_length, padding='max_length', return_tensors='pt').to(self.device)
                        answer_start = context.find(answer)
                        answer_end = answer_start + len(answer) - 1
                        if answer_end >= self.model.max_context_length:
                            answer_end = self.model.max_context_length - 1
                        labels = torch.zeros(context_input_ids.shape, dtype=torch.long).to(self.device)
                        labels[:, answer_start:answer_end+1] = 1
                        lm_logits, mc_logits = self.model.model(context_input_ids, question_input_ids)
                        preds = F.sigmoid(lm_logits.squeeze(-1)) > 0.5
                        total += len(labels)
                        correct += (preds == labels).sum().item()
                test_accuracy = correct / total
                print(f'Epoch {epoch+1}, test_accuracy: {test_accuracy:.4f}')

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
train_dataset = MyDataset('path/to/big_train_data.json', mode="train", tokenizer)
test_dataset = MyDataset('path/to/test_data.json', tokenizer)
model = MyGPTModel(tokenizer)
trainer = Trainer(model, train_dataset, test_dataset)
trainer.train()

在MyModel中，我们通过max_context_length和max_question_length来限制输入的最大长度。如果训练数据中的答案超过了max_context_length，我们会将答案截断到max_context_length。在Trainer中，我们将模型和数据都移到GPU上，并使用torch.no_grad()来关闭梯度计算，在测试集上计算模型的准确率。