In [None]:
# # @title install
# !pip install bitsandbytes
# !pip install openai
# !pip install tqdm
# !pip install bitsandbytes accelerate
# !pip install geomloss
# !!rm -rf ~/.cache/huggingface/hub

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !nvidia-smi
# import torch
# torch.cuda.is_available()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnb
import math
from tqdm import tqdm
from peft import prepare_model_for_kbit_training

# ---- 1. Configuration ----
generate_id = "baichuan-inc/Baichuan2-7B-Chat"
eval_id = "01-ai/Yi-6B"
MAX_NEW_TOKENS = 80
CHUNK_TOKENS = 32
TAU = 0.5
KV_KEEP_LAST = 256

In [None]:
# ---- 2. 模型设置 ----

# 配置 BitsAndBytes 的 4bit 量化设置，用于节省显存
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                   # 启用 4bit 加载
    bnb_4bit_quant_type="nf4",          # 使用 nf4 量化类型（相较于fp4保留更多信息）
    bnb_4bit_use_double_quant=True,     # 启用双重量化，进一步减小模型大小
    bnb_4bit_compute_dtype=torch.bfloat16  # 在 bfloat16 精度下计算，提高效率
)

# 加载生成模型的 tokenizer
generate_tokenizer = AutoTokenizer.from_pretrained(generate_id, trust_remote_code=True)

# 加载生成模型并应用 4bit 量化配置，自动分配到可用设备
generate_model = AutoModelForCausalLM.from_pretrained(
    generate_id,
    quantization_config=bnb_config,
    device_map="auto",                  # 自动将模型分配到 CPU/GPU（支持多卡）
    trust_remote_code=True,              # 允许使用模型仓库中自定义的代码
    force_download=True                 # 强制下载模型
)

# --- 关键修复：显式禁用梯度检查点（gradient checkpointing） ---
# 防止与 use_cache=True 冲突，确保 KV cache 正常工作（生成速度更快）
generate_model = prepare_model_for_kbit_training(
    generate_model, use_gradient_checkpointing=False
)

# 加载评估模型的 tokenizer
eval_tokenizer = AutoTokenizer.from_pretrained(eval_id, trust_remote_code=True)

# 加载评估模型（非生成模型），使用 bfloat16 精度并转到指定设备
eval_model = AutoModel.from_pretrained(
    eval_id, torch_dtype=torch.bfloat16, trust_remote_code=True
).to(device).eval()  # 设置为 eval 模式，避免更新权重

# 冻结评估模型参数，防止在训练中被更新
for p in eval_model.parameters():
    p.requires_grad = False

# 获取生成模型所在的设备（用于后续保持一致）
device = next(generate_model.parameters()).device


In [None]:
print(type(generate_model.model.layers[0].self_attn.W_pack.weight))


In [None]:
from tqdm import tqdm, trange
import random
import torch.nn.functional as F
import math
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau
# ---- 3. 桥接层和优化器设置 ----

# 获取生成模型的输入嵌入层权重（即 token embedding 矩阵）
E_gen = generate_model.get_input_embeddings().weight

# 生成模型 embedding 的维度（如 4096）
d_g = E_gen.size(1)

# 评估模型 embedding 的维度（如 1024）
d_e = eval_model.get_input_embeddings().embedding_dim

# 获取评估模型嵌入层使用的数据类型（通常是 bfloat16）
eval_dtype = eval_model.get_input_embeddings().weight.dtype

# 尝试从生成模型中读取已有的“桥接层”属性（避免重复创建）
try:
    bridge_E = generate_model._embed_bridge
except AttributeError:
    # 如果没有现成的桥接层，则新建一个从 d_b → d_y 的线性层（无偏置）
    bridge_E = nn.Linear(d_g, d_e, bias=False)
    # 把这个桥接层挂载到生成模型上（作为一个临时成员变量）
    generate_model._embed_bridge = bridge_E

# 把桥接层移动到对应设备，并设置为评估模型的 dtype（如 bfloat16）
bridge_E = bridge_E.to(device=device, dtype=eval_dtype)

# 启用桥接层的梯度（让其参与训练）
for p in bridge_E.parameters():
    p.requires_grad = True






In [None]:
# @title kv-cache
# 剪裁 KV 缓存长度，避免无限增长导致显存爆炸
def trim_past_key_values(past, keep_last: int):
    if past is None or keep_last is None:
        return past
    trimmed = []
    for k, v in past:
        k = k[..., -keep_last:, :].contiguous()
        v = v[..., -keep_last:, :].contiguous()
        trimmed.append((k, v))
    return tuple(trimmed)

In [None]:
# @title get_rep_loss
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def get_rep_loss(
    poem_vecs: dict[int, torch.Tensor],
    good_poems_vecs: dict[int, torch.Tensor],
    top_p: float = 0.1,
    layer: int = 31,
    plot: bool = True
) -> torch.Tensor:
    """
    计算重复率损失：对每首生成诗，与好诗计算余弦相似度，选取 Top-P% 最大值求平均。
    越高表示重复越严重。

    Args:
        poem_vecs (dict[int, Tensor]): 当前生成诗的向量表示
        good_poems_vecs (dict[int, Tensor]): 好诗的向量表示
        top_p (float): 百分比 (0~1)，表示每首诗保留相似度前 top_p 的部分
        layer (int): 使用的层号
        plot (bool): 是否绘制相似度直方图

    Returns:
        Tensor: 平均 top-p 相似度，作为重复率损失
    """
    if layer not in poem_vecs or layer not in good_poems_vecs:
        raise ValueError(f"Layer {layer} not found in input vectors.")

    vec_gen = F.normalize(poem_vecs[layer].clone(), dim=-1)       # [B, D]
    vec_good = F.normalize(good_poems_vecs[layer].clone(), dim=-1)  # [N, D]

    sim_matrix = torch.matmul(vec_gen, vec_good.T)  # [B, N]
    B, N = sim_matrix.shape

    # 动态计算 k based on percentage
    k = max(1, int(N * top_p))  # 至少为1
    topk_values, _ = torch.topk(sim_matrix, k=k, dim=1)  # [B, k]
    loss = topk_values.mean()

    # 可视化相似度分布
    if plot:
        with torch.no_grad():
            sims_np = sim_matrix.flatten().cpu().numpy()
            plt.figure(figsize=(8, 5))
            plt.hist(sims_np, bins=50, color='skyblue', edgecolor='black')
            plt.title(f"Cosine Similarity Histogram (Layer {layer}, Top-{int(top_p*100)}%)")
            plt.xlabel("Cosine Similarity")
            plt.ylabel("Frequency")
            plt.grid(True)
            plt.tight_layout()
            plt.show()

    return loss


In [None]:
# @title get_quality_loss
import torch
import torch.nn.functional as F
from geomloss import SamplesLoss
import matplotlib.pyplot as plt

def get_quality_score(
    poem_vecs: dict[int, torch.Tensor],
    good_poems_vecs: dict[int, torch.Tensor],
    blur: float = 0.05,
    alpha: float = 0.5,
    plot: bool = True
) -> torch.Tensor:
    """
    计算生成诗与好诗在多个层上的 Sinkhorn 距离（单位化后），
    并对第 25 和 31 层的结果进行加权平均，作为最终质量损失。

    Args:
        poem_vecs (dict[int, Tensor]): 每层的生成诗向量
        good_poems_vecs (dict[int, Tensor]): 每层的好诗向量
        blur (float): Sinkhorn 距离的模糊参数
        alpha (float): 第 31 层与第 25 层的加权比例（默认 0.5）
        plot (bool): 是否可视化每层 Sinkhorn 距离

    Returns:
        Tensor: 加权后的质量损失
    """
    assert 25 in poem_vecs and 31 in poem_vecs, "缺少第 25 或 31 层的向量"

    loss_fn = SamplesLoss("sinkhorn", p=2, blur=blur)
    layer_losses = {}

    for layer in range(25, 33):
        if layer not in poem_vecs or layer not in good_poems_vecs:
            raise ValueError(f"Layer {layer} missing in inputs.")

        vec_gen = F.normalize(poem_vecs[layer].clone(), dim=-1)
        vec_good = F.normalize(good_poems_vecs[layer].clone(), dim=-1)
        layer_losses[layer] = loss_fn(vec_gen, vec_good)

    if plot:
        plt.figure(figsize=(8, 5))
        plt.bar(
            [f"Layer {i}" for i in range(25, 33)],
            [layer_losses[i].item() for i in range(25, 33)],
            color="skyblue"
        )
        plt.axvline(x=0, color="orange", linestyle="--", label="Used in Loss (25)")
        plt.axvline(x=6, color="green", linestyle="--", label="Used in Loss (31)")
        plt.title("Sinkhorn Distance per Layer (Normalized)")
        plt.ylabel("Sinkhorn Distance")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()

    total_loss = alpha * layer_losses[31] + (1 - alpha) * layer_losses[25]
    return total_loss


In [None]:
# @title get_diversity_loss
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def get_diversity_loss(
    poem_vecs: dict[int, torch.Tensor],
    layer: int = 31,
    plot: bool = True
) -> torch.Tensor:
    """
    计算生成诗歌在多个层的多样性（平均余弦相似度），
    并返回指定层的值作为损失。

    Args:
        poem_vecs (dict[int, Tensor]): 每层的生成诗表示
        layer (int): 要返回的目标层
        plot (bool): 是否可视化所有层的 diversity 分布

    Returns:
        Tensor: 指定层的平均余弦相似度（越高越不多样）
    """
    layer_losses = {}

    for lyr in range(25, 33):
        vec_gen = F.normalize(poem_vecs[lyr].clone(), dim=-1)  # [B, D]

        sim_matrix = torch.matmul(vec_gen, vec_gen.T)  # [B, B]
        B = sim_matrix.size(0)
        mask = torch.triu(torch.ones(B, B, device=vec_gen.device), diagonal=1).bool()
        selected_sims = sim_matrix[mask]

        if selected_sims.numel() == 0:
            loss = torch.tensor(0.0, device=vec_gen.device, requires_grad=True)
        else:
            loss = selected_sims.mean()

        layer_losses[lyr] = loss

    if plot:
        layers = list(layer_losses.keys())
        values = [layer_losses[l].detach().cpu().item() for l in layers]
        plt.figure(figsize=(8, 5))
        plt.plot(layers, values, marker="o", linestyle="-", color="green")
        plt.title("Diversity (Avg Cosine Similarity) per Layer")
        plt.xlabel("Layer")
        plt.ylabel("Avg Cosine Similarity (Lower = More Diverse)")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    return layer_losses[layer]


In [None]:
# @title load good_poem_vecs
# 加载好诗向量并放到相同设备
good_poem_vecs = torch.load("/content/drive/MyDrive/good_poems_vecs.pt")
good_poem_vecs = {layer: good_poem_vecs[layer].to(device) for layer in range(25, 33)}

In [None]:
# @title evaluate_on_val_set
def evaluate_on_val_set(
    val_theme_ids, label_to_theme, good_poems_vecs, W,
    generate_model, eval_model, bridge_E,
    GROUP_SIZE=10, MAX_NEW_TOKENS=80, CHUNK_TOKENS=32,
    TAU=0.5, KV_KEEP_LAST=256, E_gen=None, eval_dtype=torch.float32, device="cuda"
):
    generate_model.eval()
    val_vecs_tmp = []

    theme_idx = random.choice(val_theme_ids)
    val_prompt = f"你是一位唐代诗人，擅长写唐诗。请原创一首{label_to_theme[theme_idx]}主题的唐诗，每一句五或七个字，共四句，只输出四句正文，不输出题目。"

    with torch.no_grad():
        for _ in range(GROUP_SIZE):
            enc = generate_tokenizer(val_prompt, return_tensors="pt").to(device)
            out = generate_model(input_ids=enc["input_ids"], use_cache=True)
            past_key_values = out.past_key_values

            cur_gen_embeds = generate_model.get_input_embeddings()(enc["input_ids"][:, -1:])
            chunk_eval_soft_embeds = []

            num_chunks = math.ceil(MAX_NEW_TOKENS / CHUNK_TOKENS)
            for ci in range(num_chunks):
                steps_this_chunk = min(CHUNK_TOKENS, MAX_NEW_TOKENS - ci * CHUNK_TOKENS)
                for _ in range(steps_this_chunk):
                    out = generate_model(
                        inputs_embeds=cur_gen_embeds,
                        past_key_values=past_key_values,
                        use_cache=True
                    )
                    logits = out.logits[:, -1, :]
                    past_key_values = out.past_key_values
                    past_key_values = tuple((k.detach(), v.detach()) for k, v in past_key_values)

                    y_soft = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1).unsqueeze(1)
                    next_gen_embed = y_soft.to(E_gen.dtype) @ E_gen
                    cur_gen_embeds = next_gen_embed

                    next_eval_embed = bridge_E(next_gen_embed.to(dtype=eval_dtype))
                    chunk_eval_soft_embeds.append(next_eval_embed)

            eval_inputs = torch.cat(chunk_eval_soft_embeds, dim=1)
            out_eval = eval_model(inputs_embeds=eval_inputs, output_hidden_states=True)

            poem_vec = {layer: out_eval.hidden_states[layer].float().mean(dim=1) for layer in range(25, 33)}
            val_vecs_tmp.append(poem_vec)

        group_poem_vecs = {layer: torch.cat([d[layer] for d in val_vecs_tmp], dim=0) for layer in range(25, 33)}

        rep_loss = get_rep_loss(poem_vecs=group_poem_vecs, good_poems_vecs=good_poems_vecs, top_p=0.1, layer=31, plot=False)
        quality_loss = get_quality_score(poem_vecs=group_poem_vecs, good_poems_vecs=good_poems_vecs, blur=0.05, alpha=0.5, plot=False)
        diversity_loss = get_diversity_loss(poem_vecs=group_poem_vecs, layer=31, plot=False)

        total_val_loss = W[0] * rep_loss + W[1] * quality_loss + W[2] * diversity_loss
        return total_val_loss.item()



In [None]:
# @title train_one_epoch
def train_one_epoch(
    epoch, theme_idx, label_to_theme, good_poem_vecs, val_theme_ids, W,
    generate_model, eval_model, bridge_E, optimizer,optim_params,
    GROUP_SIZE=10, MAX_NEW_TOKENS=80, CHUNK_TOKENS=32, TAU=0.5, KV_KEEP_LAST=256,
    E_gen=None, eval_dtype=torch.float32, device="cuda",
    GRAD_ACCUM_STEPS=4, MAX_GRAD_NORM=1.0
):
    generate_model.train()
    optimizer.zero_grad(set_to_none=True)
    group_poem_vecs_tmp = []
    num_chunks = math.ceil(MAX_NEW_TOKENS / CHUNK_TOKENS)

    train_prompt = f"你是一位唐代诗人，擅长写唐诗。请原创一首{label_to_theme[theme_idx]}主题的唐诗，每一句五或七个字，共四句，只输出四句正文，不输出题目。"

    for gi in range(GROUP_SIZE):
        enc = generate_tokenizer(train_prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            init_out = generate_model(input_ids=enc["input_ids"], use_cache=True)
            past_key_values = init_out.past_key_values

        cur_gen_embeds = generate_model.get_input_embeddings()(enc["input_ids"][:, -1:])
        chunk_eval_soft_embeds = []

        for ci in range(num_chunks):
            steps_this_chunk = min(CHUNK_TOKENS, MAX_NEW_TOKENS - ci * CHUNK_TOKENS)
            for _ in range(steps_this_chunk):
                out = generate_model(
                    inputs_embeds=cur_gen_embeds,
                    past_key_values=past_key_values,
                    use_cache=True
                )
                logits = out.logits[:, -1, :]
                past_key_values = out.past_key_values
                past_key_values = trim_past_key_values(past_key_values, keep_last=KV_KEEP_LAST)
                past_key_values = tuple((k.detach(), v.detach()) for k, v in past_key_values)

                y_soft = F.gumbel_softmax(logits, tau=TAU, hard=False, dim=-1).unsqueeze(1)
                next_gen_embed = y_soft.to(E_gen.dtype) @ E_gen
                cur_gen_embeds = next_gen_embed

                next_eval_embed = bridge_E(next_gen_embed.to(dtype=eval_dtype))
                chunk_eval_soft_embeds.append(next_eval_embed)

        eval_inputs = torch.cat(chunk_eval_soft_embeds, dim=1)
        out_eval = eval_model(inputs_embeds=eval_inputs, output_hidden_states=True)

        poem_vec = {layer: out_eval.hidden_states[layer].float().mean(dim=1) for layer in range(25, 33)}
        group_poem_vecs_tmp.append(poem_vec)

        # 梯度累积：每 gi 执行一次 loss.backward
        if (gi + 1) % GRAD_ACCUM_STEPS == 0 or (gi + 1) == GROUP_SIZE:
            group_poem_vecs = {
                layer: torch.cat([poem[layer] for poem in group_poem_vecs_tmp], dim=0)
                for layer in range(25, 33)
            }

            rep_loss = get_rep_loss(poem_vecs=group_poem_vecs, good_poems_vecs=good_poem_vecs, top_p=0.1, layer=31, plot=False)
            quality_loss = get_quality_score(poem_vecs=group_poem_vecs, good_poems_vecs=good_poem_vecs, blur=0.05, alpha=0.5, plot=False)
            diversity_loss = get_diversity_loss(poem_vecs=group_poem_vecs, layer=31, plot=False)
            loss = W[0] * rep_loss + W[1] * quality_loss + W[2] * diversity_loss

            # 梯度累积分摊
            (loss / GRAD_ACCUM_STEPS).backward()

            # 梯度裁剪 + 参数更新
            torch.nn.utils.clip_grad_norm_(optim_params, MAX_GRAD_NORM)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # 清空临时缓存
            group_poem_vecs_tmp.clear()

    # 验证集评估
    val_loss = evaluate_on_val_set(
        val_theme_ids, label_to_theme, good_poem_vecs, W,
        generate_model, eval_model, bridge_E,
        GROUP_SIZE=GROUP_SIZE, MAX_NEW_TOKENS=MAX_NEW_TOKENS, CHUNK_TOKENS=CHUNK_TOKENS,
        TAU=TAU, KV_KEEP_LAST=KV_KEEP_LAST,
        E_gen=E_gen, eval_dtype=eval_dtype, device=device
    )

    lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch} finished:")
    print(f"rep_loss = {rep_loss.item():.4f}, quality_loss = {quality_loss.item():.4f}, "
          f"diversity_loss = {diversity_loss.item():.4f}, val_loss = {val_loss:.4f}, "
          f"total_loss = {loss.item():.4f}, lr = {lr:.2e}")

    return {
        "epoch": epoch,
        "theme_idx": theme_idx,
        "rep_loss": rep_loss.item(),
        "quality_loss": quality_loss.item(),
        "diversity_loss": diversity_loss.item(),
        "total_loss": loss.item(),
        "val_loss": val_loss,
        "lr": lr,
        "theme_idx": theme_idx
    }


In [None]:
# 原始主题列表
label_to_theme = {
    0: "爱情",
    1: "送别",
    2: "思乡",
    3: "田园",
    4: "边塞",
    5: "咏史怀古",
    6: "咏物"
}


# 明确划分
train_theme_ids = [0, 1, 2, 3, 4, 5, 6]
val_theme_ids = [0, 1, 2, 3, 4, 5, 6]


In [None]:
GROUP_SIZE = 20
LEARNING_RATE = 1e-4 #1e-5
EPOCH_PER_ROUND = 2
ROUND_NUM = 70
TOTAL_EPOCHS = ROUND_NUM * EPOCH_PER_ROUND
GRAD_ACCUM_STEPS = 4
MAX_GRAD_NORM = 1.0
W = [1, 1.5, 1]#rep, quality, diversity



In [None]:
from tqdm import trange
import os
import json
from torch.optim.lr_scheduler import CosineAnnealingLR



# 收集所有需要优化的参数（生成模型中可训练参数 + 桥接层参数）
optim_params = list(filter(lambda p: p.requires_grad, generate_model.parameters())) + list(bridge_E.parameters())

# 使用 8bit AdamW 优化器（来自 bitsandbytes），设置学习率
optimizer = bnb.optim.AdamW8bit(optim_params, lr=LEARNING_RATE)
scheduler = CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS, eta_min=1e-6)

checkpoint = torch.load("/content/drive/MyDrive/checkpoints/poetry_model_epoch140.pth", map_location=device)
generate_model.load_state_dict(checkpoint["generate_model"],strict=False)
bridge_E.load_state_dict(checkpoint["bridge_E"])
optimizer.load_state_dict(checkpoint["optimizer"])
global_epoch = checkpoint["epoch"] + 1

# === 加载 training_logs（只保留有效部分）===
if os.path.exists("training_logs.json"):
    with open("training_logs.json", "r", encoding="utf-8") as f:
        training_logs = json.load(f)
    print(f"已加载 {len(training_logs)} 条训练日志，保留至 epoch {global_epoch - 1}")
    training_logs = [log for log in training_logs if log["epoch"] < global_epoch]
else:
    training_logs = []
    print("未发现训练日志，将重新开始记录")


with tqdm(total=TOTAL_EPOCHS, desc="Training Epochs", ncols=100) as pbar:
    pbar.n = global_epoch - 1  # 正确设置起始位置
    for round_idx in range(ROUND_NUM):
        # === 随机抽取一个 theme ===
        theme_idx = random.choice(train_theme_ids)

        for local_epoch in range(EPOCH_PER_ROUND):
            print(f"[Round {round_idx + 1}/{ROUND_NUM} | Theme: {label_to_theme[theme_idx]} "
                  f"| Local Epoch {local_epoch + 1}/{EPOCH_PER_ROUND}] "
                  f"(Global Epoch {global_epoch}/{TOTAL_EPOCHS})")

            log = train_one_epoch(
                global_epoch,
                theme_idx,
                label_to_theme,
                good_poem_vecs,
                val_theme_ids,
                W,
                generate_model,
                eval_model,
                bridge_E,
                optimizer=optimizer,
                optim_params=optim_params,
                GROUP_SIZE=GROUP_SIZE,
                E_gen=E_gen,
                eval_dtype=eval_dtype,
                device=device,
                GRAD_ACCUM_STEPS=GRAD_ACCUM_STEPS,
                MAX_GRAD_NORM=MAX_GRAD_NORM
            )

            training_logs.append(log)

            # 保存日志（本地 + Google Drive）
            with open("training_logs.json", "w", encoding="utf-8") as f:
                json.dump(training_logs, f, ensure_ascii=False, indent=2)

            with open("/content/drive/MyDrive/training_data/training_logs.json", "w", encoding="utf-8") as f:
                json.dump(training_logs, f, ensure_ascii=False, indent=2)

            # 每 28 epoch 保存一次模型
            if global_epoch % 20 == 0:
                local_path = f"./checkpoints/poetry_model_epoch{global_epoch}.pth"
                drive_path = f"/content/drive/MyDrive/training_data/poetry_model_epoch{global_epoch}.pth"
                os.makedirs(os.path.dirname(local_path), exist_ok=True)
                os.makedirs(os.path.dirname(drive_path), exist_ok=True)

                model_state = {
                    "generate_model": generate_model.state_dict(),
                    "bridge_E": bridge_E.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": global_epoch,
                }

                torch.save(model_state, local_path)
                torch.save(model_state, drive_path)
                print(f"已保存中间模型至本地：{local_path}")
                print(f"已备份模型至 Google Drive：{drive_path}")

            global_epoch += 1
            pbar.update(1)

            scheduler.step()

# === 训练结束，保存最终模型 ===
final_local = "./poetry_model_final.pth"
final_drive = "/content/drive/MyDrive/training_data/poetry_model_final.pth"

final_state = {
    "generate_model": generate_model.state_dict(),
    "bridge_E": bridge_E.state_dict(),
    "optimizer": optimizer.state_dict(),
    "epoch": global_epoch,
}

torch.save(final_state, final_local)
torch.save(final_state, final_drive)

print(f"最终模型已保存至本地：{final_local}")
print(f"最终模型已备份至 Google Drive：{final_drive}")



In [None]:
import matplotlib.pyplot as plt

def plot_training_logs(logs: list[dict]):
    """
    绘制训练过程中各项损失、验证损失及学习率的变化曲线。
    Args:
        logs (list[dict]): 每轮训练的记录，包括 loss 项、val_loss 和 lr（可选）
    """
    if not logs:
        print("日志为空，无法绘图。")
        return

    epochs = [int(log["epoch"]) for log in logs]
    rep_losses = [log["rep_loss"] for log in logs]
    quality_losses = [log["quality_loss"] for log in logs]
    diversity_losses = [log["diversity_loss"] for log in logs]
    total_losses = [log["total_loss"] for log in logs]
    val_losses = [log.get("val_loss", None) for log in logs]
    lrs = [log.get("lr", None) for log in logs]

    fig, ax1 = plt.subplots(figsize=(10, 6))

    # --- 主轴：loss 曲线 ---
    ax1.plot(epochs, rep_losses, marker="o", label="Rep Loss")
    ax1.plot(epochs, quality_losses, marker="s", label="Quality Loss")
    ax1.plot(epochs, diversity_losses, marker="^", label="Diversity Loss")
    ax1.plot(epochs, total_losses, marker="*", linestyle="--", label="Total Loss")

    # 验证损失（如果存在）
    if None not in val_losses:
        ax1.plot(epochs, val_losses, color="orange", marker="D", linestyle=":", label="Val Loss")

    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss Value")
    ax1.legend(loc="upper left")
    ax1.grid(True)

    # --- 副轴：学习率曲线 ---
    if None not in lrs:
        ax2 = ax1.twinx()
        ax2.plot(epochs, lrs, color="purple", linestyle="-", marker="x", label="Learning Rate")
        ax2.set_ylabel("Learning Rate", color="purple")
        ax2.tick_params(axis='y', labelcolor="purple")
        ax2.legend(loc="upper right")

    plt.title("Training & Validation Losses with Learning Rate Per Epoch")
    plt.tight_layout()
    plt.show()


In [None]:
plot_training_logs(training_logs)