In [15]:
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as tvm
import timm
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import torchvision.models as models
from torchvision.models import ResNet50_Weights 

device = "cuda" if torch.cuda.is_available() else "cpu"

### check the picture size if at least 100*100 pixels

In [16]:
from PIL import Image
# open the file
img = Image.open("Animals_with_Attributes2/JPEGImages/antelope/antelope_10002.jpg")  
width, height = img.size
print(f"picture width: {width}, and height: {height}")


picture width: 1024, and height: 768


# Part1 loading the dataset

In [17]:
import os, math, random, torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms as T, models
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # 兜底

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ResNet50 冻结为 2048 维特征
try:
    from torchvision.models import ResNet50_Weights
    resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    preprocess_eval = ResNet50_Weights.IMAGENET1K_V2.transforms()
except Exception:
    resnet = models.resnet50(pretrained=True)
    preprocess_eval = T.Compose([
        T.Resize((224,224)),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
resnet.fc = nn.Identity()
resnet = resnet.to(device).eval()
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7fdfaad869c0>

In [18]:
# 高斯噪声变换
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean, self.std = mean, std
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor)*self.std + self.mean

# 训练增强（旋转±15°、亮度抖动、噪声）+ 与 ResNet 匹配的归一化
train_tf = T.Compose([
    T.Resize((224,224)),
    T.RandomApply([T.RandomRotation(degrees=15)], p=0.8),
    T.ColorJitter(brightness=0.2),          # ±20% 亮度
    T.ToTensor(),
    AddGaussianNoise(std=0.03),             # 适中噪声
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# 测试/验证不做增强（与预训练分布一致）
test_tf = preprocess_eval


In [19]:
# 过滤坏图（可选，建议保留）
def is_valid_image(p: str) -> bool:
    try:
        with Image.open(p) as im:
            im.verify()
        return True
    except Exception:
        return False

# 构建“干净”的全集（先用 test_tf；训练时再用同索引替换 transform）
root = "Animals_with_Attributes2/JPEGImages"
whole_ds = ImageFolder(root=root, transform=test_tf, is_valid_file=is_valid_image)
class_to_idx = whole_ds.class_to_idx  # {'antelope':0,...}

# 每个样本 (path, label)
samples = whole_ds.samples             # list[(path,label)]
labels  = [lbl for _, lbl in samples]
num_classes = len(class_to_idx)        # 50

# —— 类级 10 折（按类ID均分，确保 Zero-Shot）——
all_classes = list(range(num_classes))
random.Random(42).shuffle(all_classes)
fold_size = math.ceil(num_classes/10)
fold_classes = [all_classes[i*fold_size:(i+1)*fold_size] for i in range(10)]

def idx_of_classes(lbls, cls_set):
    s = set(cls_set); return [i for i, y in enumerate(lbls) if y in s]

def make_dataloaders_for_fold(k, batch=64, nw=4):
    unseen_cls = fold_classes[k]                         # 本折的 unseen 类
    seen_cls   = [c for c in all_classes if c not in unseen_cls]

    tr_idx = idx_of_classes(labels, seen_cls)
    te_idx = idx_of_classes(labels, unseen_cls)

    # 训练子集用“训练增强” transform
    train_subset = Subset(ImageFolder(root=root, transform=train_tf, is_valid_file=is_valid_image), tr_idx)
    # 测试子集用“评估” transform
    test_subset  = Subset(whole_ds, te_idx)  # whole_ds 已是 test_tf

    train_loader = DataLoader(train_subset, batch_size=batch, shuffle=True,
                              num_workers=nw, pin_memory=True, persistent_workers=(nw>0))
    test_loader  = DataLoader(test_subset,  batch_size=batch, shuffle=False,
                              num_workers=nw, pin_memory=True, persistent_workers=(nw>0))
    return train_loader, test_loader, torch.tensor(seen_cls), torch.tensor(unseen_cls)


In [20]:
import re

def read_title_and_desc_clean(txt_path):
    """读取 AwA2 licenses txt 文件中的 TITLE 和 DESCRIPTION 字段（去掉符号框线）"""
    title, desc = "", ""
    with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
        lines = [line.strip() for line in f if line.strip()]

    clean = lambda s: re.sub(r'[\+\-\|\_]+', '', s).strip()  # 删除 + - | _
    
    for i, line in enumerate(lines):
        if "TITLE" in line.upper() and i + 1 < len(lines):
            title = clean(lines[i + 1])
        if "DESCRIPTION" in line.upper():
            desc_lines = []
            for j in range(i + 1, len(lines)):
                if any(k in lines[j].upper() for k in ["TITLE", "INFO", "TAGS", "PHOTOGRAPHER", "LICENSE"]):
                    break
                desc_lines.append(clean(lines[j]))
            desc = " ".join(desc_lines).strip()
            break
    return title, desc

In [21]:
#check
txt_path = "Animals_with_Attributes2/licenses/antelope/antelope_10021.txt"
title, desc = read_title_and_desc_clean(txt_path)
print("TITLE:", title)
print("DESCRIPTION:", desc[:200], "...")

TITLE: 
DESCRIPTION: You are free to use this photo  (including commercial use) under attribution to the author. If being used online please add a link to <a href="http://ujora.de" rel="nofollow">ujora.de</a> Dieses Foto  ...


In [22]:
import pandas as pd
IMG_ROOT = "Animals_with_Attributes2/JPEGImages"
TXT_ROOT = "Animals_with_Attributes2/licenses"

# 用于与标签对齐（确保顺序和 y_test 的 0..49 一致）
class_to_idx = datasets.ImageFolder(IMG_ROOT).class_to_idx  # {'antelope':0, ...}
idx_to_class = {v: k for k, v in class_to_idx.items()}

# 读取并清洗单个 txt 的 TITLE 和 DESCRIPTION（去掉框线/下划线/HTML标签）
def read_title_and_desc_clean(p):
    title, desc = "", ""
    clean = lambda s: re.sub(r'[\+\-\|\_]+', '', s).strip()
    with open(p, "r", encoding="utf-8", errors="ignore") as f:
        lines = [line.strip() for line in f if line.strip()]

    for i, line in enumerate(lines):
        U = line.upper()
        if "TITLE" in U and i + 1 < len(lines):
            title = clean(lines[i + 1])
        if "DESCRIPTION" in U:
            buf = []
            for j in range(i + 1, len(lines)):
                if any(k in lines[j].upper() for k in ["TITLE","INFO","TAGS","PHOTOGRAPHER","LICENSE"]):
                    break
                buf.append(clean(lines[j]))
            desc = " ".join(buf).strip()
            break
    # 去掉 HTML 标签
    title = re.sub(r"<.*?>", "", title)
    desc  = re.sub(r"<.*?>", "", desc)
    return title, desc

# 按类别汇总：每个子文件夹 -> 拼接所有 txt 的 title+desc
rows = []
for cls in sorted(os.listdir(TXT_ROOT)):
    cls_dir = os.path.join(TXT_ROOT, cls)
    if not os.path.isdir(cls_dir):
        continue
    pieces = []
    for fname in sorted(os.listdir(cls_dir)):
        if fname.endswith(".txt"):
            t, d = read_title_and_desc_clean(os.path.join(cls_dir, fname))
            text = " ".join([t, d]).strip()
            if text:
                pieces.append(text)
    merged = " ".join(pieces)              # 该类别的整合文本
    rows.append({"class": cls, "text": merged, "n_txt": len(pieces)})

# 变成 DataFrame，并按 label 顺序(0..49) 排好，方便与 y_test 对齐
df = pd.DataFrame(rows)
df["label"] = df["class"].map(class_to_idx)
df = df.sort_values("label").reset_index(drop=True)
df = df.iloc[:-1].reset_index(drop=True)
print("行数(应为50):", len(df))
print(df[["label","class","n_txt"]].head())


行数(应为50): 50
   label       class  n_txt
0    0.0    antelope   1046
1    1.0         bat    178
2    2.0      beaver    147
3    3.0  blue+whale    174
4    4.0      bobcat    627


In [23]:
df.head()

Unnamed: 0,class,text,n_txt,label
0,antelope,"\And God said, Let the earth bring forth the l...",1046,0.0
1,bat,Found below the power lines at Hamilton Beach....,178,1.0
2,beaver,the local beavers on Christmas day 2007 (no de...,147,2.0
3,blue+whale,(no description) Free Fall breaching. (no desc...,174,3.0
4,bobcat,"One of the cubs walking in the enclosure, unde...",627,4.0


In [24]:
from sklearn.feature_extraction.text import TfidfVectorizer

# 你已经有 df（50行），每行是一个类别的 text
texts = df["text"].fillna("").tolist()

# 1) 定义 TF-IDF 模型
tfidf = TfidfVectorizer(
    max_features=1000,       # 取前1000个高频特征，可调
    stop_words="english",    # 去除英文停用词
    lowercase=True           # 全部转小写
)

# 2) 拟合并变换
X_tfidf = tfidf.fit_transform(texts)      # shape (50, vocab_size)
print("TF-IDF shape:", X_tfidf.shape)

# 3) 转成 torch.Tensor，方便和你的 test_whole 拼接
X_tfidf_tensor = torch.tensor(X_tfidf.toarray(), dtype=torch.float32)



TF-IDF shape: (50, 1000)


In [25]:
S = X_tfidf_tensor.float()  

In [26]:

V = S.shape[1]

def extract_feats(dl):
    feats, lbls = [], []
    for imgs, y in dl:
        imgs = imgs.to(device, non_blocking=True)
        f = resnet(imgs).detach().cpu()      # [B, 2048]
        feats.append(f); lbls.append(y)
    return torch.cat(feats,0), torch.cat(lbls,0)

# —— 示例：跑一折，得到训练/测试特征，并与语义对齐 ——
k = 0  # 第 1 折
train_loader, test_loader, seen_cls, unseen_cls = make_dataloaders_for_fold(k)

X_tr, y_tr = extract_feats(train_loader)     # [Ntr, 2048], [Ntr]
X_te, y_te = extract_feats(test_loader)      # [Nte, 2048], [Nte]

# 训练目标：样本所属类的语义向量
Y_tr = S[y_tr]                                # [Ntr, V]

# ——（可选）拼接通道：把图像特征与其类语义拼接，做一些对照基线 —— 
# 这在 ZSL 训练阶段常用于判别器/回归器的额外输入；真正测试时仍只用图像→语义映射。
X_tr_concat = torch.cat([X_tr, Y_tr], dim=1)  # [Ntr, 2048+V]


In [27]:
for k in range(10):
    train_loader, test_loader, seen_cls, unseen_cls = make_dataloaders_for_fold(k)
    print(f"Fold {k+1}:")
    print(f"  Seen class number={len(seen_cls)}, Unseen class number={len(unseen_cls)}")
    print(f"  training samples={len(train_loader.dataset)}, test samples={len(test_loader.dataset)}")
    print("  Seen class index:", seen_cls.tolist())
    print("  Unseen class index:", unseen_cls.tolist())
    print("-"*60)




Fold 1:
  Seen class number=45, Unseen class number=5
  training samples=33048, test samples=4258
  Seen class index: [45, 26, 9, 29, 16, 31, 21, 12, 3, 39, 38, 10, 24, 35, 0, 43, 18, 33, 48, 41, 30, 28, 20, 22, 42, 46, 36, 32, 44, 13, 49, 47, 2, 27, 37, 5, 34, 6, 8, 14, 15, 17, 1, 7, 40]
  Unseen class index: [25, 23, 19, 11, 4]
------------------------------------------------------------
Fold 2:
  Seen class number=45, Unseen class number=5
  training samples=33013, test samples=4293
  Seen class index: [25, 23, 19, 11, 4, 31, 21, 12, 3, 39, 38, 10, 24, 35, 0, 43, 18, 33, 48, 41, 30, 28, 20, 22, 42, 46, 36, 32, 44, 13, 49, 47, 2, 27, 37, 5, 34, 6, 8, 14, 15, 17, 1, 7, 40]
  Unseen class index: [45, 26, 9, 29, 16]
------------------------------------------------------------
Fold 3:
  Seen class number=45, Unseen class number=5
  training samples=33783, test samples=3523
  Seen class index: [25, 23, 19, 11, 4, 45, 26, 9, 29, 16, 38, 10, 24, 35, 0, 43, 18, 33, 48, 41, 30, 28, 20, 22, 42

In [28]:
#mlp

In [29]:
import math, random, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

def set_seed(s=42):
    random.seed(s); torch.manual_seed(s); 
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)

@torch.no_grad()
def extract_feats(dl, feature_net):
    feats, lbls = [], []
    for imgs, y in dl:
        imgs = imgs.to(device, non_blocking=True)
        f = feature_net(imgs).detach().cpu()     # [B, 2048]
        feats.append(f); lbls.append(y)
    return torch.cat(feats,0), torch.cat(lbls,0)

@torch.no_grad()
def zsl_acc(pred_sem, S_cand, true_lbls, cand_lbl_ids):
    pred_sem = F.normalize(pred_sem, dim=1)
    S_cand   = F.normalize(S_cand, dim=1)
    sims = pred_sem @ S_cand.T
    pred_idx = sims.argmax(dim=1)
    pred_lbl = cand_lbl_ids[pred_idx]
    return (pred_lbl == true_lbls).float().mean().item()


In [30]:
def split_classes_k3(seen_classes, seed=0):
    """把 seen 类平均分成 3 份，返回 [C1, C2, C3]（每份是类ID列表）"""
    clz = list(map(int, seen_classes))
    random.Random(seed).shuffle(clz)
    s = math.ceil(len(clz)/3)
    return [clz[:s], clz[s:2*s], clz[2*s:]]

def idx_of_classes(y_vec, cls_set):
    """给定样本标签向量 y 和类集合，返回属于这些类的样本索引列表"""
    ss = set(map(int, cls_set))
    y_list = y_vec if isinstance(y_vec, list) else y_vec.tolist()
    return [i for i, yy in enumerate(y_list) if int(yy) in ss]


In [31]:
class MLP(nn.Module):
    def __init__(self, in_dim, hid, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.ReLU(),
            nn.Linear(hid, out_dim),
        )
    def forward(self, x): return self.net(x)

def train_epoch(model, X, Y, opt, batch_size=512):
    model.train(); torch.set_grad_enabled(True)
    N = X.shape[0]
    perm = torch.randperm(N)
    loss_sum = 0.0
    for i in range(0, N, batch_size):
        idx = perm[i:i+batch_size]
        xb, yb = X[idx], Y[idx]
        pred = model(xb)
        loss = F.mse_loss(pred, yb)
        opt.zero_grad(); loss.backward(); opt.step()
        loss_sum += loss.item() * xb.size(0)
    torch.set_grad_enabled(False)
    return loss_sum / N


In [32]:
def tune_mlp_inner3(Xtr, ytr, S, seen_cls, grid, epochs=30, wd=0.0, seed=0):
    """
    Xtr: [Ntr, 2048] 训练特征（seen 样本）
    ytr: [Ntr]       训练标签（全局类ID）
    S:   [50, V]     类语义矩阵
    seen_cls: [Ns]   本折的 seen 类ID
    grid: list of dicts: [{'hid':..., 'lr':...}, ...]
    """
    parts = split_classes_k3(seen_cls, seed=seed)   # [C1, C2, C3]
    best_cfg, best_score, best_key = None, -1.0, None

    # 为了“中间值更常被选中”的要求，准备一个tie-break键
    hid_order = [512,1024,2048]; lr_order = [1e-2,1e-3,1e-4]

    for cfg in grid:
        scores = []
        for k in range(3):
            val_cls = parts[k]
            tr_cls  = parts[(k+1)%3] + parts[(k+2)%3]

            tr_idx = idx_of_classes(ytr, tr_cls)
            va_idx = idx_of_classes(ytr, val_cls)

            X_tr, Y_tr = Xtr[tr_idx], S[ytr[tr_idx]]
            X_va, y_va = Xtr[va_idx], ytr[va_idx]

            model = MLP(2048, cfg['hid'], S.shape[1]).to('cpu')
            opt   = torch.optim.Adam(model.parameters(), lr=cfg['lr'], weight_decay=wd)

            for ep in range(epochs):
                train_epoch(model, X_tr, Y_tr, opt)

            with torch.no_grad():
                pred_sem = model(X_va)
                acc = zsl_acc(pred_sem, S[val_cls], y_va, torch.tensor(val_cls))
            scores.append(acc)

        mean_acc = sum(scores)/len(scores)
        # tie-break：靠近中档优先
        key = (mean_acc, -abs(hid_order.index(cfg['hid'])-1), -abs(lr_order.index(cfg['lr'])-1))
        if (mean_acc > best_score) or (abs(mean_acc-best_score)<1e-12 and (best_key is None or key > best_key)):
            best_cfg, best_score, best_key = cfg, mean_acc, key

    return best_cfg, best_score


In [33]:
def tune_once_get_best_cfg(resnet, S, *, seed=42, inner_epochs=30, wd=0.0):
    """
    只在一次 k=3 内层上确定 best_cfg，然后全程复用。
    用第 0 折的 seen 类做伪 ZSL（split_classes_k3 + tune_mlp_inner3）。
    """
    set_seed(seed)
    k_tune = 0
    train_loader, _, seen_cls, _ = make_dataloaders_for_fold(k_tune)
    Xtr, ytr = extract_feats(train_loader, resnet)   # 你已有的接口签名

    grid = [{'hid': h, 'lr': lr}
            for h in [512, 1024, 2048]
            for lr in [1e-2, 1e-3, 1e-4]]

    # ⏳ 这里会训练（tune_mlp_inner3 里会跑 3 次 * len(grid) 轮）
    best_cfg, inner_score = tune_mlp_inner3(Xtr, ytr, S, seen_cls, grid,
                                            epochs=inner_epochs, wd=wd, seed=seed)
    print(f"[TuneOnce@fold0] best_cfg={best_cfg} | inner-3fold acc={inner_score:.4f}")
    return best_cfg, inner_score


In [34]:
def run_one_fold_with_cfg(k, resnet, S, best_cfg, *, epochs_final=40, seed=42):
    """
    使用已经确定的 best_cfg：
    - 训练：在本折全部 seen 训练样本上用 MLP 学图像->语义（⏳）
    - 测试：只在本折 unseen 类里做 ZSL 评估
    """
    set_seed(seed)

    # 1) 本折数据
    train_loader, test_loader, seen_cls, unseen_cls = make_dataloaders_for_fold(k)

    # 2) 提特征（复用你的函数）
    X_tr, y_tr = extract_feats(train_loader, resnet)  # [Ntr, 2048], [Ntr]
    X_te, y_te = extract_feats(test_loader,  resnet)  # [Nte, 2048], [Nte]

    # 3) 训练 MLP（⏳）
    torch.set_grad_enabled(True)
    model = MLP(2048, best_cfg['hid'], S.shape[1]).to('cpu')
    opt   = torch.optim.Adam(model.parameters(), lr=best_cfg['lr'])
    Y_tr  = S[y_tr]   # 训练目标：类语义
    for _ in range(epochs_final):
        train_epoch(model, X_tr, Y_tr, opt)
    torch.set_grad_enabled(False)

    # 4) 外层真实 ZSL 测试（只在 unseen 候选里做语义最近邻）
    with torch.no_grad():
        pred_sem = model(X_te)
        acc = zsl_acc(pred_sem, S[unseen_cls], y_te, torch.tensor(unseen_cls))
    return acc, best_cfg


In [36]:
# 固定随机种子
set_seed(42)

# 第一步：只跑一次 k=3 内层调参
best_cfg, inner_score = tune_once_get_best_cfg(resnet, S, seed=42, inner_epochs=30, wd=0.0)

# 第二步：用固定 best_cfg 跑 10 折（每折仅训练 MLP + 测试）
accs, cfgs = [], []
for k in range(10):
    acc, _ = run_one_fold_with_cfg(k, resnet, S, best_cfg, epochs_final=40, seed=42+k)
    accs.append(acc)
    cfgs.append(best_cfg)
    print(f"Fold {k+1}: acc={acc:.4f}")

# 汇总
mean_acc = sum(accs) / len(accs)
import math
std = math.sqrt(sum((a - mean_acc)**2 for a in accs) / (len(accs) - 1))
ci95 = 1.96 * std / math.sqrt(len(accs))

print("\n======== Summary ========")
print("Per-fold acc :", [f"{a:.4f}" for a in accs])
print(f"Mean acc     : {mean_acc:.4f}")
print(f"Std (10fold) : {std:.4f}")
print(f"95% CI       : ±{ci95:.4f}")
print(f"Fixed cfg    : {best_cfg}, (inner-3fold={inner_score:.4f})")


[TuneOnce@fold0] best_cfg={'hid': 512, 'lr': 0.001} | inner-3fold acc=0.1096


  acc = zsl_acc(pred_sem, S[unseen_cls], y_te, torch.tensor(unseen_cls))


Fold 1: acc=0.4605
Fold 2: acc=0.4990
Fold 3: acc=0.2552
Fold 4: acc=0.2939
Fold 5: acc=0.2946
Fold 6: acc=0.4452
Fold 7: acc=0.3520
Fold 8: acc=0.1516
Fold 9: acc=0.3123
Fold 10: acc=0.3034

Per-fold acc : ['0.4605', '0.4990', '0.2552', '0.2939', '0.2946', '0.4452', '0.3520', '0.1516', '0.3123', '0.3034']
Mean acc     : 0.3368
Std (10fold) : 0.1053
95% CI       : ±0.0653
Fixed cfg    : {'hid': 512, 'lr': 0.001}, (inner-3fold=0.1096)


In [38]:
# ===== f-CLSWGAN: 模型与训练 =====
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np, random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class _FeatDS(Dataset):
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        self.X = X.float().contiguous()
        self.y = y.long().contiguous()
    def __len__(self): return self.y.numel()
    def __getitem__(self, i): return self.X[i], self.y[i]

class G(nn.Module):  # 生成器
    def __init__(self, z_dim, a_dim, x_dim, hidden=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim + a_dim, hidden), nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, hidden),         nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, x_dim),
        )
    def forward(self, z, a): return self.net(torch.cat([z, a], dim=1))

class D(nn.Module):  # 判别器
    def __init__(self, x_dim, a_dim, hidden=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + a_dim, hidden), nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, hidden),         nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, 1),
        )
    def forward(self, x, a): return self.net(torch.cat([x, a], dim=1))

def _grad_penalty(Dmod, real_x, fake_x, a, gp_lambda=10.0):
    b = real_x.size(0)
    eps = torch.rand(b, 1, device=real_x.device).expand_as(real_x)
    inter = (eps*real_x + (1-eps)*fake_x).requires_grad_(True)
    d_inter = Dmod(inter, a)
    grads = torch.autograd.grad(d_inter, inter, torch.ones_like(d_inter),
                                create_graph=True, retain_graph=True)[0]
    return gp_lambda * ((grads.view(b, -1).norm(2, dim=1) - 1)**2).mean()

def train_wgan_seen(  # ⏳ 训练 GAN（只用 seen 样本）
    X_seen, y_seen_sid, A_seen, *,
    z_dim=100, g_hidden=2048, d_hidden=2048,
    lr=1e-4, n_epochs=40, batch_size=256, n_critic=5, gp_lambda=10.0, seed=42
):
    random.seed(seed); torch.manual_seed(seed); 
    x_dim, a_dim = X_seen.size(1), A_seen.size(1)
    Gmod, Dmod = G(z_dim, a_dim, x_dim, g_hidden).to(device), D(x_dim, a_dim, d_hidden).to(device)
    optG = torch.optim.Adam(Gmod.parameters(), lr=lr, betas=(0.5, 0.9))
    optD = torch.optim.Adam(Dmod.parameters(), lr=lr, betas=(0.5, 0.9))
    A_seen = A_seen.float().to(device)
    dl = DataLoader(_FeatDS(X_seen, y_seen_sid), batch_size=batch_size, shuffle=True, drop_last=True)

    for _ in range(n_epochs):
        for xb, yb_sid in dl:
            xb, yb_sid = xb.to(device), yb_sid.to(device)
            B = xb.size(0)
            for _ in range(n_critic):
                a = A_seen[yb_sid]; z = torch.randn(B, z_dim, device=device)
                x_fake = Gmod(z, a).detach()
                lossD = -(Dmod(xb, a).mean() - Dmod(x_fake, a).mean()) + _grad_penalty(Dmod, xb, x_fake, a, gp_lambda)
                optD.zero_grad(); lossD.backward(); optD.step()
            a = A_seen[yb_sid]; z = torch.randn(B, z_dim, device=device)
            lossG = - Dmod(Gmod(z, a), a).mean()
            optG.zero_grad(); lossG.backward(); optG.step()
    return Gmod

def synth_unseen(Gmod, A_unseen, n_per_class=300, z_dim=100):
    Gmod.eval()
    with torch.no_grad():
        K = A_unseen.size(0)
        z = torch.randn(K*n_per_class, z_dim, device=device)
        a = A_unseen.float().to(device).repeat_interleave(n_per_class, 0)
        Xsyn = Gmod(z, a).cpu()
    return Xsyn  # [K*n, x_dim]

# 轻量线性分类器（在合成特征上训练）
class LinearSoftmax(nn.Module):
    def __init__(self, in_dim, n_cls): super().__init__(); self.fc = nn.Linear(in_dim, n_cls)
    def forward(self, x): return self.fc(x)

def train_clf(  # ⏳ 训练分类器
    X, y, n_cls, *, epochs=30, lr=1e-3, wd=1e-4, bs=256, seed=42
):
    random.seed(seed); torch.manual_seed(seed)
    clf = LinearSoftmax(X.size(1), n_cls).to(device)
    opt = torch.optim.Adam(clf.parameters(), lr=lr, weight_decay=wd)
    dl = DataLoader(_FeatDS(X, y), batch_size=bs, shuffle=True)
    for _ in range(epochs):
        for xb, yb in dl:
            xb, yb = xb.to(device), yb.to(device)
            loss = F.cross_entropy(clf(xb), yb)
            opt.zero_grad(); loss.backward(); opt.step()
    return clf

def acc_of(clf, X, y):
    clf.eval()
    with torch.no_grad():
        pred = clf(X.to(device)).argmax(1).cpu()
    return float((pred == y).float().mean().item())


In [39]:
!pip -q install tqdm

[0m

In [47]:
# === f-CLSWGAN ① 依赖与小工具（无需等待） ===
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import numpy as np, random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed_all(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); 
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

class _FeatDS(Dataset):
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        self.X = X.float().contiguous()
        self.y = y.long().contiguous()
    def __len__(self): return self.y.numel()
    def __getitem__(self, i): return self.X[i], self.y[i]

def acc_of(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
    model.eval()
    with torch.no_grad():
        pred = model(X.to(device)).argmax(1).cpu()
    return float((pred == y).float().mean().item())


In [67]:
class G(nn.Module):  # 生成器
    def __init__(self, z_dim, a_dim, x_dim, hidden=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim + a_dim, hidden),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(hidden, x_dim),
        )
    def forward(self, z, a):
        # 保障 dtype/device 一致
        if a.dtype != z.dtype: a = a.to(z.dtype)
        if a.device != z.device: a = a.to(z.device)
        return self.net(torch.cat([z, a], dim=1))

class D(nn.Module):  # 判别器
    def __init__(self, x_dim, a_dim, hidden=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + a_dim, hidden),
            nn.LeakyReLU(0.2, inplace=False),   # ⚠️ 关闭 inplace
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2, inplace=False),   # ⚠️ 关闭 inplace
            nn.Linear(hidden, 1),
        )
    def forward(self, x, a):
        if a.dtype != x.dtype: a = a.to(x.dtype)
        if a.device != x.device: a = a.to(x.device)
        out = self.net(torch.cat([x, a], dim=1))
        return out.view(out.size(0))            # 直接返回 (B,)

def grad_penalty(Dmod, real_x, fake_x, a, gp_lambda):
    B = real_x.size(0)
    eps = torch.rand(B, 1, device=real_x.device)       # 向量特征 -> (B,1)

    # 只对 inter 打开梯度，real/fake 可 detach 以免传回 G
    inter = (eps * real_x.detach() + (1 - eps) * fake_x.detach()).requires_grad_(True)

    with torch.enable_grad():  # 覆盖外层任何 no_grad
        d_inter = Dmod(inter, a)                         # 形状 (B,)
        grad_outputs = torch.ones_like(d_inter)
        grads = torch.autograd.grad(
            outputs=d_inter,
            inputs=inter,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]                                            # (B, x_dim)

    gp = gp_lambda * ((grads.view(B, -1).norm(2, dim=1) - 1.0) ** 2).mean()
    return gp


In [69]:
def train_wgan_seen(  # 只用 seen 样本训练 WGAN-GP（稳定版）
    X_seen, y_seen_sid, A_seen, *,
    z_dim=100, g_hidden=2048, d_hidden=2048,
    lr=1e-4, n_epochs=40, batch_size=256,
    n_critic=5, gp_lambda=10.0, seed=42,
    desc="GAN(seen)"
):
    set_seed_all(seed)
    x_dim, a_dim = X_seen.size(1), A_seen.size(1)

    # 1) 模型 & 优化器
    Gmod = G(z_dim, a_dim, x_dim, g_hidden).to(device)
    Dmod = D(x_dim, a_dim, d_hidden).to(device)
    optG = torch.optim.Adam(Gmod.parameters(), lr=lr, betas=(0.5, 0.9))
    optD = torch.optim.Adam(Dmod.parameters(), lr=lr, betas=(0.5, 0.9))

    # 2) 数据
    A_seen = A_seen.float().to(device)
    dl = DataLoader(_FeatDS(X_seen, y_seen_sid),
                    batch_size=batch_size, shuffle=True, drop_last=True)

    # 3) 训练
    for _ in tqdm(range(n_epochs), desc=desc):
        for xb, yb_sid in dl:
            xb, yb_sid = xb.to(device), yb_sid.to(device)
            B = xb.size(0)

            # -------------------------------
            # D-step (Critic update)
            # -------------------------------
            for p in Dmod.parameters(): p.requires_grad_(True)
            for p in Gmod.parameters(): p.requires_grad_(False)

            for _ in range(n_critic):
                with torch.enable_grad():                # 保证梯度开启
                    a = A_seen[yb_sid]
                    z = torch.randn(B, z_dim, device=device)
                    x_fake = Gmod(z, a).detach()         # 只断开 G

                    d_real = Dmod(xb, a).mean()
                    d_fake = Dmod(x_fake, a).mean()
                    gp     = grad_penalty(Dmod, xb, x_fake, a, gp_lambda)

                    lossD = -(d_real - d_fake) + gp
                    optD.zero_grad(set_to_none=True)
                    lossD.backward()
                    optD.step()

            # -------------------------------
            # G-step (Generator update)
            # -------------------------------
            for p in Dmod.parameters(): p.requires_grad_(False)
            for p in Gmod.parameters(): p.requires_grad_(True)

            with torch.enable_grad():
                a = A_seen[yb_sid]
                z = torch.randn(B, z_dim, device=device)
                x_fake = Gmod(z, a)
                lossG  = - Dmod(x_fake, a).mean()

                optG.zero_grad(set_to_none=True)
                lossG.backward()
                optG.step()

    return Gmod


In [80]:
def synth_unseen(Gmod, A_unseen, n_per_class=300, z_dim=100):
    Gmod.eval()
    with torch.no_grad():
        K = A_unseen.size(0)
        z = torch.randn(K*n_per_class, z_dim, device=device)
        a = A_unseen.float().to(device).repeat_interleave(n_per_class, 0)
        Xsyn = Gmod(z, a).cpu()
    return Xsyn  # [K*n, x_dim]

class LinearSoftmax(nn.Module):
    def __init__(self, in_dim, n_cls): super().__init__(); self.fc = nn.Linear(in_dim, n_cls)
    def forward(self, x): return self.fc(x)

def train_clf(X, y, n_cls, *, epochs=30, lr=1e-3, wd=1e-4, bs=256, seed=42, desc="CLS"):
    set_seed_all(seed)

    # 1) 把输入统一成 torch 张量 + 正确 dtype/device
    if not torch.is_tensor(X): X = torch.tensor(X)
    if not torch.is_tensor(y): y = torch.tensor(y)
    X = X.float().to(device)
    y = y.long().to(device)

    # 2) 简单线性分类器（和你之前的 LinearSoftmax 等价）
    clf = nn.Sequential(nn.Linear(X.size(1), n_cls)).to(device)
    opt = torch.optim.Adam(clf.parameters(), lr=lr, weight_decay=wd)
    dl  = DataLoader(_FeatDS(X, y), batch_size=bs, shuffle=True, drop_last=True)

    # 3) 训练：强制开启梯度 + 明确 train() + 自检
    clf.train()
    assert any(p.requires_grad for p in clf.parameters()), "clf params require_grad=False ?"

    with torch.set_grad_enabled(True):            # 覆盖任何历史 set_grad_enabled(False)
        for _ in tqdm(range(epochs), desc=desc):
            for xb, yb in dl:
                xb, yb = xb.to(device), yb.to(device)

                opt.zero_grad(set_to_none=True)
                logits = clf(xb)                  # 不要 .detach() / .argmax()
                loss   = F.cross_entropy(logits, yb)

                # 运行前自检（若失败，可立刻定位）
                assert logits.requires_grad, "logits has no grad_fn (被detach/禁梯度了?)"
                assert loss.requires_grad,   "loss has no grad_fn (梯度被全局关闭?)"

                loss.backward()
                opt.step()

    return clf


In [81]:
# === f-CLSWGAN ④ 只调一次参（k=3，⏳） ===
def fclswgan_tune_once_k3(resnet, S, *, seed=42):
    set_seed_all(seed)
    # 用第 0 折的 seen 类做 3 折伪 ZSL
    k_tune = 0
    train_loader, _, seen_cls, _ = make_dataloaders_for_fold(k_tune)
    Xtr, ytr = extract_feats(train_loader, resnet)
    parts = split_classes_k3(seen_cls, seed=seed)  # [C1, C2, C3]

    # 两个超参各三值（作业要求）
    grid = [(g_hidden, lr) for g_hidden in [1024, 2048, 3072] for lr in [5e-5, 1e-4, 2e-4]]

    def eval_cfg(g_hidden, lr):
        scores = []
        for i in range(3):
            val_cls = parts[i]
            tr_cls  = parts[(i+1)%3] + parts[(i+2)%3]
            tr_idx  = idx_of_classes(ytr.tolist(), tr_cls)
            va_idx  = idx_of_classes(ytr.tolist(), val_cls)
            X_tr, y_tr_gid = Xtr[tr_idx], ytr[tr_idx]
            X_va, y_va_gid = Xtr[va_idx], ytr[va_idx]

            # y -> seen 局部ID
            train_seen_ids = sorted(list(map(int, set(tr_cls))))
            gid2sid = {gid: sid for sid, gid in enumerate(train_seen_ids)}
            y_tr_sid = torch.tensor([gid2sid[int(g)] for g in y_tr_gid], dtype=torch.long)

            A_in_seen   = S[torch.tensor(train_seen_ids, dtype=torch.long)].cpu()
            A_in_unseen = S[torch.tensor(val_cls,       dtype=torch.long)].cpu()

            Gmod = train_wgan_seen(
                X_tr, y_tr_sid, A_in_seen,
                z_dim=100, g_hidden=g_hidden, d_hidden=2048,
                lr=lr, n_epochs=40, batch_size=256, n_critic=5, gp_lambda=10.0, seed=seed,
                desc=f"GAN inner cfg(g={g_hidden},lr={lr}) fold{i+1}/3"
            )
            X_syn = synth_unseen(Gmod, A_in_unseen, n_per_class=300, z_dim=100)
            y_syn = torch.repeat_interleave(torch.tensor(val_cls, dtype=torch.long), repeats=300)
            clf   = train_clf(X_syn, y_syn, n_cls=S.size(0), epochs=30, lr=1e-3, wd=1e-4, bs=256, seed=seed,
                              desc=f"CLS inner fold{i+1}/3")
            scores.append(acc_of(clf, X_va, y_va_gid))
        return float(np.mean(scores))

    best_cfg, best_score = None, -1.0
    for g_hidden, lr in tqdm(grid, desc="Grid(k=3)"):  # 每个配置 3 次小训练
        s = eval_cfg(g_hidden, lr)
        if s > best_score:
            best_score = s
            best_cfg = dict(g_hidden=g_hidden, lr=lr, d_hidden=2048,
                            n_per_class=300, epochs_gan=40, batch_size=256,
                            n_critic=5, gp_lambda=10.0)
    print(f"[TuneOnce@fold0] best={best_cfg} | inner-3fold acc={best_score:.4f}")
    return best_cfg, best_score


In [82]:
# === f-CLSWGAN ⑤ 10折（用固定 best_cfg，⏳） ===
def fclswgan_run_10fold(resnet, S, best_cfg, *, seed=42, nw=0):
    accs = []
    for k in tqdm(range(10), desc="Outer 10-fold"):
        train_loader, test_loader, seen_cls, unseen_cls = make_dataloaders_for_fold(k, nw=nw)
        X_tr, y_tr = extract_feats(train_loader, resnet)
        X_te, y_te = extract_feats(test_loader,  resnet)

        A_seen   = S[torch.tensor(seen_cls,   dtype=torch.long)].cpu()
        A_unseen = S[torch.tensor(unseen_cls, dtype=torch.long)].cpu()

        gid2sid = {gid: sid for sid, gid in enumerate(seen_cls)}
        y_tr_sid = torch.tensor([gid2sid[int(g)] for g in y_tr.tolist()], dtype=torch.long)

        # ⏳ 训练 GAN
        Gmod = train_wgan_seen(
            X_tr, y_tr_sid, A_seen,
            z_dim=100,
            g_hidden=best_cfg["g_hidden"], d_hidden=best_cfg["d_hidden"],
            lr=best_cfg["lr"], n_epochs=best_cfg["epochs_gan"],
            batch_size=best_cfg["batch_size"], n_critic=best_cfg["n_critic"],
            gp_lambda=best_cfg["gp_lambda"], seed=seed+k, desc=f"GAN fold{k+1}"
        )

        # 合成 + ⏳ 训练分类器
        X_syn = synth_unseen(Gmod, A_unseen, n_per_class=best_cfg["n_per_class"], z_dim=100)
        y_syn = torch.repeat_interleave(torch.tensor(unseen_cls, dtype=torch.long),
                                        repeats=best_cfg["n_per_class"])
        clf = train_clf(X_syn, y_syn, n_cls=S.size(0), epochs=30, lr=1e-3, wd=1e-4, bs=256,
                        seed=seed+k, desc=f"CLS fold{k+1}")

        acc = acc_of(clf, X_te, y_te)
        accs.append(acc)
        print(f"Fold {k+1}: unseen acc = {acc:.4f}")
    mean_acc = float(np.mean(accs))
    print(f"\n===> 10-fold mean unseen acc = {mean_acc:.4f}")
    return accs, mean_acc


In [None]:
best_cfg, inner_acc = fclswgan_tune_once_k3(resnet, S, seed=42)

Grid(k=3):   0%|          | 0/9 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=5e-05) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=5e-05) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=5e-05) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0001) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0001) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0001) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0002) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0002) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=1024,lr=0.0002) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=5e-05) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=5e-05) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=5e-05) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0001) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0001) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0001) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0002) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0002) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=2048,lr=0.0002) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=5e-05) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=5e-05) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=5e-05) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0001) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0001) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0001) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold3/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0002) fold1/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold1/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0002) fold2/3:   0%|          | 0/40 [00:00<?, ?it/s]

CLS inner fold2/3:   0%|          | 0/30 [00:00<?, ?it/s]

GAN inner cfg(g=3072,lr=0.0002) fold3/3:   0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
# wait for loading
accs, mean_acc = fclswgan_run_10fold(resnet, S, best_cfg, seed=42, nw=0)