In [1]:
'''
data - колонки main_img, label (multi, single, unknown)

нужен классификатор который (что-то быстрое, но мощное, файнтюн кого-то)
принимает картинку, говорит сколько на ней товаров - много, мало, неизвестно (сюда попадаем, если модель не уверена в много/мало)

обучить с тестом и валидацией, выводить метрики

сохранить классификатор

написать функцию is_multi(model_dir, img_path) -> три вероятности (multi, single, unknown)
при этом она сама выгружает модель из директории

написать функцию is_multi_packimg(model_dir, img_pathes=[]) -> три вероятности (multi, single, unknown)
проверяет каждую картинку, и агрегирует результат, давая его для всего набора переданных картинок
'''


'\ndata - колонки main_img, label (multi, single, unknown)\n\nнужен классификатор который (что-то быстрое, но мощное, файнтюн кого-то)\nпринимает картинку, говорит сколько на ней товаров - много, мало, неизвестно (сюда попадаем, если модель не уверена в много/мало)\n\nобучить с тестом и валидацией, выводить метрики\n\nсохранить классификатор\n\nнаписать функцию is_multi(model_dir, img_path) -> три вероятности (multi, single, unknown)\nпри этом она сама выгружает модель из директории\n\nнаписать функцию is_multi_packimg(model_dir, img_pathes=[]) -> три вероятности (multi, single, unknown)\nпроверяет каждую картинку, и агрегирует результат, давая его для всего набора переданных картинок\n'

In [None]:
import pandas as pd
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import timm
import torch.nn as nn
import numpy as np

CLASS_NAMES = ['multi', 'single', 'unknown']
LABEL2IDX = {name: i for i, name in enumerate(CLASS_NAMES)}
IDX2LABEL = {i: name for name, i in LABEL2IDX.items()}


class ProductCountDataset(Dataset):
    def __init__(self, df, images_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['main_img']
        if not os.path.isabs(img_path):
            img_path = os.path.join(self.images_dir, img_path)
        img = Image.open(img_path).convert('RGB')
        x = self.transform(img)
        y = LABEL2IDX[row['label']]
        return x, y


df = pd.read_csv('your_data.csv')
train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))

train_ds = ProductCountDataset(train_df, images_dir=IMAGES)
val_ds = ProductCountDataset(val_df, images_dir=IMAGES)
test_ds = ProductCountDataset(test_df, images_dir=IMAGES)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2)


class ProductCountClassifier(nn.Module):
    def __init__(self, backbone_name='efficientnet_b0', num_classes=3):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0, global_pool='avg')
        self.fc = nn.Linear(self.backbone.num_features, num_classes)
    
    def forward(self, x):
        x = self.backbone(x)
        return self.fc(x)


def train_epoch(model, dl, optimizer, criterion, device):
    model.train()
    losses, corrects, total = [], 0, 0
    all_preds, all_labels = [], []
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        pred = out.argmax(dim=1)
        corrects += (pred == y).sum().item()
        total += y.size(0)
        all_preds.extend(pred.detach().cpu().numpy())
        all_labels.extend(y.detach().cpu().numpy())
    
    acc = corrects / total
    return np.mean(losses), acc, np.array(all_preds), np.array(all_labels)

def eval_epoch(model, dl, criterion, device):
    model.eval()
    losses, corrects, total = [], 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in dl:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            losses.append(loss.item())
            pred = out.argmax(dim=1)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            corrects += (pred == y).sum().item()
            total += y.size(0)
    acc = corrects / total
    return np.mean(losses), acc, np.array(all_preds), np.array(all_labels)


def plot_metrics(history):
    plt.figure(figsize=(14, 10))
    
    # Loss plot
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Accuracy plot
    plt.subplot(2, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # F1 Score plot
    plt.subplot(2, 2, 3)
    plt.plot(history['train_f1'], label='Train')
    plt.plot(history['val_f1'], label='Validation')
    plt.title('F1 Score (Macro)')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.draw()
    plt.pause(0.1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProductCountClassifier('efficientnet_b0', num_classes=3).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'train_f1': [],
    'val_f1': []
}

plt.figure(figsize=(14, 10))

for epoch in range(10):
    train_loss, train_acc, train_preds, train_labels = train_epoch(
        model, train_dl, optimizer, criterion, device
    )
    val_loss, val_acc, val_preds, val_labels = eval_epoch(
        model, val_dl, criterion, device
    )
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['train_f1'].append(f1_score(train_labels, train_preds, average='macro'))
    history['val_f1'].append(f1_score(val_labels, val_preds, average='macro'))
    
    plot_metrics(history)
    
    print(f"Epoch {epoch+1}/10")
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {history['train_f1'][-1]:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {history['val_f1'][-1]:.4f}\n")


model.load_state_dict(torch.load('product_count_model/model.pt'))
test_loss, test_acc, test_preds, test_labels = eval_epoch(model, test_dl, criterion, device)
test_f1 = f1_score(test_labels, test_preds, average='macro')

print("\nTest Results:")
print(f"Loss: {test_loss:.4f} | Accuracy: {test_acc:.4f} | F1: {test_f1:.4f}")
print(classification_report(test_labels, test_preds, target_names=CLASS_NAMES))
print("Confusion Matrix:\n", confusion_matrix(test_labels, test_preds))

plt.show()

''' ИНФЕРЕНС
from torchvision import transforms


def is_multi(model_dir, img_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = ProductCountClassifier('efficientnet_b0', num_classes=3).to(device)
    model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt'), map_location=device))
    model.eval()
    tfm = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    img = Image.open(img_path).convert('RGB')
    x = tfm(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(x)
        probs = torch.softmax(out, dim=1).cpu().numpy()[0]
    return {name: float(probs[i]) for i, name in enumerate(CLASS_NAMES)}


def is_multi_packimg(model_dir, img_pathes):
    probs_sum = np.zeros(3, dtype=np.float32)
    for img_path in img_pathes:
        probs = is_multi(model_dir, img_path)
        probs_sum += np.array([probs[name] for name in CLASS_NAMES])
    probs_avg = probs_sum / len(img_pathes)
    return {name: float(probs_avg[i]) for i, name in enumerate(CLASS_NAMES)}
'''