In [None]:
from sklearn.metrics import classification_report, accuracy_score
from transformers import CLIPModel, CLIPProcessor
from typing import List, Union, Optional, Tuple
from PIL import Image
import pandas as pd
import numpy as np
import torch
import os

In [None]:
class FashionClipEncoder:
    def __init__(self,
                 model_name: str = "patrickjohncyh/fashion-clip",
                 device: Optional[str] = None) -> None:
        """
        Initializes the FashionCLIP encoder.
        
        Args:
            model_name: Hugging Face model identifier
            device: Optional device override ('cuda', 'cpu', or None for auto-detection)
        """
        self.device = device
        if not self.device:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

    def encode_images(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 32,
        verbose: bool = False,
        normalize: bool = True
    ) -> np.ndarray:
        """
        Encodes images in batches.
        
        Args:
            images: List of image paths or PIL Images
            batch_size: Number of images to process simultaneously
            verbose: Whether to print progress
            normalize: Whether to normalize embeddings to unit vectors
            
        Returns:
            Numpy array of all image embeddings (len(images), embedding_dim)
        """
        if not isinstance(images, list):
            raise ValueError("Input must be a list of images")

        all_embeddings = []
        for i in range(0, len(images), batch_size):
            batch = images[i:i + batch_size]
            if verbose:
                print(f"Processing image batch {i//batch_size + 1}/{(len(images)-1)//batch_size + 1}")

            loaded_images = []
            for img in batch:
                if isinstance(img, str):
                    loaded_images.append(Image.open(img))
                else:
                    loaded_images.append(img)

            inputs = self.processor(
                images=loaded_images,
                return_tensors="pt",
                padding=True
            ).to(self.device)

            with torch.no_grad():
                embeddings = self.model.get_image_features(**inputs)
                if normalize:
                    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
                all_embeddings.append(embeddings.cpu().numpy())

        return np.concatenate(all_embeddings)

    def encode_texts(
        self,
        texts: List[str],
        batch_size: int = 128,
        verbose: bool = False,
        normalize: bool = True
    ) -> np.ndarray:
        """
        Encodes texts in batches.
        
        Args:
            texts: List of text strings to encode
            batch_size: Number of texts to process simultaneously
            verbose: Whether to print progress
            normalize: Whether to normalize embeddings to unit vectors
            
        Returns:
            Numpy array of all text embeddings (len(texts), embedding_dim)
        """
        if not isinstance(texts, list):
            raise ValueError("Input must be a list of text strings")

        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            if verbose:
                print(f"Processing text batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")

            inputs = self.processor(
                text=batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=77
            ).to(self.device)

            with torch.no_grad():
                embeddings = self.model.get_text_features(**inputs)
                if normalize:
                    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
                all_embeddings.append(embeddings.cpu().numpy())

        return np.concatenate(all_embeddings)

In [None]:
def get_test_set() -> Tuple[List[str], List[str]]:
    base_dir = "/kaggle/input/fashion-styles-3/fashion_styles_3"
    image_paths = []
    true_labels = []
    for style_dir in os.listdir(base_dir):
        style_path = os.path.join(base_dir, style_dir)
        for img_name in os.listdir(style_path):
            img_path = os.path.join(style_path, img_name)
            image_paths.append(img_path)
            true_labels.append(style_dir)
    return image_paths, true_labels

def get_style_descriptions() -> dict:
    return {
        "formal": "business formal, sharply tailored suit, polished",
        "streetwear": "streetwear, urban casual, relax",
        "minimalist": "minimal, clean, monochrome, high‑quality, neutral tones, sophisticated",
        "athleisure": "athleisure, sporty outfit"
    }

def get_predictions(encoder: FashionClipEncoder,
                    image_paths: List[str],
                    labels_desc: dict,
                    threshold: float = 0.2) -> List[str]:
    label_items = labels_desc.items()
    labels = [item[0] for item in label_items]
    labels.append("other")
    descriptions = [item[1] for item in label_items]
    text_embs = encoder.encode_texts(descriptions, batch_size=64, verbose=True)
    image_embs = encoder.encode_images(image_paths, batch_size=64, verbose=True)
    sim_matrix = image_embs @ text_embs.T
    predictions, confidence = np.argmax(sim_matrix, axis=1), np.max(sim_matrix, axis=1)
    predictions = np.where(confidence >= threshold, predictions, len(labels)-1)
    pred_labels = []
    for label_idx in predictions:
        pred_labels.append(labels[label_idx])
    return pred_labels

In [None]:
encoder = FashionClipEncoder()
image_paths, true_labels = get_test_set()
style_descriptions = get_style_descriptions()
pred_labels = get_predictions(encoder, image_paths, style_descriptions)
labels = list(style_descriptions.keys()) + ["other"]

In [None]:
print("\n" + "="*50)
print("Отчёт о классификации:")
print(classification_report(true_labels, pred_labels))

print("\n" + "="*50)
print(f"Общая точность (Accuracy): {accuracy_score(true_labels, pred_labels):.4f}")

print("\n" + "="*50)
print("Cross-tab отчет:")
cross_tab = pd.crosstab(pd.Series(true_labels, name='Истинный стиль'),
                        pd.Series(pred_labels, name='Предсказанный стиль'),
                        margins=True)
print(cross_tab)
