In [3]:
import torch
import torch.nn.functional as F
import os
import logging
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertForSequenceClassification, BertTokenizer
import json
import random

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class GPT2Generator:
    """简化版的GPT-2生成器"""
    def __init__(self, model_path):
        print(f"Loading GPT-2 generator from {model_path}")
        
        # 直接加载本地模型和tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_path, local_files_only=True)
        self.model = GPT2LMHeadModel.from_pretrained(model_path, local_files_only=True).to(device)
        
        # 确保有padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        print("Successfully loaded generator model and tokenizer")
    
    def generate_text(self, max_length=200, temperature=0.9, top_k=40, top_p=0.9):
        self.model.eval()
        
        # 使用模型的开始标记，让模型自己决定第一个实际单词
        input_ids = torch.tensor([[self.tokenizer.bos_token_id]]).to(device)
        
        # 生成文本
        output = self.model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=1.2,
            do_sample=True,
            num_return_sequences=1,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        # 解码生成的文本
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return generated_text

class BERTDiscriminator:
    """简化版的BERT判别器"""
    def __init__(self, model_path):
        print(f"Loading BERT discriminator from {model_path}")
        
        # 直接加载本地模型和tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True)
        self.model = BertForSequenceClassification.from_pretrained(model_path, local_files_only=True).to(device)
        
        # 加载作者标签
        label_path = os.path.join(model_path, "label_names.json")
        with open(label_path, "r") as f:
            self.author_labels = json.load(f)
        
        # 创建作者索引映射
        self.author_indices = {author: idx for idx, author in enumerate(self.author_labels) if author is not None}
        print(f"Loaded BERT discriminator with authors: {', '.join(self.author_indices.keys())}")
    
    def evaluate_text(self, text):
        self.model.eval()
        
        # 编码文本
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=512
        ).to(device)
        
        # 预测
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probs = F.softmax(logits, dim=1)
        
        # 创建作者概率字典
        author_probs = {}
        for author, idx in self.author_indices.items():
            author_probs[author] = probs[0][idx].item()
        
        return author_probs

def generate_best_sample(generator, discriminator, author, num_samples=10):
    """生成多个样本并选择reward最高的"""
    samples = []
    rewards = []
    
    print(f"Generating {num_samples} samples for {author}...")
    for i in range(num_samples):
        # 生成文本
        generated_text = generator.generate_text(
            max_length=200,
            temperature=0.9,
            top_k=40,
            top_p=0.9
        )
        samples.append(generated_text)
        
        # 计算reward
        author_probs = discriminator.evaluate_text(generated_text)
        reward = author_probs.get(author, 0)
        rewards.append(reward)
        
        print(f"Sample {i+1}, Reward: {reward:.4f}")
    
    # 找出reward最高的样本
    best_idx = rewards.index(max(rewards))
    best_sample = samples[best_idx]
    best_reward = rewards[best_idx]
    
    print(f"\nBest sample (reward={best_reward:.4f}):\n{best_sample}")
    
    return best_sample, best_reward, samples, rewards

# 主函数调用
def main():
    author = "Charlotte_Brontë"  # 使用正确的作者名称
    
    # 设置模型路径为您提供的具体路径
    generator_path = r"C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\generator\Charlotte_Brontë\gan_best_model"
    discriminator_path = r"C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\discriminator\best_model"
    
    print(f"Generator path: {generator_path}")
    print(f"Discriminator path: {discriminator_path}")
    
    # 加载模型
    generator = GPT2Generator(generator_path)
    discriminator = BERTDiscriminator(discriminator_path)
    
    # 生成最佳样本
    best_sample, best_reward, all_samples, all_rewards = generate_best_sample(
        generator, discriminator, author, num_samples=10
    )
    
    # 保存结果
    output_dir = "brontë_samples"
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, "best_sample.txt")
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(f"Best sample (reward={best_reward:.4f}):\n")
        f.write(best_sample)
        f.write("\n\n" + "="*50 + "\n\n")
        f.write("All samples:\n\n")
        for i, (sample, reward) in enumerate(zip(all_samples, all_rewards)):
            f.write(f"Sample {i+1}, Reward: {reward:.4f}\n")
            f.write(sample)
            f.write("\n\n" + "-"*50 + "\n\n")
    
    print(f"Results saved to {output_path}")

if __name__ == "__main__":
    main()

Using device: cuda
Generator path: C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\generator\Charlotte_Brontë\gan_best_model
Discriminator path: C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\discriminator\best_model
Loading GPT-2 generator from C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\generator\Charlotte_Brontë\gan_best_model
Successfully loaded generator model and tokenizer
Loading BERT discriminator from C:\Users\HTL\OneDrive\Desktop\Author-Identifier\stylegan_output\Charlotte_Brontë\discriminator\best_model
Loaded BERT discriminator with authors: Charles_Dickens, Charlotte_Brontë, Honoré_de_Balzac, Jane_Austen, Joseph_Conrad, O._Henry, W._Somerset_Maugham, Unknown
Generating 10 samples for Charlotte_Brontë...
Sample 1, Reward: 0.9954
Sample 2, Reward: 0.9963
Sample 3, Reward: 0.9983
Sample 4, Reward: 0.9979
Sample 5, Reward: 0.0007
Sample 6, Reward: 0.9964
Sample 7, Reward: 0.