In [2]:
# ConvNeXtV2 + F1-критерий + Augmentations

import torch
import torch.nn as nn
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

# ==== Augmentations ====
def get_transforms():
    return A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=0.5),
        A.Normalize(),
        ToTensorV2()
    ])

# ==== Dataset ====
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        image = self.transform(image=image)["image"]
        label = self.labels[idx]
        return image, label

# ==== Model ====
class F1ConvNeXtModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = timm.create_model("convnextv2_base.fcmae_ft_in1k", pretrained=True)
        in_features = self.backbone.head.in_features
        self.backbone.reset_classifier(0)
        self.head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        return self.head(x)

# ==== F1 Loss ====
class F1Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-7

    def forward(self, logits, labels):
        y_pred = torch.softmax(logits, dim=1)
        y_true = torch.eye(logits.size(1))[labels].to(logits.device)

        tp = (y_true * y_pred).sum(dim=0)
        precision = tp / (y_pred.sum(dim=0) + self.eps)
        recall = tp / (y_true.sum(dim=0) + self.eps)
        f1 = 2 * precision * recall / (precision + recall + self.eps)
        return 1 - f1.mean()

# ==== Train loop ====
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(dataloader)

# ==== Evaluate loop ====
def evaluate(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())

    return f1_score(all_labels, all_preds, average="macro")


In [8]:
!pip install albumentations

[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/anyio-3.7.1-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/dirsearch-0.4.3-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/requests_ntlm-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages