In [1]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

train_transform = A.Compose([
    A.Resize(256, 256),
    A.CenterCrop(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=10, p=0.2),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])

# Для валидации/теста (только resize + norm)
val_transform = A.Compose([
    A.Resize(256, 256),
    A.CenterCrop(224, 224),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])

class CsvImageDataset(Dataset):
    def __init__(self, csv_path, split, transform=None, class_names=None):
        df = pd.read_csv(csv_path)
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.transform = transform

        if class_names is None:
            self.class_names = sorted(self.df['class_name'].unique())
        else:
            self.class_names = class_names

        self.class_to_idx = {cls: i for i, cls in enumerate(self.class_names)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = np.array(Image.open(row['image_path']).convert('RGB'))
        if self.transform:
            image = self.transform(image=image)['image']
        label = self.class_to_idx[row['class_name']]
        return image, label

  check_for_updates()


In [5]:
import torch
import open_clip
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm

# === Конфиги ===
CSV_PATH = 'merged_dataset_v2.csv'
BATCH_SIZE = 128
NUM_WORKERS = 0
NUM_EPOCHS = 30
LR = 1e-6
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_NAME = 'ViT-B-32'
PRETRAINED = 'laion2b_s34b_b79k'
CHECKPOINT_PATH = 'finetuned_openclip_v3'
WEIGHT_DECAY = 1e-2
SPLITS = ['train', 'val', 'test', 'обучение', 'валидация', 'тест']



 # можно адаптировать под свой CSV
df = pd.read_csv(CSV_PATH)
all_classes = sorted(df['class_name'].unique().tolist())

tokenizer = open_clip.get_tokenizer(MODEL_NAME)
prompts = [f"{c}" for c in all_classes]
text_tokens = tokenizer(prompts).to(DEVICE)


In [4]:
import matplotlib.pyplot as plt

train_losses = []
val_accuracies = []
val_f1s = []
gzsl_seen_accs = []
gzsl_unseen_accs = []
gzsl_hmeans = []


def compute_gzsl(trues_cls, preds_cls, seen_classes, unseen_classes):
    seen_mask = [cls in seen_classes for cls in trues_cls]
    unseen_mask = [cls in unseen_classes for cls in trues_cls]
    seen_acc = sum([t == p for t, p, m in zip(
        trues_cls, preds_cls, seen_mask) if m]) / max(sum(seen_mask), 1)
    unseen_acc = sum([t == p for t, p, m in zip(trues_cls, preds_cls, unseen_mask) if m]) / max(sum(unseen_mask), 1)
    h_mean = 2 * seen_acc * unseen_acc / (seen_acc + unseen_acc) if (seen_acc + unseen_acc) > 0 else 0.0
    return seen_acc, unseen_acc, h_mean

In [7]:
def evaluate_on_domain(model, data_loader, text_tokens, class_names, seen_classes, unseen_classes, domain_name):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(DEVICE)
            text_features = model.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits = 100. * image_features @ text_features.T
            pred = logits.argmax(dim=1).cpu().numpy()
            preds.extend(pred)
            trues.extend(labels.cpu().numpy())
    acc = accuracy_score(trues, preds)
    macro_f1 = f1_score(trues, preds, average='macro')
    print(f"[{domain_name} Test] Top-1 Accuracy: {acc:.4f}, Macro F1: {macro_f1:.4f}")

    idx_to_class = {i: c for i, c in enumerate(class_names)}
    trues_cls = [idx_to_class[i] for i in trues]
    preds_cls = [idx_to_class[i] for i in preds]
    seen_acc, unseen_acc, h_mean = compute_gzsl(trues_cls, preds_cls, seen_classes, unseen_classes)
    print(f"[{domain_name} GZSL] Seen acc: {seen_acc:.4f} | Unseen acc: {unseen_acc:.4f} | H-mean: {h_mean:.4f}")
    return acc, macro_f1, seen_acc, unseen_acc, h_mean



test_dataset = CsvImageDataset(CSV_PATH, split='test', transform=val_transform)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

DATASET_DIRS = [
    'CUB_200_2011_split',
    'dtd_split',
    'fungi_clef_2022_split'
]



import open_clip
import torch

# Путь к своему чекпоинту
checkpoint_path = "finetuned_openclip_v3_epoch2_h0.44423542499502605.pth"

# Название архитектуры и процессора (см. как обучал, например, 'ViT-B-32')
model_name = "ViT-B-32"
pretrained = None  # None если используешь свой чекпоинт

# Загружаем модель
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name,
    pretrained=pretrained
)
# Загрузка весов
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint)
model.to(DEVICE)
model.eval()




for domain in DATASET_DIRS:
    # Фильтруем по нужному домену и 'test'
    domain_df = df[(df['split'] == 'test') & (df['domain'] == domain)].reset_index(drop=True)
    if len(domain_df) == 0:
        continue
    # Создаём датасет для этого домена
    domain_dataset = CsvImageDataset(CSV_PATH, split='test', transform=val_transform)
    # Переопределяем выборку: только текущий домен
    domain_dataset.df = domain_df
    # Если класс_names/мэппинг отличается, пересоздай их:
    domain_dataset.class_names = sorted(domain_dataset.df['class_name'].unique())
    domain_dataset.class_to_idx = {cls: i for i, cls in enumerate(domain_dataset.class_names)}
    domain_loader = DataLoader(domain_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    # Для каждого домена пересчитай seen/unseen (на всякий случай)
    train_classes = set(df[df['split'] == 'train']['class_name'].unique())
    test_classes = set(domain_df['class_name'].unique())
    unseen_classes = test_classes - train_classes
    seen_classes = train_classes & test_classes

    # Промпты (пересчитай для каждого подмножества классов!)
    prompts = [f"a photo of a {c.lower()}" for c in domain_dataset.class_names]
    text_tokens = tokenizer(prompts).to(DEVICE)

    # Вызов функции
    evaluate_on_domain(
        model, domain_loader, text_tokens, domain_dataset.class_names,
        seen_classes, unseen_classes, domain
    )

[CUB_200_2011_split Test] Top-1 Accuracy: 0.7440, Macro F1: 0.7246
[CUB_200_2011_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.7440 | H-mean: 0.0000
[dtd_split Test] Top-1 Accuracy: 0.6213, Macro F1: 0.6042
[dtd_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.6213 | H-mean: 0.0000
[fungi_clef_2022_split Test] Top-1 Accuracy: 0.1873, Macro F1: 0.1245
[fungi_clef_2022_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.1873 | H-mean: 0.0000


odict_keys(['module.logit_scale', 'module.visual.class_embedding', 'module.visual.positional_embedding', 'module.visual.proj', 'module.visual.conv1.weight', 'module.visual.ln_pre.weight', 'module.visual.ln_pre.bias', 'module.visual.transformer.resblocks.0.attn.in_proj_weight', 'module.visual.transformer.resblocks.0.attn.in_proj_bias', 'module.visual.transformer.resblocks.0.attn.out_proj.weight', 'module.visual.transformer.resblocks.0.attn.out_proj.bias', 'module.visual.transformer.resblocks.0.ln_1.weight', 'module.visual.transformer.resblocks.0.ln_1.bias', 'module.visual.transformer.resblocks.0.mlp.c_fc.weight', 'module.visual.transformer.resblocks.0.mlp.c_fc.bias', 'module.visual.transformer.resblocks.0.mlp.c_proj.weight', 'module.visual.transformer.resblocks.0.mlp.c_proj.bias', 'module.visual.transformer.resblocks.0.ln_2.weight', 'module.visual.transformer.resblocks.0.ln_2.bias', 'module.visual.transformer.resblocks.1.attn.in_proj_weight', 'module.visual.transformer.resblocks.1.attn.

In [None]:
#finetuned_openclip_v3_epoch1_h0.3820202953874789

[CUB_200_2011_split Test] Top-1 Accuracy: 0.7487, Macro F1: 0.7293
[CUB_200_2011_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.7487 | H-mean: 0.0000
[dtd_split Test] Top-1 Accuracy: 0.6213, Macro F1: 0.6008
[dtd_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.6213 | H-mean: 0.0000
[fungi_clef_2022_split Test] Top-1 Accuracy: 0.1770, Macro F1: 0.1113
[fungi_clef_2022_split GZSL] Seen acc: 0.0000 | Unseen acc: 0.1770 | H-mean: 0.0000