In [3]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import timm
import torchvision.transforms as transforms

# Mapping dictionaries
TRANSLIT_TO_RU = {
    'bezhevyi': 'бежевый',
    'belyi': 'белый',
    'biryuzovyi': 'бирюзовый',
    'bordovyi': 'бордовый',
    'goluboi': 'голубой',
    'zheltyi': 'желтый',
    'zelenyi': 'зеленый',
    'zolotoi': 'золотой',
    'korichnevyi': 'коричневый',
    'krasnyi': 'красный',
    'oranzhevyi': 'оранжевый',
    'raznocvetnyi': 'разноцветный',
    'rozovyi': 'розовый',
    'serebristyi': 'серебряный',
    'seryi': 'серый',
    'sinii': 'синий',
    'fioletovyi': 'фиолетовый',
    'chernyi': 'черный'
}

# Create reverse mapping from Russian to transliteration
RU_TO_TRANSLIT = {v: k for k, v in TRANSLIT_TO_RU.items()}

# Colors dictionary
COLORS = {
    'бежевый': 'beige',
    'белый': 'white',
    'бирюзовый': 'turquoise',
    'бордовый': 'burgundy',
    'голубой': 'blue',
    'желтый': 'yellow',
    'зеленый': 'green',
    'золотой': 'gold',
    'коричневый': 'brown',
    'красный': 'red',
    'оранжевый': 'orange',
    'разноцветный': 'variegated',
    'розовый': 'pink',
    'серебряный': 'silver',
    'серый': 'gray',
    'синий': 'blue',
    'фиолетовый': 'purple',
    'черный': 'black'
}

# Categories
CATEGORIES = ['одежда для девочек', 'столы', 'стулья', 'сумки']

# Global variable to store the loaded model
MODEL = None
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ColorClassifier(nn.Module):
    def __init__(self, num_colors, num_categories):
        super().__init__()
        # Using a lighter and faster ViT variant
        self.backbone = timm.create_model(
            'beitv2_large_patch16_224', 
            pretrained=True, 
            num_classes=0,  # Without top classification layer
        )
        
        # Fixed most weights to speed up training
        for param in list(self.backbone.parameters())[:-30]:
            param.requires_grad = False
            
        # Extension for fast inference with caching
        self.backbone.reset_classifier(0)
        
        # Model feature dimension
        self.feature_dim = self.backbone.embed_dim
        
        # Category embedding
        self.category_embedding = nn.Embedding(num_categories, 32)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim + 32, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_colors)
        )
        
        # For torch.jit optimization
        self.example_input = torch.zeros(1, 3, 224, 224)
        self.example_category = torch.LongTensor([0])
        
    def forward(self, x, category):
        features = self.backbone(x)
        
        category_emb = self.category_embedding(category)
        combined = torch.cat([features, category_emb], dim=1)
        
        return self.classifier(combined)

def load_model(model_path):
    """
    Loads a previously trained model from the specified path.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    try:
        # Try to load as a TorchScript model
        model = torch.jit.load(model_path, map_location=device)
        print("Loaded optimized TorchScript model")
        return model
    except:
        # Load as a regular model
        print("Loading model from standard weights...")
        model = ColorClassifier(len(COLORS), len(CATEGORIES))
        
        checkpoint = torch.load(model_path, map_location=device)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        # Set model to evaluation mode
        model.eval()
        model = model.to(device)
        
        return model

def initialize_model(model_path="vit_color_classifier.pth"):
    """
    Initialize the model once and store it in a global variable.
    This function should be called once at the beginning of your application.
    
    Args:
        model_path (str): Path to the model weights
        
    Returns:
        The loaded model
    """
    global MODEL
    if MODEL is None:
        print("Loading model for the first time...")
        MODEL = load_model(model_path)
    else:
        print("Model already loaded, reusing...")
    
    return MODEL

def predict_color(image_path, category_name):
    """
    Predicts the color of a product from an image and its category.
    Uses the globally loaded model (make sure to call initialize_model first).
    
    Args:
        image_path (str): Path to the product image
        category_name (str): Category name of the product (must be one of CATEGORIES)
        
    Returns:
        tuple: (best_color, top5_colors) where:
            - best_color (str): The most likely color in transliterated form (e.g. 'bezhevyi')
            - top5_colors (dict): Dictionary with top 5 colors (in transliterated form) and their probabilities
    """
    global MODEL, DEVICE
    
    # Check if model is loaded
    if MODEL is None:
        raise RuntimeError("Model not initialized. Please call initialize_model() first.")
    
    # Validate category name
    if category_name not in CATEGORIES:
        raise ValueError(f"Category must be one of: {CATEGORIES}")
    
    # Check if image exists
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
    
    # Prepare image transformation
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    
    # Load and transform the image
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(DEVICE)
    except Exception as e:
        raise RuntimeError(f"Error processing image: {str(e)}")
    
    # Get category index
    category_idx = CATEGORIES.index(category_name)
    category_tensor = torch.tensor([category_idx], dtype=torch.long).to(DEVICE)
    
    # Make prediction
    with torch.no_grad():
        outputs = MODEL(image_tensor, category_tensor)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
    
    # Get the color names in Russian
    color_list = list(COLORS.keys())
    
    # Find the best color
    best_color_idx = np.argmax(probs)
    best_color_ru = color_list[best_color_idx]
    
    # Convert to transliterated form
    best_color = RU_TO_TRANSLIT[best_color_ru]
    
    # Get top 5 colors with probabilities
    top5_indices = np.argsort(probs)[-5:][::-1]
    
    # Convert colors to transliterated form in the result
    top5_colors = {RU_TO_TRANSLIT[color_list[idx]]: float(probs[idx]) for idx in top5_indices}
    
    return best_color, top5_colors

# Example usage:
# 1. Initialize the model once at the beginning
# model = initialize_model("vit_color_classifier.pth")
#
# 2. Make predictions as many times as needed without reloading the model
# best_color, top5_colors = predict_color("path/to/image.jpg", "столы")
# print(f"Best color: {best_color}")
# for color, prob in top5_colors.items():
#     print(f"  {color}: {prob:.4f}")
#
# 3. Make another prediction with the same model
# best_color2, top5_colors2 = predict_color("path/to/another_image.jpg", "стулья")


In [4]:
model = initialize_model("/kaggle/input/macro_/pytorch/default/1/macro_weights.pth")

Loading model for the first time...
Loading model from standard weights...


  checkpoint = torch.load(model_path, map_location=device)


In [5]:
best_color, top5_colors = predict_color("/kaggle/input/colors/dataset_colors/test_data/19762915377.png", "одежда для девочек")
print(f"Best color: {best_color}")  # Now returns 'bezhevyi' instead of 'бежевый'
print("Top 5 colors:")
for color, prob in top5_colors.items():
    print(f"  {color}: {prob:.4f}")  # Colors are now in transliterated form


Best color: zelenyi
Top 5 colors:
  zelenyi: 0.9715
  chernyi: 0.0067
  raznocvetnyi: 0.0044
  korichnevyi: 0.0040
  sinii: 0.0034
