In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel, BertTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from torch.utils.data import DataLoader, Dataset

# 数据集类
class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

# BERT特征提取
class BERTFeatureExtractor(nn.Module):
    def __init__(self, pretrained_model_name='bert-base-uncased'):
        super(BERTFeatureExtractor, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

# TF-IDF特征提取
class TFIDFFeatureExtractor:
    def __init__(self):
        self.vectorizer = TfidfVectorizer()

    def fit_transform(self, texts):
        return self.vectorizer.fit_transform(texts)

    def transform(self, texts):
        return self.vectorizer.transform(texts)

# Two-Stage模型
class TwoStageModel(nn.Module):
    def __init__(self, bert_feature_size, tfidf_feature_size, num_classes):
        super(TwoStageModel, self).__init__()
        self.bert_extractor = BERTFeatureExtractor()
        self.tfidf_extractor = TFIDFFeatureExtractor()
        self.linear = nn.Linear(bert_feature_size + tfidf_feature_size, num_classes)

    def forward(self, bert_input_ids, bert_attention_mask, tfidf_features):
        bert_features = self.bert_extractor(bert_input_ids, bert_attention_mask)
        combined_features = torch.cat((bert_features, tfidf_features), dim=1)
        output = self.linear(combined_features)
        return output

# 数据预处理和模型训练代码
def train_model(train_texts, train_labels, val_texts, val_labels, num_epochs=10, batch_size=32, learning_rate=1e-3):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_input_ids = tokenizer(train_texts, padding=True, truncation=True, return_tensors='pt')['input_ids']
    bert_attention_mask = tokenizer(train_texts, padding=True, truncation=True, return_tensors='pt')['attention_mask']
    tfidf_extractor = TFIDFFeatureExtractor()
    tfidf_features = tfidf_extractor.fit_transform(train_texts).toarray()

    train_dataset = TextDataset(list(zip(bert_input_ids, bert_attention_mask, tfidf_features)), train_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model = TwoStageModel(bert_feature_size=768, tfidf_feature_size=tfidf_features.shape[1], num_classes=2)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for inputs, labels in train_loader:
            bert_input_ids, bert_attention_mask, tfidf_features = zip(*inputs)
            bert_input_ids = torch.stack(bert_input_ids)
            bert_attention_mask = torch.stack(bert_attention_mask)
            tfidf_features = torch.tensor(tfidf_features)
            labels = torch.tensor(labels)

            optimizer.zero_grad()
            outputs = model(bert_input_ids, bert_attention_mask, tfidf_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}")

    return model

# 假数据示例
train_texts = ["This is a human written text.", "This is an AI generated text."]
train_labels = [0, 1]
val_texts = ["Another human written text.", "Another AI generated text."]
val_labels = [0, 1]

# 训练模型
trained_model = train_model(train_texts, train_labels, val_texts, val_labels)
