In [22]:
from pathlib import Path
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import os
from torch.cuda.amp import autocast
from sentence_transformers import SentenceTransformer

In [23]:
IMAGE_FOLDER_TRAIN = "/kaggle/input/hw-2-zero-shot-image-classification/train"
IMAGE_FOLDER_TEST = "/kaggle/input/hw-2-zero-shot-image-classification/test"
OUTPUT_CSV = "/kaggle/working/submit.csv"

In [24]:
df_train = pd.read_csv(
    os.path.join(IMAGE_FOLDER_TRAIN, 'labels.csv'),
    usecols=['class', 'group', 'language', 'ID']
)
df_train.head(10)

Unnamed: 0,class,group,language,ID
0,Elefante,Mamífero,Español,1.jpg
1,bos-gaurus,Animalia,Latina,2.jpg
2,martes-americana,Animalia,Latina,3.jpg
3,poutine,Dishes,English,4.jpg
4,triceratops-horridus,Animalia,Latina,5.jpg
5,eudocimus-albus,Animalia,Latina,6.jpg
6,eudocimus-albus,Animalia,Latina,7.jpg
7,loxodonta-africana,Animalia,Latina,8.jpg
8,puma-concolor,Animalia,Latina,9.jpg
9,desmodus-rotundus,Animalia,Latina,10.jpg


In [25]:
df_test = pd.read_csv(os.path.join(IMAGE_FOLDER_TEST,'labels.csv'))
df_test.head(10)

Unnamed: 0,group,language,ID
0,Dishes,English,1.jpg
1,Animalia,Latina,2.jpg
2,Dishes,English,3.jpg
3,Dishes,English,4.jpg
4,Other,English,5.jpg
5,Dishes,English,6.jpg
6,Other,English,7.jpg
7,Animalia,Latina,8.jpg
8,Dishes,English,9.jpg
9,Other,English,10.jpg


In [26]:
df_train['language'].unique()

array(['Español', 'Latina', 'English'], dtype=object)

In [27]:
# Посчитайте частоту испанских классов
df_train[df_train['language'] == 'Español']['class'].unique()

array(['Elefante', 'flamenco', 'delfín', 'cocodrilo', 'Hormiga'],
      dtype=object)

In [28]:
la_classes = df_train[df_train['language'] == 'Latina']['class'].unique()
la_classes

array(['bos-gaurus', 'martes-americana', 'triceratops-horridus',
       'eudocimus-albus', 'loxodonta-africana', 'puma-concolor',
       'desmodus-rotundus', 'phoebetria-fusca', 'salmo-salar',
       'ara-macao', 'megaptera-novaeangliae', 'anas-platyrhynchos',
       'ursus-maritimus', 'equus-quagga', 'gallus-gallus-domesticus',
       'passerina-ciris', 'pteranodon-longiceps', 'falco-peregrinus',
       'tyrannus-tyrannus', 'poecile-atricapillus', 'pongo-abelii',
       'phascolarctos-cinereus', 'mimus-polyglottos',
       'varanus-komodoensis', 'pavo-cristatus', 'mammuthus-primigeniu',
       'codium-fragile', 'balaenoptera-musculus', 'colaptes-auratus',
       'pantherophis-alleghaniensis', 'circus-hudsonius', 'bison-bison',
       'ovis-canadensis', 'hapalochlaena-maculosa', 'ovis-aries',
       'cathartes-aura', 'monodon-monoceros', 'eidolon-helvum',
       'felis-catus', 'heterocera', 'giraffa-camelopardalis',
       'panthera-pardus', 'icterus-gularis', 'okapia-johnstoni',
     

In [29]:
df_train['group'].unique()

array(['Mamífero', 'Animalia', 'Dishes', 'aves', 'Other', 'Cat',
       'Anfibio', 'Insectos', 'Faces'], dtype=object)

In [30]:
SPANISH_TO_ENGLISH = {
    "flamenco": "flamingo",
    "cocodrilo": "crocodile",
    "Elefante": "elephant",
    "delfín": "dolphin",
    "Hormiga": "ant"
}

In [31]:
LATIN_TO_ENGLISH = {
    # Млекопитающие
    "bos-gaurus": "gaur",
    "martes-americana": "american marten",
    "loxodonta-africana": "african elephant",
    "puma-concolor": "puma",
    "desmodus-rotundus": "common vampire bat",
    "ursus-maritimus": "polar bear",
    "equus-quagga": "quagga",
    "gallus-gallus-domesticus": "domestic chicken",
    "pongo-abelii": "sumatran orangutan",
    "phascolarctos-cinereus": "koala",
    "bison-bison": "american bison",
    "ovis-canadensis": "bighorn sheep",
    "ovis-aries": "domestic sheep",
    "eidolon-helvum": "straw-coloured fruit bat",
    "felis-catus": "domestic cat",
    "giraffa-camelopardalis": "giraffe",
    "panthera-pardus": "leopard",
    "icterus-gularis": "orange-breasted blackbird",
    "okapia-johnstoni": "okapi",
    "homo-sapiens": "human",
    "centrochelys-sulcata": "african spurred tortoise",
    "equus-caballus": "horse",
    "hippopotamus-amphibius": "hippopotamus",
    "procyon-lotor": "raccoon",
    "panthera-leo": "lion",
    "enhydra-lutris": "sea otter",
    "iguana-iguana": "green iguana",
    "bradypus-variegatus": "brown-throated sloth",
    "rattus-rattus": "black rat",
    "hydrurga-leptonyx": "leopard seal",
    "tapirus": "tapir",
    "dugong-dugon": "dugong",
    "panthera-onca": "jaguar",
    "lemur-catta": "ring-tailed lemur",
    "lepus-americanus": "snowshoe hare",
    "alces-alces": "moose",
    "dasypus-novemcinctus": "nine-banded armadillo",
    "vulpes-vulpes": "red fox",
    "taurotragus-oryx": "eland",
    "canis-lupus": "gray wolf",
    "canis-lupus-familiaris": "domestic dog",
    "rusa-unicolor": "sambar deer",
    "connochaetes-gnou": "black wildebeest",
    "panthera-tigris": "tiger",
    "gorilla-gorilla": "western gorilla",
    "ailuropoda-melanoleuca": "giant panda",
    "ailurus-fulgens": "red panda",
    "ursus-arctos-horribilis": "grizzly bear",
    "acrocinonyx jubatus": "cheetah",
    "acinonyx-jubatus": "cheetah",
    "bos-taurus": "cow",

    # Птицы (aves)
    "eudocimus-albus": "white ibis",
    "anas-platyrhynchos": "mallard duck",
    "passerina-ciris": "painted bunting",
    "poecile-atricapillus": "black-capped chickadee",
    "mimus-polyglottos": "northern mockingbird",
    "pavo-cristatus": "peacock",
    "colaptes-auratus": "northern flicker",
    "circus-hudsonius": "northern harrier",
    "cathartes-aura": "turkey vulture",
    "melanerpes-carolinus": "red-bellied woodpecker",
    "ardea-herodias": "great blue heron",
    "mellisuga-helenae": "bee hummingbird",
    "aquila-chrysaetos": "golden eagle",
    "haliaeetus-leucocephalus": "bald eagle",
    "correlophus-ciliatus": "crested gecko",
    "spheniscus-demersus": "african penguin",
    "aptenodytes-forsteri": "emperor penguin",
    "phoenicopterus-ruber": "american flamingo",
    "branta-canadensis": "canada goose",
    "struthio-camelus": "ostrich",
    "geococcyx-californianus": "greater roadrunner",
    "mergus-serrator": "red-breasted merganser",
    "cyanocitta-cristata": "blue jay",
    "icterus-galbula": "baltimore oriole",
    "icterus-spurius": "orchard oriole",
    "aethia-cristatella": "crested auklet",
    "turdus-migratorius": "american robin",
    "thryothorus-ludovicianus": "carolina wren",
    "cardinalis-cardinalis": "northern cardinal",
    "falco-peregrinus": "peregrine falcon",

    # Рептилии и амфибии
    "varanus-komodoensis": "komodo dragon",
    "heloderma-suspectum": "gila monster",
    "agkistrodon-contortrix": "copperhead snake",
    "crocodylus-niloticus": "nile crocodile",
    "gavialis-gangeticus": "gharial",
    "telmatobufo-bullocki": "bullock's false toad",
    "agalychnis-callidryas": "red-eyed tree frog",
    "phyllobates-terribilis": "golden poison frog",
    "dendrobatidae": "poison dart frog",
    "chelonia-mydas": "green sea turtle",
    "dermochelys-coriacea": "leatherback sea turtle",
    "pantherophis-guttatus": "corn snake",
    "pantherophis-alleghaniensis": "northern pine snake",
    "lampropeltis-triangulum": "milk snake",
    "crotalus-atrox": "western diamondback rattlesnake",
    "malayopython-reticulatus": "reticulated python",
    "eunectes-murinus": "green anaconda",
    "ophiophagus-hannah": "king cobra",
    "hapalochlaena-maculosa": "southern blue-ringed octopus",
    
    # Рыбы и морские животные
    "salmo-salar": "atlantic salmon",
    "megaptera-novaeangliae": "humpback whale",
    "balaenoptera-musculus": "blue whale",
    "tursiops-truncatus": "bottlenose dolphin",
    "delphinapterus-leucas": "beluga whale",
    "orcinus-orca": "orca",
    "physeter-macrocephalus": "sperm whale",
    "monodon-monoceros": "narwhal",
    "sphyrna-mokarran": "great hammerhead shark",
    "carcharodon-carcharias": "great white shark",
    "pterois-volitans": "red lionfish",
    "pterois-mombasae": "frillfin lionfish",
    "enteroctopus-dofleini": "giant pacific octopus",
    "architeuthis-dux": "giant squid",
    "coelacanthiformes": "coelacanth",
    "ina-geoffrensis": "amazon river dolphin",
    "physalia-physalis": "portuguese man o' war",

    # Насекомые и беспозвоночные
    "heterocera": "moth",
    "formicidae": "ant",
    "ceratitis-capitata": "mediterranean fruit fly",
    "musca-domestica": "housefly",
    "periplaneta-americana": "american cockroach",
    "apis-mellifera": "honey bee",
    "danaus-plexippus": "monarch butterfly",
    "papilio-glaucus": "eastern tiger swallowtail",
    "centruroides-vittatus": "striped bark scorpion",
    "codium-fragile": "green seaweed",

    # Динозавры и ископаемые
    "triceratops-horridus": "triceratops",
    "tyrannosaurus-rex": "tyrannosaurus",
    "stegosaurus-stenops": "stegosaurus",
    "ankylosaurus-magniventris": "ankylosaurus",
    "iguanodon-bernissartensis": "iguanodon",
    "spinosaurus-aegyptiacus": "spinosaurus",
    "smilodon-populator": "sabertooth cat",
    "mammuthus-primigeniu": "woolly mammoth",
    "diplodocus": "diplodocus",
    "pteranodon-longiceps": "pteranodon",
    "trilobita": "trilobite",

    # Прочие
    "ara-macao": "scarlet macaw",
    "correlophus-ciliatus": "crested gecko",
    "sciurus-carolinensis": "eastern gray squirrel",
    "vultur-gryphus": "andean condor",
}

In [32]:
def normalize_class_name(class_name: str, language: str) -> str:
    if language == 'Latina':
        key = class_name
        return LATIN_TO_ENGLISH.get(key, class_name.replace('-', ' '))
    elif language == 'Español':
        return SPANISH_TO_ENGLISH.get(class_name, class_name) 
    return class_name  # оставляем как есть: "Elefante", "poutine" и т.д.

# Применяем
df_train['class_norm'] = df_train.apply(
    lambda row: normalize_class_name(row['class'], row['language']),
    axis=1
)

# Проверим несколько примеров
sample_check = df_train[df_train['language'].isin(['Latina', 'Español'])][['class', 'language', 'class_norm']].head(5)
print(sample_check)

                  class language       class_norm
0              Elefante  Español         elephant
1            bos-gaurus   Latina             gaur
2      martes-americana   Latina  american marten
4  triceratops-horridus   Latina      triceratops
5       eudocimus-albus   Latina       white ibis


In [33]:
GROUP_PROMPT_TEMPLATES = {
    "Mamífero": [
        "a mammal called {}",
        "a photo of a {} mammal in its natural habitat",
        "a close-up of a {} mammal with natural lighting",
        "a wild {} mammal standing in a forest",
        "a {} mammal walking through tall grass",
        "a high-quality photo of a {} mammal, sharp focus",
        "a {} mammal captured in the wild, daylight",
        "a furry {} mammal with detailed texture",
        "a {} mammal looking directly at the camera",
        "a {} mammal in a zoo enclosure, clear view",
    ],
    
    "Animalia": [
        "an animal known as {}",
        "a photo of the animal {} in nature",
        "a living creature: {}, captured in the wild",
        "a {} animal in its natural environment",
        "a high-resolution image of the {} animal",
        "a {} animal with detailed features, natural background",
        "a {} animal standing still, full body visible",
        "a realistic photo of {} in daylight",
        "a {} animal in motion, clear and focused",
        "a {} animal observed from a safe distance in the wild",
    ],
    
    "Dishes": [
        "a delicious dish named {}",
        "a photo of the food {} served on a white ceramic plate",
        "a plate of {} with garnish, overhead view",
        "a gourmet {} plated professionally in a restaurant",
        "a close-up of {} food with steam rising, warm lighting",
        "a high-quality food photography of {}",
        "a {} dish on a wooden table, natural light",
        "a mouth-watering {} served fresh",
        "a traditional {} cuisine, presented beautifully",
        "a {} recipe plated with herbs and spices",
    ],
    
    "aves": [
        "a bird species called {}",
        "a photo of a {} bird perched on a branch",
        "a wild {} bird in flight against a blue sky",
        "a {} bird standing on the ground, full body visible",
        "a close-up of a {} bird with detailed feathers",
        "a {} bird in its natural habitat: forest or wetland",
        "a {} bird singing on a tree branch, morning light",
        "a colorful {} bird with vibrant plumage",
        "a {} bird captured with a telephoto lens, sharp focus",
        "a {} bird resting on a rock near water",
    ],
    
    "Other": [
        "a photo of {}",
        "an object named {} on a plain background",
        "something called {}, isolated on white",
        "a high-resolution image of the object {}",
        "a {} object with clear detailing and sharp focus",
        "a realistic photo of {} in neutral lighting",
        "a {} object placed on a clean surface",
        "a centered photo of {} with no distractions",
        "a {} object captured in studio lighting",
        "a detailed view of the {} object, macro lens",
    ],
    
    "Cat": [
        "a domestic cat named {}",
        "a photo of a {} cat sitting indoors",
        "a cute {} cat with bright eyes looking at the camera",
        "a fluffy {} cat resting on a sofa",
        "a {} cat playing with a toy, natural light",
        "a close-up portrait of a {} cat",
        "a {} cat in a cozy home environment",
        "a {} cat with soft fur, high detail",
        "a {} cat lying in sunlight near a window",
        "a playful {} kitten captured in motion",
    ],
    
    "Anfibio": [
        "an amphibian called {}",
        "a photo of the amphibian {} near water",
        "a {} frog or salamander on a wet leaf",
        "a {} amphibian in a moist, shaded forest floor",
        "a close-up of a {} amphibian with glistening skin",
        "a {} amphibian sitting on a rock by a pond",
        "a {} frog with bright colors in tropical rainforest",
        "a nocturnal {} amphibian under dim light",
        "a {} amphibian partially submerged in water",
        "a realistic photo of {} amphibian in natural habitat",
    ],
    
    "Insectos": [
        "an insect known as {}",
        "a photo of an insect: {} on a flower",
        "a tiny {} bug crawling on a green leaf",
        "a close-up macro photo of a {} insect",
        "a {} insect with detailed wings and body",
        "a {} bug in natural daylight on a plant stem",
        "a {} insect captured in flight or resting",
        "a realistic image of {} insect on bark or soil",
        "a {} bug with iridescent colors, sharp focus",
        "a microscopic view of the {} insect (simulated photo)",
    ],
    
    "Faces": [
        "a human face of {}",
        "a photo of a person named {} looking at the camera",
        "a portrait of {} with neutral expression, soft lighting",
        "a close-up of {}'s face with clear skin details",
        "a {} person in front of a blurred background",
        "a high-quality facial portrait of {}",
        "a {} individual photographed in studio conditions",
        "a frontal view of {}'s face, even lighting",
        "a {} person smiling gently, natural light",
        "a realistic photo of {}'s face, no makeup, minimal editing",
    ],
}

In [34]:
GENERAL_PROMPT_TEMPLATES = [
    'a photo of {}',
    '{}',
]

In [35]:
# Собираем уникальные строки: (class_norm, original_class, group)
unique_class_info = df_train[['class_norm', 'class', 'group']].drop_duplicates()

prompts4clf = []
prompt_to_class = {}

for _, row in unique_class_info.iterrows():
    norm_class = row['class_norm']
    orig_class = row['class']
    group = row['group']
    
    # Общие шаблоны
    for template in GENERAL_PROMPT_TEMPLATES:
        prompt = template.format(norm_class)
        prompts4clf.append(prompt)
        prompt_to_class[prompt] = orig_class

    # Групповые шаблоны, если группа известна
    if group in GROUP_PROMPT_TEMPLATES:
        for template in GROUP_PROMPT_TEMPLATES[group]:
            prompt = template.format(norm_class)
            prompts4clf.append(prompt)
            prompt_to_class[prompt] = orig_class
    else:
        # На случай новых групп (защита)
        for template in GENERAL_PROMPT_TEMPLATES:
            prompt = template.format(norm_class)
            prompts4clf.append(prompt)
            prompt_to_class[prompt] = orig_class

In [36]:
# Сортируем уникальные классы
unique_classes = sorted(set(prompt_to_class.values()))
# Каждому классу сопоставляем его индекс
class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}

# Маппинг: индекс промпта -> индекс класса
prompt_to_class_idx = [class_to_idx[prompt_to_class[p]] for p in prompts4clf]
# Массив, в котором каждому prompt_to_class_idx[i] сопоставляется с i промптом этого класса

num_classes = len(unique_classes)

# Загрузка модели и процессора 

In [37]:
MODEL_PATH = "/kaggle/input/clip-mod/kaggle/util/clip-vit-base-patch32-saved"
device = "cpu"

# Модель содержит и текстовый, и визуальный энкодеры
model = CLIPModel.from_pretrained(MODEL_PATH).to(device)
# Отвечает за предобработку текста и изображений в формат, понятный модели
processor = CLIPProcessor.from_pretrained(MODEL_PATH)

# Словарь с ключами 
text_inputs = processor(
    text=prompts4clf,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=77
).to(device)

with torch.no_grad():
    with autocast():
        text_features = model.get_text_features(**text_inputs) # тензор (#промтов, размерность)
        # Нормализация к единичной длине
        text_features = text_features / text_features.norm(dim=-1, keepdim=True) 

  with autocast():


## Загрузка всех изображений

In [38]:
imgs = [
        p for p in Path(IMAGE_FOLDER_TRAIN).iterdir()
        if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}
    ]
imgs = sorted(imgs, key=lambda x: x.name)

## Предсказание класса изображения

In [39]:
def classify_images_ensemble(image_paths, model, processor, text_features, prompt_to_class_idx, unique_classes, device, batch_size=32):
    results = []
    num_classes = len(unique_classes)
    
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Classifying"):
        # Берем изображения по пакетам
        batch_paths = image_paths[i:i + batch_size]
        images = [Image.open(p).convert("RGB") for p in batch_paths]
        
        image_inputs = processor(images=images, return_tensors="pt").to(device)
        
        with torch.no_grad():
            with autocast():
                image_features = model.get_image_features(**image_inputs)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits = 100.0 * image_features @ text_features.T  # сходства (больше вес лучше)
                
                # Преобразуем в вероятности (изображение, промпт)
                probs = logits.softmax(dim=-1)  
                
                # Агрегируем по классам: усредняем вероятности
                class_probs = torch.zeros(len(images), num_classes, device=device)
                prompt_count = torch.zeros(num_classes, device=device)
                
                for prompt_idx, cls_idx in enumerate(prompt_to_class_idx):
                    # Для каждого промпта добавляем его вероятность к соответствующему классу
                    class_probs[:, cls_idx] += probs[:, prompt_idx]
                    prompt_count[cls_idx] += 1

                # #промтов на каждый класс
                prompt_count = prompt_count.clamp(min=1)
                # Усредненная вероятность по классу
                class_probs /= prompt_count.unsqueeze(0)
                # Для каждого изображения выбирается класс с максимальной усредненной вероятностью
                pred_class_indices = class_probs.argmax(dim=1)
        
        for j, img_path in enumerate(batch_paths):
            pred_class = unique_classes[pred_class_indices[j].item()]
            results.append({"ID": img_path.name, "predicted_class": pred_class})
    
    return pd.DataFrame(results)

In [40]:
train_results = classify_images_ensemble(
        imgs, model, processor, text_features,
        prompt_to_class_idx, unique_classes, device, batch_size=32
)

eval_df = train_results.merge(df_train[['ID', 'class']], on='ID', how='inner')
eval_df.rename(columns={'class': 'true_class'}, inplace=True)

accuracy = (eval_df['true_class'] == eval_df['predicted_class']).mean()
print(f"\nСреднее совпадение на: {accuracy * 100:.2f}%")

  with autocast():
Classifying: 100%|██████████| 45/45 [02:24<00:00,  3.21s/it]


Среднее совпадение на: 77.76%





In [41]:
test_image_paths = [
    p for p in Path(IMAGE_FOLDER_TEST).iterdir()
    if p.is_file() and p.suffix.lower() in {'.jpg', '.jpeg', '.png'}
]
test_image_paths = sorted(test_image_paths, key=lambda x: x.name)

In [42]:
test_results = classify_images_ensemble(
    test_image_paths, 
    model, 
    processor, 
    text_features,
    prompt_to_class_idx, 
    unique_classes, 
    device, 
    batch_size=32
)

  with autocast():
Classifying: 100%|██████████| 45/45 [02:29<00:00,  3.33s/it]


In [43]:
test_results['ID'] = test_results['ID'].apply(lambda x: Path(x).name)
submit_df = test_results
submit_df.head()

Unnamed: 0,ID,predicted_class
0,1.jpg,ice_cream
1,10.jpg,airplanes
2,100.jpg,ice_cream
3,1000.jpg,triceratops-horridus
4,1001.jpg,panthera-onca


In [44]:
submit_df.to_csv(OUTPUT_CSV, index=False)