In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

# ====== Минимальный вариант архитектуры из DeCLIP для ViT (из declip.py, сильно упрощён) ======

class QuickTransformer(nn.Module):
    def __init__(self, width, layers, heads):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=width, nhead=heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers)

    def forward(self, x):
        return self.encoder(x)

class DeCLIP(nn.Module):
    def __init__(self, 
                 embed_dim=512, 
                 vision_width=768, 
                 vision_layers=12, 
                 vision_patch_size=16,
                 image_resolution=224, 
                 text_width=512, 
                 text_layers=12, 
                 text_heads=8, 
                 vocab_size=49408):
        super().__init__()

        self.image_resolution = image_resolution
        self.vision_patch_size = vision_patch_size
        self.visual_conv1 = nn.Conv2d(3, vision_width, kernel_size=vision_patch_size, stride=vision_patch_size, bias=False)
        scale = vision_width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(vision_width))
        num_patches = (image_resolution // vision_patch_size) ** 2
        self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, vision_width))
        self.ln_pre = nn.LayerNorm(vision_width)
        self.visual_transformer = QuickTransformer(vision_width, vision_layers, heads=12)
        self.ln_post = nn.LayerNorm(vision_width)
        self.visual_projection = nn.Parameter(torch.randn(vision_width, embed_dim))

        # Текстовый энкодер (очень упрощён)
        self.token_embedding = nn.Embedding(vocab_size, text_width)
        self.positional_embedding_text = nn.Parameter(torch.empty(77, text_width))
        self.text_transformer = QuickTransformer(text_width, text_layers, text_heads)
        self.ln_final = nn.LayerNorm(text_width)
        self.text_projection = nn.Parameter(torch.randn(text_width, embed_dim))

        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
        
    def encode_image(self, image):
        # image: [B, 3, H, W]
        x = self.visual_conv1(image)   # [B, C, H', W']
        x = x.flatten(2).permute(2, 0, 1)  # [N, B, C]
        class_emb = self.class_embedding.unsqueeze(0).unsqueeze(1).expand(1, x.size(1), -1)
        x = torch.cat([class_emb, x], dim=0)
        x = x + self.positional_embedding[:x.size(0), :].unsqueeze(1)
        x = self.ln_pre(x)
        x = self.visual_transformer(x)
        x = x[0]
        x = self.ln_post(x)
        x = x @ self.visual_projection
        x = x / x.norm(dim=-1, keepdim=True)
        return x

    def encode_text(self, text):
        # text: [B, seq_len]
        x = self.token_embedding(text) + self.positional_embedding_text[:text.size(1), :]
        x = x.permute(1, 0, 2)  # [seq_len, batch, dim]
        x = self.text_transformer(x)
        x = x.permute(1, 0, 2)  # [batch, seq_len, dim]
        x = self.ln_final(x)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        x = x / x.norm(dim=-1, keepdim=True)
        return x

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()
        return logits

# =================== DeCLIP Inference Wrapper ===================

class DeCLIPInference:
    def __init__(self, model_ckpt_path, classnames, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        # Обычно параметры можно взять из документации к весам/статье
        self.model = DeCLIP(
            embed_dim=512, vision_width=768, vision_layers=12, 
            vision_patch_size=16, image_resolution=224,
            text_width=512, text_layers=12, text_heads=8, vocab_size=49408
        ).to(self.device)
        print(self.model)
        # Загрузка весов с удалением префикса 'module.'
        state_dict = torch.load(model_ckpt_path, map_location=self.device)
        if "model" in state_dict:
            state_dict = state_dict["model"]
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        self.model.load_state_dict(new_state_dict, strict=False)
        self.model.eval()
        # Препроцессинг (CLIP-style)
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            ),
        ])
        # Классы и токенизация
        import clip  # нужен clip.tokenize (или возьми токенизацию из DeCLIP!)
        self.classnames = classnames
        with torch.no_grad():
            text_tokens = clip.tokenize([f"a photo of a {c}" for c in classnames]).to(self.device)
            self.text_features = self.model.encode_text(text_tokens)
            self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)

    def predict(self, image_paths, batch_size=16):
        preds = []
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i+batch_size]
            images = [self.preprocess(Image.open(p).convert("RGB")) for p in batch_paths]
            images = torch.stack(images).to(self.device)
            with torch.no_grad():
                image_features = self.model.encode_image(images)
                logits = (self.model.logit_scale.exp() * image_features @ self.text_features.t()).softmax(dim=-1)
                pred_indices = logits.argmax(dim=-1).cpu().tolist()
                batch_preds = [self.classnames[i] for i in pred_indices]
                preds.extend(batch_preds)
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        return preds

    def predict_proba(self, image_paths, topk=5, batch_size=16):
        topk_preds = []
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i+batch_size]
            images = [self.preprocess(Image.open(p).convert("RGB")) for p in batch_paths]
            images = torch.stack(images).to(self.device)
            with torch.no_grad():
                image_features = self.model.encode_image(images)
                logits = (self.model.logit_scale.exp() * image_features @ self.text_features.t()).softmax(dim=-1)
                probs = logits.cpu().tolist()
                for p in probs:
                    topk_indices = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:topk]
                    topk_labels = [self.classnames[i] for i in topk_indices]
                    topk_preds.append(topk_labels)
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        return topk_preds

    def get_name(self):
        return "DeCLIP-original"
    
# =========== Пример использования ===========

if __name__ == "__main__":
    classnames = ["cat", "dog", "car"]
    model_ckpt_path = "vitb32.pth.tar"
    #model = DeCLIPInference(model_ckpt_path, classnames, device="cuda")
    new_state_dict = DeCLIPInference(model_ckpt_path=model_ckpt_path, classnames=classnames, device="cuda")


DeCLIP(
  (visual_conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (visual_transformer): QuickTransformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-11): 12 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (linear1): Linear(in_features=768, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=768, bias=True)
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (ln_post): LayerNorm((768,), eps=1e-05, elementwise_a

In [None]:
model_keys = set(model.state_dict().keys())
ckpt_keys = set(new_state_dict.keys())
print("Не совпало:", ckpt_keys - model_keys)
print("Отсутствуют в ckpt:", model_keys - ckpt_keys)

In [1]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from sentence_transformers import SentenceTransformer
from PIL import Image
import torch

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to("cuda")
st_model = SentenceTransformer('all-mpnet-base-v2', device="cuda")

classnames = ["cat", "dog", "car", "tree", "person"]
prompts = [f"a photo of a {c}" for c in classnames]
class_embeds = st_model.encode(prompts, normalize_embeddings=True)

img = Image.open("img1.jpg").convert("RGB")
inputs = processor(images=img, text="Describe this image", return_tensors="pt").to("cuda")
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=15)
    caption = processor.decode(out[0], skip_special_tokens=True).lower()
caption_embed = st_model.encode([caption], normalize_embeddings=True)
sims = (caption_embed @ class_embeds.T).squeeze(0)
pred_idx = int(sims.argmax())
print("Predicted class:", classnames[pred_idx])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Predicted class: person


In [2]:
import os
import shutil
import random

def split_classes(all_classes, n_train, n_test, seed=42):
    random.seed(seed)
    shuffled = list(all_classes)
    random.shuffle(shuffled)
    train_classes = shuffled[:n_train]
    test_classes = shuffled[n_train:n_train+n_test]
    return train_classes, test_classes

def split_dataset_by_class(
    source_dir,
    output_dir,
    n_train_classes,
    n_test_classes,
    n_train_images_per_class,
    n_test_images_per_class,
    seed=42
):
    # Получаем все классы (папки) в датасете
    all_classes = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]
    train_classes, test_classes = split_classes(all_classes, n_train_classes, n_test_classes, seed)
    
    print(f"Train classes: {train_classes}")
    print(f"Test classes: {test_classes}")
    
    for split, splited_classes, n_images in [('train', train_classes, n_train_images_per_class),
                                           ('test', test_classes, n_test_images_per_class)]:
        split_dir = os.path.join(output_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        for class_name in splited_classes:
            src_class_dir = os.path.join(source_dir, class_name)
            dst_class_dir = os.path.join(split_dir, class_name)
            os.makedirs(dst_class_dir, exist_ok=True)
            images = os.listdir(src_class_dir)
            random.seed(seed)  # Для стабильности отбора изображений — фиксируем seed
            random.shuffle(images)
            selected = images[:n_images] if n_images else images
            for img in selected:
                shutil.copy(os.path.join(src_class_dir, img), os.path.join(dst_class_dir, img))

# ---------- FungiCLEF ----------
def split_fungi_clef_dataset(
    source_dir,
    csv_file,
    output_dir,
    n_train_classes,
    n_test_classes,
    n_train_images_per_class,
    n_test_images_per_class,
    seed=42
):
    import csv
    from collections import defaultdict
    
    # Загружаем метаданные из CSV
    samples_by_class = defaultdict(list)
    with open(os.path.join(source_dir, csv_file), newline='', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            class_name = " ".join(row["scientificName"].split()[:2])  # Берем только род и вид
            img_path = os.path.join(source_dir, "DF20_300", row["image_path"])
            if os.path.isfile(img_path):
                samples_by_class[class_name].append(img_path)
    
    # Получаем список всех классов
    all_classes = sorted(samples_by_class.keys())
    train_classes, test_classes = split_classes(all_classes, n_train_classes, n_test_classes, seed)
    
    print(f"Train classes: {train_classes}")
    print(f"Test classes: {test_classes}")
    
    for split, splited_classes, n_images in [('train', train_classes, n_train_images_per_class),
                                           ('test', test_classes, n_test_images_per_class)]:
        split_dir = os.path.join(output_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        
        for class_name in splited_classes:
            dst_class_dir = os.path.join(split_dir, class_name)
            os.makedirs(dst_class_dir, exist_ok=True)
            
            # Получаем все изображения для класса
            images = samples_by_class[class_name]
            random.seed(seed)
            random.shuffle(images)
            selected = images[:n_images] if n_images else images
            
            for img_path in selected:
                img_name = os.path.basename(img_path)
                dst_path = os.path.join(dst_class_dir, img_name)
                try:
                    shutil.copy(img_path, dst_path)
                except Exception as e:
                    print(f"Error copying {img_path}: {e}")

# # Example usage:
# split_fungi_clef_dataset(
#     source_dir='fungi_clef_2022',
#     csv_file='DF20-train_metadata.csv',
#     output_dir='fungi_clef_2022_split',
#     n_train_classes=120,
#     n_test_classes=20,
#     n_train_images_per_class=250,
#     n_test_images_per_class=250,
#     seed=42
# )

# # ---------- DTD ----------
# split_dataset_by_class(
#     source_dir='dtd/images/',             # замените на свой путь
#     output_dir='dtd_split/',       # куда сохранить
#     n_train_classes=29,
#     n_test_classes=18,
#     n_train_images_per_class=120,
#     n_test_images_per_class=120,
#     seed=42
# )

# ---------- FungiCLEF ----------

# # ---------- CUB_200_2011 ----------
# split_dataset_by_class(
#     source_dir='CUB_200_2011/CUB_200_2011/images',
#     output_dir='CUB_200_2011_split/',
#     n_train_classes=130,
#     n_test_classes=70,
#     n_train_images_per_class=58,    # ~7500/130
#     n_test_images_per_class=59,     # ~4100/70
#     seed=42
# )



In [2]:
import os

path = "dtd/images/pleated"
print("Exists:", os.path.exists(path))
print("Readable:", os.access(path, os.R_OK))
print("Writable:", os.access(path, os.W_OK))

Exists: True
Readable: True
Writable: True


In [5]:
import os
import shutil
import random

def split_train_val_with_unseen(
    train_dir,
    test_dir,
    output_dir,
    validation_ratio=0.2,
    seed=42
):
    """
    Создает валидационную выборку, включающую:
    - часть изображений из train (seen classes)
    - все изображения из test (unseen classes)

    train_dir: путь к папке с обучающими seen-классами
    test_dir: путь к папке с unseen-классами
    output_dir: папка, где будет создана структура val/
    """
    val_dir = os.path.join(output_dir, 'val')
    os.makedirs(val_dir, exist_ok=True)

    # 1. seen-классы: берем часть изображений
    seen_classes = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]

    for class_name in seen_classes:
        src_class_dir = os.path.join(train_dir, class_name)
        val_class_dir = os.path.join(val_dir, class_name)
        os.makedirs(val_class_dir, exist_ok=True)

        images = [f for f in os.listdir(src_class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        random.seed(seed)
        random.shuffle(images)
        n_val = int(len(images) * validation_ratio)
        val_images = images[:n_val]

        for img in val_images:
            shutil.copy(
                os.path.join(src_class_dir, img),
                os.path.join(val_class_dir, img)
            )

    # 2. unseen-классы: копируем все изображения
    unseen_classes = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]

    for class_name in unseen_classes:
        src_class_dir = os.path.join(test_dir, class_name)
        val_class_dir = os.path.join(val_dir, class_name)
        os.makedirs(val_class_dir, exist_ok=True)

        images = [f for f in os.listdir(src_class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        for img in images:
            shutil.copy(
                os.path.join(src_class_dir, img),
                os.path.join(val_class_dir, img)
            )
                
# ---------- DTD ----------
# split_dataset_by_class(
#     source_dir='dtd/images/',
#     output_dir='dtd_split/',
#     n_train_classes=29,
#     n_test_classes=18,
#     n_train_images_per_class=120,
#     n_test_images_per_class=120,
#     seed=42
# )
# split_train_val_with_unseen(
#     train_dir='dtd_split/train',
#     test_dir='dtd_split/test',
#     output_dir='dtd_split',
#     validation_ratio=0.2,
#     seed=42
# )

# # ---------- FungiCLEF ----------
# split_fungi_clef_dataset(
#     source_dir='fungi_clef_2022',
#     csv_file='DF20-train_metadata.csv',
#     output_dir='fungi_clef_2022_split',
#     n_train_classes=120,
#     n_test_classes=20,
#     n_train_images_per_class=250,
#     n_test_images_per_class=250,
#     seed=42
#)
split_train_val_with_unseen(
    train_dir='fungi_clef_2022_split/train',
    test_dir='fungi_clef_2022_split/test',
    output_dir='fungi_clef_2022_split',
    validation_ratio=0.2,
    seed=42
)

# ---------- CUB_200_2011 ----------
split_dataset_by_class(
    source_dir='CUB_200_2011/CUB_200_2011/images',
    output_dir='CUB_200_2011_split/',
    n_train_classes=130,
    n_test_classes=70,
    n_train_images_per_class=58,
    n_test_images_per_class=59,
    seed=42
)
split_train_val_with_unseen(
    train_dir='CUB_200_2011_split/train',
    test_dir='CUB_200_2011_split/test',
    output_dir='CUB_200_2011_split',
    validation_ratio=0.2,
    seed=42
)

Train classes: ['067.Anna_Hummingbird', '188.Pileated_Woodpecker', '102.Western_Wood_Pewee', '194.Cactus_Wren', '112.Great_Grey_Shrike', '122.Harris_Sparrow', '014.Indigo_Bunting', '003.Sooty_Albatross', '065.Slaty_backed_Gull', '045.Northern_Fulmar', '137.Cliff_Swallow', '171.Myrtle_Warbler', '129.Song_Sparrow', '077.Tropical_Kingbird', '159.Black_and_white_Warbler', '168.Kentucky_Warbler', '046.Gadwall', '131.Vesper_Sparrow', '031.Black_billed_Cuckoo', '004.Groove_billed_Ani', '160.Black_throated_Blue_Warbler', '016.Painted_Bunting', '043.Yellow_bellied_Flycatcher', '127.Savannah_Sparrow', '187.American_Three_toed_Woodpecker', '163.Cape_May_Warbler', '048.European_Goldfinch', '111.Loggerhead_Shrike', '006.Least_Auklet', '195.Carolina_Wren', '038.Great_Crested_Flycatcher', '123.Henslow_Sparrow', '013.Bobolink', '037.Acadian_Flycatcher', '154.Red_eyed_Vireo', '070.Green_Violetear', '062.Herring_Gull', '196.House_Wren', '033.Yellow_billed_Cuckoo', '173.Orange_crowned_Warbler', '172.Nash