In [1]:
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoProcessor
from torch.utils.data import DataLoader
import polars as pl
import pandas as pd
from torchvision import transforms
import os

to_tensor = transforms.ToTensor()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import SiglipProcessor, SiglipModel

model_name = "google/siglip2-base-patch16-224"
processor = SiglipProcessor.from_pretrained(model_name)
model = SiglipModel.from_pretrained(model_name)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
import re

def extract_binomial(name):
    """
    提取双名法：
    - 杂交属（× Genus species）去掉前缀 ×
    - 杂交种（Genus × species）保留 ×
    """
    # 1. 杂交属（× Genus species）
    hybrid_genus_pattern = r'^\s*×\s*([A-Z][a-z-]+)\s+([a-z-]+)'
    m = re.match(hybrid_genus_pattern, name)
    if m:
        genus, species = m.groups()
        genus = genus.replace('×', '').strip()
        return f"{genus} {species}"

    # 2. 普通属 + 种，或自然杂交种（Genus × species）
    normal_pattern = r'^([A-Z][a-z-]+)\s+(×?\s?[a-z-]+)'
    m = re.match(normal_pattern, name)
    if m:
        genus, species = m.groups()
        species = species.replace('×', '').strip()
        return f"{genus} {species}"

    return None

In [4]:
df_species_ids = pd.read_csv('./species_ids.csv')

df_metadata = pd.read_csv('./PlantCLEF2024_single_plant_training_metadata.csv', sep=';', dtype={'partner': str})
class_map = df_species_ids['species_id'].to_dict() # dictionary to map the species model Id with the species Id
id_to_species = df_metadata[['species_id', 'species']].drop_duplicates().set_index('species_id')
id_to_species_dict = id_to_species['species'].to_dict()

species_to_id_dict = {extract_binomial(sp):id1 for id1, sp in id_to_species_dict.items()}
id_genus_dict = {id1:sp.split(' ')[0] for sp, id1 in species_to_id_dict.items()}
genus_set = set(id_genus_dict.values())
all_species_name = list(species_to_id_dict.keys())

In [5]:
def generate_prompt(family, genus, species, part):
    prompts = [
        f"A {part} of a plant from {genus} {species} of the {family} family.", 
        f"{genus} {species} of ({family}) {part}.",
        f"A {genus} {species} plant, which belongs to the {family} family, showing its {part}."
    ]
    return prompts

import random

class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, samples, mode='train'):
        self.samples = samples
        self.mode = mode
        self.labels = list({s["species"] for s in samples})
        self.classes = sorted(self.labels)
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

    def __getitem__(self, idx):
        item = self.samples[idx]
        image = Image.open(item["image_path"]).convert("RGB")
        text = random.choice(item["prompts"])
        if self.mode == 'train':
            return image, text
        elif self.mode == 'eval':
            label = item["species"]
            ids = self.class_to_idx[label]
            return image, ids

    def __len__(self):
        return len(self.samples)

def smart_collate_fn(batch):
    images = [item[0] for item in batch]
    texts = [item[1] for item in batch]
    
    inputs = processor(
        text=texts,
        images=images,
        padding="max_length", # 或者 True，视需求而定
        truncation=True,
        max_length=64,       # 限制文本长度，防止过长拖慢速度
        return_tensors="pt"
    )
    return inputs

def collate_fn_eval(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    inputs = processor(
        images=images,
        padding="max_length", # 或者 True，视需求而定
        truncation=True,
        max_length=64,       # 限制文本长度，防止过长拖慢速度
        return_tensors="pt"
    )
    inputs['labels'] = torch.tensor(labels)
    return inputs

def clip_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(logits)).to(logits.device)

    loss_i = torch.nn.functional.cross_entropy(logits, labels)
    loss_t = torch.nn.functional.cross_entropy(logits.T, labels)
    return (loss_i + loss_t) / 2

In [6]:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Example target modules
    lora_dropout=0.05,
    bias="none",
    # ... other config based on the model's structure
)

In [7]:
model = get_peft_model(model, lora_config).to("cuda")
model.print_trainable_parameters()

trainable params: 884,736 || all params: 376,072,706 || trainable%: 0.2353


In [8]:
device = "cuda"
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

In [9]:
label = pl.read_csv('PlantCLEF2024_single_plant_training_metadata.csv', separator=';')

In [10]:
top5_df = label.group_by("species_id").head(20)
print(label.shape)
print(top5_df.shape)
label_1 = top5_df[label.columns]

(1408033, 20)
(136388, 20)


In [11]:
image_base_path = '/data/images_max_side_800'
samples = []
for row in label_1.iter_rows():
    # print(row)
    sample = {}
    sample["image_path"] = f'{image_base_path}/{row[2]}/{row[0]}'
    genus, species = extract_binomial(row[11]).split(' ')
    part = row[1]
    sample["prompts"] = generate_prompt(row[13], genus, species, part)
    sample["species"] = f'{genus} {species}'
    # sample["prompts"] = f'{genus} {species}'
    samples.append(sample)

In [12]:
dataset = PlantDataset(samples)

dataloader = DataLoader(
    dataset,
    collate_fn=smart_collate_fn,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

eval_dataset = PlantDataset(samples, mode='eval')

eval_dataloader = DataLoader(
    eval_dataset,
    collate_fn=collate_fn_eval,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True)

In [13]:
def get_class_text_embeddings(model, processor, all_scientific_names, device):
    """
    预计算所有类别的文本嵌入（作为分类头）
    """
    model.eval()
    
    # 构造提示
    prompts = [f"A photo of a plant from {name}." for name in all_scientific_names]
    
    batch_size = 64
    all_text_embeds = []
    
    # 禁用梯度计算，节省显存
    with torch.no_grad():
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            
            # 1. 此时 inputs 只包含 'input_ids' 和 'attention_mask'，没有 'pixel_values'
            inputs = processor(text=batch_prompts, return_tensors="pt", padding=True).to(device)
            
            # 2. 关键修改：使用 get_text_features 而不是 model(**inputs)
            # 这会自动调用文本编码器并经过投影层
            text_features = model.get_text_features(**inputs)
            
            # 3. 归一化 (Normalization)
            # SigLIP/CLIP 的相似度计算依赖于余弦相似度，所以必须归一化
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            all_text_embeds.append(text_features.cpu())

    # 合并所有批次: [7800, hidden_dim]
    class_text_embeds = torch.cat(all_text_embeds, dim=0)
    
    return class_text_embeds.to(device)

In [14]:
def evaluate_zero_shot_accuracy(model, processor, eval_dataloader, class_text_embeds, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in eval_dataloader:
            images = batch['pixel_values']
            true_labels_ids = batch['labels']
            inputs = processor(images=images, return_tensors="pt").to(device)
            
            # 2. 获取图像嵌入
            outputs = model.vision_model(**inputs)
            image_embeds = outputs.pooler_output
            
            # 归一化图像嵌入
            image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

            # 3. 计算相似度矩阵
            # 矩阵形状: [Batch Size] x [7800 Classes]
            # 每个元素是 I_i 与 C_j 的相似度
            similarity_matrix = image_embeds @ class_text_embeds.T

            # 4. 预测类别
            # 找到相似度最高的文本索引（即预测的类别ID）
            predicted_ids = torch.argmax(similarity_matrix, dim=1).cpu()
            
            # 5. 计算准确率
            correct = (predicted_ids == true_labels_ids).sum().item()
            
            total_correct += correct
            total_samples += len(true_labels_ids)
            print(total_samples)
            if total_samples > 1000:
                break

    return total_correct / total_samples

In [15]:
print("--- 预计算所有类别的文本嵌入 ---")
class_text_embeds = get_class_text_embeddings(model, processor, all_species_name, device)

--- 预计算所有类别的文本嵌入 ---




In [16]:
num_epochs = 10
best_accuracy = 0.0
load_hist = False

output_dir = "./checkpoints/lora_siglip_plantclef24/"
epoch_start = 0
if load_hist:
    model_name = "google/siglip2-base-patch16-224"
    processor = SiglipProcessor.from_pretrained(model_name)
    base_model = SiglipModel.from_pretrained(model_name)
    from peft import PeftModel
    model = PeftModel.from_pretrained(base_model, output_dir, is_trainable=True)
    model.to(device)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-4,
        weight_decay=1e-4
    )

    checkpoint_path = os.path.join(output_dir, "training_state.pt")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch'] + 1
    best_accuracy = checkpoint['best_accuracy']
    print(f"✅ 加载历史训练状态， 从第 {epoch_start} 轮")
else:
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-4,
        weight_decay=1e-4
    )

model.train()

for epoch in range(epoch_start, epoch_start + num_epochs):
    # --- 训练阶段 ---
    total_loss = 0.0
    num_batches = 0
    for batch_inputs in dataloader:
        inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        
        if inputs["pixel_values"].dtype != model.dtype:
            inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)

        outputs = model(**inputs)
        loss = clip_loss(
            outputs.image_embeds,
            outputs.text_embeds
        )

        loss.backward()
        # 3. 累计损失和准确率
        total_loss += loss.item()
        num_batches += 1

        optimizer.step()
        optimizer.zero_grad()
        print('num_batches:', num_batches)
    avg_loss = total_loss / num_batches
    # 打印训练损失
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}")

    # --- 评估阶段 ---
    print(f"--- Epoch {epoch+1} Zero-Shot 评估 ---")
    
    # 计算准确率
    eval_acc = evaluate_zero_shot_accuracy(
        model, 
        processor, 
        eval_dataloader, 
        class_text_embeds, 
        device
    )
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Zero-Shot Acc: {eval_acc:.4f}")

    # ==========================
    # 3. 保存最佳模型逻辑 (Save Best)
    # ==========================
    if eval_acc > best_accuracy:
        print(f"🚀 发现新高! 准确率从 {best_accuracy:.4f} 提升到 {eval_acc:.4f}")
        best_accuracy = eval_acc
        
        # 确保目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # A. 保存 LoRA 权重 (PeftModel 的特有方法)
        # 这只会保存 adapter_model.bin (很小，几MB) 和 adapter_config.json
        model.save_pretrained(output_dir)
        
        # B. 保存 Processor (分词器配置等)
        # 这一步很重要，确保推理时使用相同的预处理逻辑
        processor.save_pretrained(output_dir)
        
        print(f"✅ 模型已保存到: {output_dir}")
    else:
        print(f"本轮准确率未超过历史最高 ({best_accuracy:.4f})，跳过保存。")

num_batches: 1274
num_batches: 1275
num_batches: 1276
num_batches: 1277
num_batches: 1278
num_batches: 1279
num_batches: 1280
num_batches: 1281
num_batches: 1282
num_batches: 1283
num_batches: 1284
num_batches: 1285
num_batches: 1286
num_batches: 1287
num_batches: 1288
num_batches: 1289
num_batches: 1290
num_batches: 1291
num_batches: 1292
num_batches: 1293
num_batches: 1294
num_batches: 1295
num_batches: 1296
num_batches: 1297
num_batches: 1298
num_batches: 1299
num_batches: 1300
num_batches: 1301
num_batches: 1302
num_batches: 1303
num_batches: 1304
num_batches: 1305
num_batches: 1306
num_batches: 1307
num_batches: 1308
num_batches: 1309
num_batches: 1310
num_batches: 1311
num_batches: 1312
num_batches: 1313
num_batches: 1314
num_batches: 1315
num_batches: 1316
num_batches: 1317
num_batches: 1318
num_batches: 1319
num_batches: 1320
num_batches: 1321
num_batches: 1322
num_batches: 1323
num_batches: 1324
num_batches: 1325
num_batches: 1326
num_batches: 1327
num_batches: 1328
num_batche

In [None]:
# model.save_pretrained(output_dir)
# processor.save_pretrained(output_dir)

[]

In [None]:
# checkpoint_path = os.path.join(output_dir, "training_state.pt")

# torch.save({
#     'epoch': epoch,
#     'best_accuracy': best_accuracy,
#     'optimizer_state_dict': optimizer.state_dict(),
# }, checkpoint_path)