In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
def compute_standard_ppl_with_sliding_window(model, tokenizer, dataset, device="cuda"):
    """
    结合了：
    1. 滑动窗口 (Stride) -> 保证每个 token 都有足够的上文
    2. 全局聚合 (Global Aggregation) -> 符合 PPL 标准定义
    """
    model.eval()
    
    # 1. 拼接全量文本
    encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
    
    max_length = model.config.n_positions
    stride = 512  # 滑动步长，通常设为 max_length 的一半或更小
    seq_len = encodings.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    
    # tqdm 进度条
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # 我们这一轮实际要评估的 token 长度
        
        # 获取当前窗口的 input_ids
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100 

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            
            # outputs.loss 是平均 loss，我们需要还原成 sum loss
            # 因为最后一个 batch 的 trg_len 可能不等于 stride
            neg_log_likelihood = outputs.loss * trg_len

        nlls.append(neg_log_likelihood)
        
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    # 3. 全局计算 PPL
    ppl = torch.exp(torch.stack(nlls).sum() / seq_len) # 总 NLL / 总长度
    return ppl.item()

In [None]:
def group_texts(examples):
    """将文本拼接并切块 (Packing)"""
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    if total_length >= MAX_LENGTH:
        total_length = (total_length // MAX_LENGTH) * MAX_LENGTH
    
    result = {
        k: [t[i : i + MAX_LENGTH] for i in range(0, total_length, MAX_LENGTH)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

def train_model(model, train_dataset, eval_dataset, tokenizer, output_dir):
    """
    模型微调函数：包含自动列清理、BF16加速、Epoch级评估
    """
    # 内部预处理函数
    def preprocess_dataset(dataset):
        # 关键：获取当前所有列名 (text, entropy 等)，以便稍后移除
        column_names = dataset.column_names

        def tokenize_function(examples):
            # 批量处理：给每个文本末尾加上 EOS
            return tokenizer(
                [t + tokenizer.eos_token for t in examples["text"]]
            )
        
        # 1. Tokenize 并移除旧列
        tokenized = dataset.map(
            tokenize_function, 
            batched=True, 
            num_proc=8,
            remove_columns=column_names # 彻底清理，防止 group_texts 报错
        )
        
        # 2. Packing
        packed = tokenized.map(
            group_texts, 
            batched=True,
            num_proc=8
        )
        return packed

    lm_train_dataset = preprocess_dataset(train_dataset)
    lm_eval_dataset = preprocess_dataset(eval_dataset)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=EPOCHS,
        
        per_device_train_batch_size=TRAIN_BATCH_SIZE, 
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        bf16=use_bf16,
        fp16=use_fp16,
        dataloader_num_workers=8,
        optim="adamw_torch_fused",

        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
        save_total_limit=1,
        
        report_to="none", 
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        lr_scheduler_type="cosine",
        warmup_ratio=WARMUP_RATIO,
        logging_steps=20,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=lm_train_dataset,
        eval_dataset=lm_eval_dataset,
    )
    
    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    return model

In [None]:
def generate_synthetic_data(model, tokenizer, prompt_dataset, num_samples=None):
    """
    镜像生成函数
    """
    model.eval()
    
    # --- 1. Tokenizer 设置 ---
    # 强制左填充 (Left Padding is crucial for batch generation)
    tokenizer.padding_side = "left" 
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # --- 2. 数据准备与清洗 ---
    # 获取原始文本
    if num_samples is None:
        raw_texts = prompt_dataset["text"]
    else:
        raw_texts = prompt_dataset["text"][:num_samples]
    
    # 只保留非空文本，后续逻辑默认数据是干净的
    clean_texts = [t for t in raw_texts if len(t.strip()) > 0]
    print(f">>> [Data Clean] Filtered {len(raw_texts)} -> {len(clean_texts)} samples")

    synthetic_texts = []

    print(f">>> [Gen] Generating samples with Batch Size {batch_size}...")
    
    # --- 3. 批量生成循环 ---
    # 直接遍历 clean_texts，步长为 batch_size
    for i in tqdm(range(0, len(clean_texts), batch_size), desc="Mirror Gen"):
        batch_prompts = clean_texts[i : i + batch_size]
        
        # 计算这一批次的目标长度（基于原始文本长度）
        # 这里为了计算长度，只进行简单的 tokenize，不占用显存
        batch_lens = [len(t) for t in tokenizer(batch_prompts, add_special_tokens=False)["input_ids"]]
        
        # 动态设定 max_length: 原始最长长度 + 一点冗余，但不超过硬上限
        current_max_target = min(max(batch_lens) + 10, max_len)

        # 准备输入：截断到 prompt_len (64)
        inputs = tokenizer(
            batch_prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=prompt_len
        ).to(device)
        
        with torch.no_grad():
            # 使用 BF16 加速 (如果 GPU 支持)
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                outputs = model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_length=current_max_target, 
                    do_sample=True,
                    top_k=50,
                    top_p=0.95,
                    pad_token_id=tokenizer.eos_token_id,
                    use_cache=True 
                )
        
        # 解码
        gen_texts_batch = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        synthetic_texts.extend(gen_texts_batch)
        
        # 及时释放显存引用，但【不要】调用 empty_cache()
        del inputs, outputs

    # --- 4. 恢复环境与最终格式化 ---
    tokenizer.padding_side = "right" # 恢复默认
    
    # 去除可能因解码失败产生的空行
    final_data = [t for t in synthetic_texts if len(t.strip()) > 0]
    print(f">>> [Post Clean] Final count: {len(final_data)} (Removed {len(synthetic_texts)-len(final_data)} empty/failed)")

    # 返回 Dataset 对象，直接喂给 Trainer
    return synthetic_texts

In [None]:
class MetricsEvaluator:
    def __init__(self, device="cuda"):
        self.device = device
        # 使用 GPT-2 Large 作为固定的“上帝视角”特征提取器和裁判
        print(">>> [Metrics] Loading Oracle Model (gpt2-large) for Evaluation...")
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device)
        self.model.eval()
        
        self.real_embeddings_cache = None
        self.real_mu = None
        self.real_sigma = None

    def get_embeddings_and_entropy(self, texts, batch_size=32):
        """
        同时计算 Embeddings (用于FID/OT) 和 Entropy (用于多样性分析)
        """
        all_embeddings = []
        all_entropies = []
        
        # 过滤空文本
        texts = [t for t in texts if len(t.strip()) > 0]
        
        for i in tqdm(range(0, len(texts), batch_size), desc="Computing Metrics", leave=False):
            batch_texts = texts[i : i + batch_size]
            inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(inputs.input_ids, output_hidden_states=True)
                
                # --- 1. 获取 Embedding (取最后一层 hidden state 的平均值) ---
                # last_hidden_state: [batch, seq_len, hidden_dim]
                # mask: [batch, seq_len]
                hidden_states = outputs.hidden_states[-1]
                mask = inputs.attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
                # Sum / Count (Mean Pooling ignoring padding)
                sum_embeddings = torch.sum(hidden_states * mask, 1)
                sum_mask = torch.clamp(mask.sum(1), min=1e-9)
                mean_embeddings = sum_embeddings / sum_mask
                all_embeddings.append(mean_embeddings.cpu().numpy())

                # --- 2. 计算 Entropy ---
                # Logits: [batch, seq_len, vocab]
                logits = outputs.logits
                probs = torch.softmax(logits, dim=-1)
                # Shannon Entropy: -sum(p * log(p))
                # 我们只关心非Padding部分的熵
                token_entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # [batch, seq_len]
                
                # 计算每个句子的平均熵
                active_elements = inputs.attention_mask.sum(1)
                seq_entropy = (token_entropy * inputs.attention_mask).sum(1) / active_elements
                all_entropies.extend(seq_entropy.cpu().tolist())

        return np.concatenate(all_embeddings, axis=0), np.mean(all_entropies)

    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        """计算 FID / FBD"""
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            print("WARNING: fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError("Imaginary component {}".format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)
        return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)

    def calculate_ot_distance(self, emb1, emb2):
        """
        计算近似 OT 距离 (1-Wasserstein Distance on Embeddings)
        这里为了计算速度，计算每一维特征的 Wasserstein 距离的平均值 (Sliced Wasserstein 近似)
        """
        # 如果样本量太大，随机采样 1000 个进行计算以加速
        n_sample = min(1000, len(emb1), len(emb2))
        idx1 = np.random.choice(len(emb1), n_sample, replace=False)
        idx2 = np.random.choice(len(emb2), n_sample, replace=False)
        
        dists = []
        # 对 Embedding 的 1280 维 (gpt2-large) 分别计算分布距离
        # 这是一个简化版的 Sliced Wasserstein，如果用精确 OT (Sinkhorn) 会非常慢
        for i in range(emb1.shape[1]):
            d = wasserstein_distance(emb1[idx1, i], emb2[idx2, i])
            dists.append(d)
        
        return np.mean(dists)

    def evaluate_dataset(self, real_texts, synthetic_texts):
        # 1. 缓存真实数据的统计量 (只算一次)
        if self.real_embeddings_cache is None:
            print("    | [Metrics] Computing Real Data Embeddings...")
            self.real_embeddings_cache, self.real_entropy = self.get_embeddings_and_entropy(real_texts)
            self.real_mu = np.mean(self.real_embeddings_cache, axis=0)
            self.real_sigma = np.cov(self.real_embeddings_cache, rowvar=False)

        # 2. 计算合成数据统计量
        print("    | [Metrics] Computing Synthetic Data Embeddings...")
        syn_embeddings, syn_entropy = self.get_embeddings_and_entropy(synthetic_texts)
        syn_mu = np.mean(syn_embeddings, axis=0)
        syn_sigma = np.cov(syn_embeddings, rowvar=False)

        # 3. 计算 FID (FBD)
        fid = self.calculate_frechet_distance(self.real_mu, self.real_sigma, syn_mu, syn_sigma)
        
        # 4. 计算 OT 近似距离
        ot_dist = self.calculate_ot_distance(self.real_embeddings_cache, syn_embeddings)

        return {
            "fid": float(fid),
            "ot_distance": float(ot_dist),
            "syn_entropy": float(syn_entropy),
            "real_entropy": float(self.real_entropy) # 作为参考
        }