In [1]:
# !pip install polars

In [1]:
import polars as pl

In [2]:
import re

def extract_binomial(name):
    """
    提取双名法：
    - 杂交属（× Genus species）去掉前缀 ×
    - 杂交种（Genus × species）保留 ×
    """
    # 1. 杂交属（× Genus species）
    hybrid_genus_pattern = r'^\s*×\s*([A-Z][a-z-]+)\s+([a-z-]+)'
    m = re.match(hybrid_genus_pattern, name)
    if m:
        genus, species = m.groups()
        genus = genus.replace('×', '').strip()
        return f"{genus} {species}"

    # 2. 普通属 + 种，或自然杂交种（Genus × species）
    normal_pattern = r'^([A-Z][a-z-]+)\s+(×?\s?[a-z-]+)'
    m = re.match(normal_pattern, name)
    if m:
        genus, species = m.groups()
        species = species.replace('×', '').strip()
        return f"{genus} {species}"

    return None

def generate_prompt(family, genus, species, part):
    prompts = [
        f"A {part} of a plant from {genus} {species} of the {family} family.", 
        f"{genus} {species} of ({family}) {part}.",
        f"A {genus} {species} plant, which belongs to the {family} family, showing its {part}."
    ]
    return prompts

from PIL import Image
import torch
import random
    
class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, samples, processor):
        """
        samples: List of dicts
        {
          "image_path": "...",
          "prompts": [str, str, ...]
        }
        """
        self.processor = processor
        self.samples = samples
        self.classes = sorted(list({s["prompts"] for s in samples}))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

    def __getitem__(self, idx):
        item = self.samples[idx]
        image = Image.open(item["image_path"]).convert("RGB")
        processed = self.processor(images=image, return_tensors="pt")        
        processed = {k:v.squeeze(0) for k,v in processed.items()}
        # text = random.choice(item["prompts"])
        text = item["prompts"]
        # return processed, text
        label = self.class_to_idx[item["prompts"]]
        return processed, label, text

    def __len__(self):
        return len(self.samples)


In [3]:
label = pl.read_csv('PlantCLEF2024_single_plant_training_metadata.csv', separator=';')

In [4]:
image_base_path = '/data/images_max_side_800'
samples = []
for row in label.iter_rows():
    sample = {}
    sample["image_path"] = f'{image_base_path}/{row[2]}/{row[0]}'
    genus, species = extract_binomial(row[11]).split(' ')
    part = row[1]
    # sample["prompts"] = generate_prompt(row[13], genus, species, part)
    sample["prompts"] = f'{genus} {species}'
    samples.append(sample)

In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoProcessor

model_name = "google/siglip2-base-patch16-224"
device = "cuda"

processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

siglip2_model = AutoModel.from_pretrained(
    model_name,
    dtype=torch.float16
).to(device)

# 冻结所有参数
for p in siglip2_model.parameters():
    p.requires_grad = False

siglip2_model.eval()

  from .autonotebook import tqdm as notebook_tqdm


SiglipModel(
  (text_model): SiglipTextTransformer(
    (embeddings): SiglipTextEmbeddings(
      (token_embedding): Embedding(256000, 768)
      (position_embedding): Embedding(64, 768)
    )
    (encoder): SiglipEncoder(
      (layers): ModuleList(
        (0-11): 12 x SiglipEncoderLayer(
          (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attn): SiglipAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): SiglipMLP(
            (activation_fn): GELUTanh()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_feature

In [6]:
from torch.utils.data import DataLoader

dataset = PlantDataset(samples, processor)

dataloader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

In [7]:
class PlantClassifier(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(embed_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        return self.head(x)

In [8]:
num_classes = len(label['species_id'].unique())
print(num_classes)

7806


In [9]:
embed_dim = siglip2_model.config.vision_config.hidden_size
classifier = PlantClassifier(embed_dim, num_classes).to(device)

In [10]:
import os

def save_checkpoint(
    path,
    epoch,
    classifier,
    optimizer,
    class_to_idx
):
    os.makedirs(os.path.dirname(path), exist_ok=True)

    torch.save({
        "epoch": epoch,
        "classifier_state": classifier.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "class_to_idx": class_to_idx,
    }, path)

    print(f"Checkpoint saved to {path}")

def load_checkpoint(path, classifier, optimizer, device):
    ckpt = torch.load(path, map_location=device)

    classifier.load_state_dict(ckpt["classifier_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])

    start_epoch = ckpt["epoch"] + 1
    class_to_idx = ckpt["class_to_idx"]

    print(f"Resumed from epoch {ckpt['epoch']}")
    return start_epoch, class_to_idx

In [10]:
optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [11]:
optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

vision_model = siglip2_model.vision_model

num_epochs = 5
for epoch in range(num_epochs):
    classifier.train()
    for batch in dataloader:
        inputs, labels, texts = batch
        labels = labels.to(device)

        # SigLIP2 embedding
        with torch.no_grad():
            outputs = vision_model(inputs['pixel_values'].to(device))
            embeddings = outputs.pooler_output.float()
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

        # 分类 head
        logits = classifier(embeddings)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} - Loss: {loss.item():.4f}")

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/work/plantCLEF/myenv/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/work/plantCLEF/myenv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/plantCLEF/myenv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_7618/599960993.py", line 55, in __getitem__
    image = Image.open(item["image_path"]).convert("RGB")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/plantCLEF/myenv/lib/python3.11/site-packages/PIL/Image.py", line 3493, in open
    fp = builtins.open(filename, "rb")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/data/images_max_side_800/1357227/e0c6685ecd84a27e6440384bd848d4b5405de8ab.jpg'


In [11]:
num_epochs = 1
for epoch in range(num_epochs):
    classifier.train()
    for batch in dataloader:
        inputs, labels, texts = batch
        labels = labels.to(device)

        # SigLIP2 embedding
        with torch.no_grad():
            outputs = vision_model(inputs['pixel_values'].to(device))
            embeddings = outputs.pooler_output.float()
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

        # 分类 head
        logits = classifier(embeddings)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} - Loss: {loss.item():.4f}")

Epoch 0 - Loss: 9.3225


In [13]:
save_checkpoint(
        path=f"checkpoints/head_ft_epoch_6.pt",
        epoch=6,
        classifier=classifier,
        optimizer=optimizer,
        class_to_idx=dataset.class_to_idx
    )

Checkpoint saved to checkpoints/head_ft_epoch_6.pt


In [None]:
start_epoch = 0

ckpt_path = "checkpoints/head_ft_epoch_12.pt"
if os.path.exists(ckpt_path):
    start_epoch, class_to_idx = load_checkpoint(
        ckpt_path,
        classifier,
        optimizer,
        device
    )
    # dataset.class_to_idx = class_to_idx

Resumed from epoch 6


In [14]:
num_epochs = 20
vision_model = siglip2_model.vision_model
for epoch in range(start_epoch, start_epoch + num_epochs):
    classifier.train()
    for batch in dataloader:
        inputs, labels, texts = batch
        labels = labels.to(device)

        # SigLIP2 embedding
        with torch.no_grad():
            outputs = vision_model(inputs['pixel_values'].to(device))
            embeddings = outputs.pooler_output.float()
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

        # 分类 head
        logits = classifier(embeddings)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} - Loss: {loss.item():.4f}")

AcceleratorError: CUDA error: unspecified launch failure
Search for `cudaErrorLaunchFailure' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
