In [3]:
import os
import json
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

In [None]:
# === Константы ===
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 12
LEARNING_RATE = 0.001
TRAIN_DIR = "train_data"
TEST_DIR = "test_data"
TRAIN_CSV = "train_data.csv"
TEST_CSV = "test_data.csv"
MODEL_PATH = "model.pth"

In [None]:
# Классы цветов
CLASSES = ["bezhevyi", "belyi", "biryuzovyi", "bordovyi", "goluboi", "zheltyi", "zelenyi",
           "zolotoi", "korichnevyi", "krasnyi", "oranzhevyi", "raznocvetnyi", "rozovyi",
           "serebristyi", "seryi", "sinii", "fioletovyi", "chernyi"]
CLASS_TO_INDEX = {color: idx for idx, color in enumerate(CLASSES)}
INDEX_TO_CLASS = {idx: color for color, idx in CLASS_TO_INDEX.items()}

In [6]:
# Трансформации для изображений
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# === Dataset ===
class ProductDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, train=True):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.train = train  # Флаг train/test
        # Фильтруем данные с помощью функции _image_exists
        self.data = self.data[self.data["id"].apply(self._image_exists)]

    def _image_exists(self, img_id):
        # Проверяем, существует ли файл
        img_path = os.path.join(self.img_dir, str(img_id) + ".jpg")
        return os.path.exists(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Получаем ID изображения
        img_id = str(self.data.iloc[idx]["id"])
        img_path = os.path.join(self.img_dir, f"{img_id}.jpg")

        # Загружаем изображение
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Ошибка при загрузке {img_path}: {e}")
            # Возвращаем "пустое" изображение и метку
            image = torch.zeros((3, 224, 224))  # Пример для изображения 224x224
            label = -1  # Специальное значение для ошибки
            return image, label

        # Обработка меток
        if self.train:
            label = CLASS_TO_INDEX.get(self.data.iloc[idx]["target"], -1)  # Используем -1 для ошибок
            return image, label
        else:
            category = self.data.iloc[idx].get("category", "unknown")  # Используем "unknown", если категория отсутствует
            return image, img_id, category

In [8]:
# === Фильтр batch ===
def collate_fn(batch):
    # Убираем None из батча
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None  # Если весь батч состоит из None, возвращаем None
    return torch.utils.data.dataloader.default_collate(batch)

# === Модель ===
def get_model():
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASSES))
    return model.to(DEVICE)


In [None]:

# === Обучение ===
def train():
    dataset = ProductDataset(TRAIN_CSV, TRAIN_DIR, transform, train=True)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn)

    model = get_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            if batch is None:
                print (batch)
                continue  # Пропуск пустых батчей
            
            images, labels = batch
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {running_loss / len(train_loader)}")

    torch.save(model.state_dict(), MODEL_PATH)
    print("✅ Модель сохранена!")



In [10]:


# === Предсказание ===
def predict():
    dataset = ProductDataset(TEST_CSV, TEST_DIR, transform, train=False)
    test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

    model = get_model()
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()

    results = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            if batch is None:
                continue

            images, img_ids, categories = batch
            images = images.to(DEVICE)
            outputs = torch.softmax(model(images), dim=1)  # Преобразуем в вероятности
            
            for i in range(len(img_ids)):
                img_id = img_ids[i]
                category = categories[i]
                probs = {CLASSES[j]: float(outputs[i, j]) for j in range(len(CLASSES))}
                predicted_color = max(probs, key=probs.get)
                
                results.append({
                    "id": img_id,
                    "category": category,
                    "predict_proba": json.dumps(probs, ensure_ascii=False),
                    "predict_color": predicted_color
                })

    df = pd.DataFrame(results)
    df.to_csv("submission.csv", index=False)
    print("✅ Предсказания сохранены в submission.csv")



In [None]:
train()  # Обучение модели


Epoch 1/12: 100%|██████████| 521/521 [09:59<00:00,  1.15s/it]


Epoch 1: Loss = 1.2072897588909244


Epoch 2/12: 100%|██████████| 521/521 [09:34<00:00,  1.10s/it]


Epoch 2: Loss = 0.9779191263120143


Epoch 3/12: 100%|██████████| 521/521 [09:33<00:00,  1.10s/it]


Epoch 3: Loss = 0.863504676695291


Epoch 4/12: 100%|██████████| 521/521 [09:37<00:00,  1.11s/it]


Epoch 4: Loss = 0.7596533661726111


Epoch 5/12: 100%|██████████| 521/521 [09:37<00:00,  1.11s/it]


Epoch 5: Loss = 0.6513175717814184


Epoch 6/12: 100%|██████████| 521/521 [09:39<00:00,  1.11s/it]


Epoch 6: Loss = 0.5380434147341466


Epoch 7/12: 100%|██████████| 521/521 [09:38<00:00,  1.11s/it]


Epoch 7: Loss = 0.44627581273639955


Epoch 8/12: 100%|██████████| 521/521 [09:39<00:00,  1.11s/it]


Epoch 8: Loss = 0.36023641184630184


Epoch 9/12: 100%|██████████| 521/521 [09:37<00:00,  1.11s/it]


Epoch 9: Loss = 0.3011816359498679


Epoch 10/12: 100%|██████████| 521/521 [09:38<00:00,  1.11s/it]


Epoch 10: Loss = 0.2476881112681698


Epoch 11/12: 100%|██████████| 521/521 [09:39<00:00,  1.11s/it]


Epoch 11: Loss = 0.22394631241739596


Epoch 12/12: 100%|██████████| 521/521 [09:38<00:00,  1.11s/it]

Epoch 12: Loss = 0.21181644923312878
✅ Модель сохранена!





In [13]:
predict() # Предсказание на тесте

Predicting: 100%|██████████| 344/344 [00:06<00:00, 53.85it/s]

✅ Предсказания сохранены в submission.csv



