In [None]:
import os
import random
from typing import List, Tuple
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.nn.functional import cross_entropy
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from torchvision import transforms
from transformers import AutoProcessor, AutoModel
import timm
import shutil

# ========== 1. 设备 ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# ========== 2. 数据路径 ==========
# ❗❗❗ 把这个改成你“四个类别”所在的文件夹 ❗❗❗
root_dir = r"D:/OneDriveFiles/OneDrive/人工智能基础期末/dataset2/"

# 目录结构要求：
# root_dir/
#   classA/
#   classB/
#   classC/
#   classD/

# ========== 3. 训练超参数 ==========
batch_size   = 32
num_workers  = 0
num_epochs   = 30
lr           = 1e-3      # 只训练线性头，可以稍微大一点
weight_decay = 1e-2

# ========== 4. 随机种子（保证每次划分一致） ==========
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)

local_model_dir = r"D:/OneDriveFiles/OneDrive/人工智能基础期末/medsiglip-448"  # TODO: 改成你的路径

train_dir = r"D:/OneDriveFiles/OneDrive/人工智能基础期末/data_split/train"
val_dir = r"D:/OneDriveFiles/OneDrive/人工智能基础期末/data_split/val"
image_size = 448

In [None]:
# %% 从本地加载 MedSigLIP，多模态模型中抽出视觉塔
print("Loading base model from local dir:", local_model_dir)

raw_model = AutoModel.from_pretrained(
    local_model_dir,
    local_files_only=True,
)
print("raw_model class:", type(raw_model))

# ---- 关键：只抽出“视觉 encoder” ----
if hasattr(raw_model, "vision_model"):
    # 大多数 CLIP/SigLIP 多模态模型的视觉塔都叫 vision_model
    img_encoder = raw_model.vision_model
    print("Use raw_model.vision_model as image encoder.")
elif hasattr(raw_model, "get_image_features"):
    # 有些实现是直接在 model 上提供 get_image_features
    img_encoder = raw_model
    print("Use raw_model itself as image encoder (get_image_features).")
else:
    raise RuntimeError(
        "在 raw_model 里找不到 vision_model 或 get_image_features，"
        "请 print(raw_model) 看看结构，然后再定位视觉塔。"
    )

img_encoder.to(device)
img_encoder.eval()

# ---- 用 dummy 探测 embedding 维度 ----
with torch.no_grad():
    dummy = torch.zeros(1, 3, image_size, image_size).to(device)   # [1,3,448,448]

    try:
        out = img_encoder(pixel_values=dummy)   # 优先用 keyword
    except TypeError:
        out = img_encoder(dummy)               # 有的模型只收 positional

    if hasattr(out, "image_embeds"):
        feats = out.image_embeds                    # [1, D]
    elif hasattr(out, "pooler_output"):
        feats = out.pooler_output                   # [1, D]
    elif hasattr(out, "last_hidden_state"):
        feats = out.last_hidden_state.mean(dim=1)   # [1, D]
    elif isinstance(out, torch.Tensor):
        feats = out
    else:
        print("Unknown output type:", type(out))
        print(out)
        raise RuntimeError(
            "无法从 img_encoder 的输出中找到特征，请 print(out) 再调整逻辑。"
        )

embed_dim = feats.shape[-1]
print("image embed dim:", embed_dim)


# ---- 封装分类模型：视觉塔 + 线性 head ----
class MedSigVisionClassifier(nn.Module):
    def __init__(self, img_encoder, embed_dim, num_classes):
        super().__init__()
        self.encoder = img_encoder
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, pixel_values):
        # pixel_values: [B,3,448,448]
        try:
            out = self.encoder(pixel_values=pixel_values)
        except TypeError:
            out = self.encoder(pixel_values)

        if hasattr(out, "image_embeds"):
            feats = out.image_embeds
        elif hasattr(out, "pooler_output"):
            feats = out.pooler_output
        elif hasattr(out, "last_hidden_state"):
            feats = out.last_hidden_state.mean(dim=1)
        elif isinstance(out, torch.Tensor):
            feats = out
        else:
            raise RuntimeError("encoder 输出里找不到特征，需根据实际结构单独处理。")

        # L2 归一化（保持和 CLIP 系一致的风格）
        feats = feats / (feats.norm(dim=-1, keepdim=True) + 1e-6)
        logits = self.head(feats)    # [B,num_classes]
        return logits


# ==== 构建模型 + DataParallel ====
model = MedSigVisionClassifier(img_encoder, embed_dim, 4)

if torch.cuda.device_count() > 1:
    print("使用", torch.cuda.device_count(), "张 GPU 进行 DataParallel")
    model = nn.DataParallel(model)   # 在多卡上自动切 batch

model = model.to(device)

# 关键：取出真正的模型（DataParallel 包了一层壳）
core = model.module if isinstance(model, nn.DataParallel) else model

# 先冻结视觉塔，只训 head 当 baseline
for p in core.encoder.parameters():
    p.requires_grad = False
for p in core.head.parameters():
    p.requires_grad = True

# 统计参数量用 core（真正的模型）
total_params = sum(p.numel() for p in core.parameters())
trainable_params = sum(p.numel() for p in core.parameters() if p.requires_grad)
print(f"总参数量: {total_params:,}")
print(f"当前可训练参数量(仅 head): {trainable_params:,}")


In [None]:
# %% DataLoader：自己写 448x448 的 transform
# 灰度医学片 -> 3 通道，Resize 到 448，归一化

mean = [0.5, 0.5, 0.5]
std  = [0.5, 0.5, 0.5]

train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

train_dataset = ImageFolder(train_dir, transform=train_transform)
val_dataset   = ImageFolder(val_dir,   transform=val_transform)

print("Classes:", train_dataset.classes)
print("train samples:", len(train_dataset))
print("val samples:", len(val_dataset))

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)


In [None]:
# LoRA 版 Linear：在 frozen 的 base Linear 外面加一个低秩补丁 ΔW = BA
import math

class LoRALinear(nn.Module):
    def __init__(self, base_linear: nn.Linear, r=8, alpha=16, dropout=0.1):
        super().__init__()
        assert isinstance(base_linear, nn.Linear)
        self.in_features  = base_linear.in_features
        self.out_features = base_linear.out_features

        self.base = base_linear
        self.base.weight.requires_grad = False
        if self.base.bias is not None:
            self.base.bias.requires_grad = False

        self.r = r
        self.lora_A = nn.Linear(self.in_features, r, bias=False)
        self.lora_B = nn.Linear(r, self.out_features, bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.dropout = nn.Dropout(dropout)
        self.scaling = alpha / r

        # ⭐ 关键：LoRA 权重搬到和 base 同一个 device
        device = self.base.weight.device
        self.lora_A.to(device)
        self.lora_B.to(device)

    def forward(self, x):
        base_out = self.base(x)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        return base_out + lora_out


In [None]:
import torch.nn as nn

def add_lora_fc_last_blocks(model, r=8, alpha=16, dropout=0.1, last_n_blocks=4):
    """
    只在视觉编码器最后 last_n_blocks 个 block 的 mlp.fc1 / mlp.fc2 上打 LoRA，
    其他层全部冻结；head + LoRA 参与训练。
    """
    # 兼容 DataParallel / DDP
    core = model.module if hasattr(model, "module") else model

    # 你的视觉塔是 SiglipVisionTransformer
    vit = core.encoder

    # 1. 拿到所有 block（每个 block 是一个 Transformer 层）
    try:
        blocks = list(vit.encoder.layers)  # ✅ SiglipVisionTransformer 里是 encoder.layers
    except AttributeError:
        print("encoder 结构如下，请发我这个打印：")
        print(vit)
        raise

    n_blocks = len(blocks)
    if last_n_blocks > n_blocks:
        last_n_blocks = n_blocks
        print(f"⚠ last_n_blocks 超过总层数，自动改为 {n_blocks}")

    target_blocks = blocks[-last_n_blocks:]

    # 2. 先把整个 encoder 全部 requires_grad=False （base 权重全部冻结）
    for p in vit.parameters():
        p.requires_grad = False

    # 3. 在目标 block 的 mlp.fc1 / mlp.fc2 注入 LoRA
    for idx, block in enumerate(target_blocks):
        # SigLIP 的每层里一般有 block.mlp.fc1 / block.mlp.fc2
        mlp = block.mlp

        # 用你自己定义好的 LoRALinear 包一下原始 Linear
        mlp.fc1 = LoRALinear(mlp.fc1, r=r, alpha=alpha, dropout=dropout)
        mlp.fc2 = LoRALinear(mlp.fc2, r=r, alpha=alpha, dropout=dropout)

        real_idx = n_blocks - last_n_blocks + idx


In [None]:
# 只优化 LoRA + head
add_lora_fc_last_blocks(model, r=8, alpha=16, dropout=0.1, last_n_blocks=4)

trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
    trainable_params,
    lr=1e-4,
    weight_decay=1e-2,
)
print("可训练参数个数:", sum(p.numel() for p in trainable_params))

# %% 训练 & 验证函数
criterion = nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
)

def train_one_epoch(epoch: int):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch} [train]")
    for imgs, labels in pbar:
        imgs   = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        logits = model(imgs)
        loss   = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)

        pbar.set_postfix({
            "loss": f"{running_loss/total:.4f}",
            "acc":  f"{correct/total:.4f}",
        })

    epoch_loss = running_loss / total
    epoch_acc  = correct / total
    return epoch_loss, epoch_acc


@torch.no_grad()
def eval_one_epoch(epoch: int):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(val_loader, desc=f"Epoch {epoch} [val]  ")
    for imgs, labels in pbar:
        imgs   = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(imgs)
        loss   = criterion(logits, labels)

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)

        pbar.set_postfix({
            "loss": f"{running_loss/total:.4f}",
            "acc":  f"{correct/total:.4f}",
        })

    epoch_loss = running_loss / total
    epoch_acc  = correct / total
    return epoch_loss, epoch_acc

model = model.to(device)

In [None]:
import torch.nn as nn

# 取出真正的模型（去掉 DataParallel 壳）
core = model.module if isinstance(model, nn.DataParallel) else model

total_params = sum(p.numel() for p in core.parameters())
trainable_params = sum(p.numel() for p in core.parameters() if p.requires_grad)

# 统计 LoRA 专属参数（lora_A / lora_B）
lora_params = 0
for name, p in core.named_parameters():
    if ("lora_A" in name) or ("lora_B" in name):
        lora_params += p.numel()

# 统计 head 参数（你分类头叫 core.head）
head_params = sum(p.numel() for p in core.head.parameters() if p.requires_grad)

print("====== 参数统计（LoRA 注入后）======")
print(f"总参数量          : {total_params:,}")
print(f"当前可训练参数量  : {trainable_params:,}")
print(f"  其中 LoRA 参数  : {lora_params:,}")
print(f"  其中 head 参数  : {head_params:,}")
print(f"可训练比例        : {trainable_params / total_params:.4%}")
print(f"LoRA 占总参数比例 : {lora_params / total_params:.4%}")
print(f"LoRA 占可训练比例 : {lora_params / trainable_params:.4%}")
print("✅ 当前参与训练的参数：")

for name, p in core.named_parameters():
    if p.requires_grad:
        print("Train:",name)
    else :
        print("Freez:",name)

In [None]:
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm

scaler = GradScaler(enabled=True)

def train_one_epoch_amp(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # ⭐ 用 tqdm 包一层 dataloader
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} [train]", leave=False)

    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        with autocast():
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # ⭐ 动态在进度条上显示当前 loss / acc
        cur_loss = running_loss / total
        cur_acc  = correct / total
        pbar.set_postfix(loss=f"{cur_loss:.4f}", acc=f"{cur_acc:.4f}")

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"[Epoch {epoch}] train_loss={epoch_loss:.4f} train_acc={epoch_acc:.4f}")
    return epoch_loss, epoch_acc


@torch.no_grad()
def eval_one_epoch_amp(epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(val_loader, desc=f"Epoch {epoch} [val]  ", leave=False)

    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        with autocast():
            logits = model(imgs)
            loss = criterion(logits, labels)

        running_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        cur_loss = running_loss / total
        cur_acc  = correct / total
        pbar.set_postfix(loss=f"{cur_loss:.4f}", acc=f"{cur_acc:.4f}")

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"[Epoch {epoch}] val_loss={epoch_loss:.4f} val_acc={epoch_acc:.4f}")
    return epoch_loss, epoch_acc




In [None]:
# %% 主训练循环
log_path = "trainlog_linear.log"
best_val_acc = 0.0
best_state_dict = None
torch.cuda.empty_cache()
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch_amp(epoch)
    val_loss, val_acc     = eval_one_epoch_amp(epoch)
    line = (
        f"[Epoch {epoch}] "
        f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
        f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
    )

    scheduler.step()

    print(line)

    with open(log_path, "a", encoding="utf-8") as f:
        f.write(line + "\n")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state_dict = model.state_dict()

print("Best val_acc:", best_val_acc)

if best_state_dict is not None:
    save_path = f"./medsiglip448_cls_best_acc{best_val_acc:.4f}.pth"
    torch.save(best_state_dict, save_path)
    print("Saved best model to:", save_path)
