# Оценка качества изображений

**Постановка задачи**: Фотографии загружаемые поставщиками WB имеют разное качество: На одних может быть сложный фон, на каких-то фотографиях часть объекта не попала в кадр и.т.п. Для последующей работы с такими данными, например при использовании алгоритмов поиска по фото надо знать типs дефектов/овособенностей которые присутствуют на изображении.

<img src ="https://ml.gan4x4.ru/wb/quality/content/samples.png" width="800">


Всего 6 типов особенностей:  

* untidy,
* angle-composition,
* background,
* crop,
* text,
* multiple-objects

и один класс для изображений без дефектов
* good-image.

При этом изображение может содержать несколько видов дефектов.


Задача:

Требуется создать модель которая будет определять список дефектов для для изображения.

# Данные

По [ссылке](https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip) доступен архив содержащий 5000 изображений и разметку.

Оригинальные изображения имели размер 900x1200 в датасете их разрешение уменьшено вдвое. Кроме изображений в архиве находиться csv файл c разметкой.
В первой колонке имя файла с изображением (без расширения), в остальных колонках названия классов к которым относиться изображение:

```
  18715,text,multiple-objects,,
  5259,text,background,,
  8932,background,,,
  ...

```

# Порядок выполнения задания

Задание рекомендуется выполнять по шагам:

1. Познакомьтесь с данными
2. Выберите метрику для оценки результата
3. Проведите анализ состояния вопроса, изучите существующие модели которые можно использовать для решения задачи
4. Проведите EDA, опишите особенности данных и проблемы которые они могут за собой повлечь
5. Подготовьте данные для обучения
6. Выберите baseline модель, оцените качество её работы на данном датасете.
7. Попробуйте улучшить значение метрики используя другую модель. Возможно обучив/дообучив ее.
8. Оцените быстродействие выбранной модели
9. Дайте оценку полученному результату.


**Важно!**

Блокнот должен содержать весь необходимый код для запуска финальной модели. Если для запуска требуется подгрузка весов, все ссылки длжны работать не только в вашем аккаунте но и в аккаунте преподавателя.

## Import libraries | Set Configs | Implementation classes

In [None]:
# --------------Libraries --------------- #
import os, gc
import io
from IPython.display import clear_output
from contextlib import redirect_stdout
import time,random
from typing import Optional, Tuple, Union, Callable

import numpy as np
import pandas as pd
import pathlib
import inspect
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm.notebook import tqdm

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, hamming_loss, ConfusionMatrixDisplay
from mlcm import mlcm


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader , random_split
from torch.utils.data.sampler import SubsetRandomSampler
from pytorch_multilabel_balanced_sampler.samplers import LeastSampledClassSampler

from torchvision.transforms import v2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.datasets import vision
from torchsummary import summary

from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors


import timm
import wandb

from pytorch_lightning import LightningDataModule, LightningModule, Trainer 
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

torch.set_float32_matmul_precision('medium')
print("Done!")

In [None]:
# ---Configuration class contains training configs---- #
#                                                      #
#                  &  model configs                    #
# ---------------------------------------------------- #
class Config:
        BATCH_SIZE = 52
        SEED = 42
        LEARNING_RATE = 0.00878938382303729
        EPS = 0.005663097332083254
        WD = 0.08157514864673669
        TYPE  = "Multi-label Image Classification"
        DATA_SOURCE = "multi-label-image-classification-dataset"
        MODEL_NAME = 'resnetv2_50'
        CRITERION_ = "Binary Cross Entropy"
        OPTIMIZER_ = "AdamW"
        DATA_TYPE = 'image'
        
        EPOCHS = 100
        NUM_WORKERS=8

        PROJECT_NAME='WB intership'
        TASK_NAME='multi-label-classification'

        IMG_SIZE = (440,440)

        mean = [0.5061, 0.4890, 0.4901]
        std  = [0.4247, 0.4200, 0.4184]
       
        
        def __init__(self):
            print("configuration set!")
        
        def check_cuda(self):
            print("Scanning for CUDA")
            if torch.cuda.is_available():
                print("GPU is available , training will be accelerated! : )\n")
            else:
                print("NO GPUs found : / \n")
        
        def seed_everything(self):
            print("Seeding...")
            np.random.seed(self.SEED)
            random.seed(self.SEED)
            os.environ['PYTHONHASHSEED'] = str(self.SEED)
            torch.manual_seed(self.SEED)
            torch.cuda.manual_seed(self.SEED)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            print("Seeded everything!")

Config.train_augmentations = v2.RandomApply([
                v2.RandomHorizontalFlip(p=0.5),v2.RandomVerticalFlip(p=0.5),
                v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
                v2.RandomPerspective(distortion_scale=0.4, p=0.5),
                # v2.ElasticTransform(alpha=250.0, sigma=10),
                # v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
                v2.RandomPosterize(bits=2),
            ])
Config.train_transforms = v2.Compose([
                Config.train_augmentations,
                v2.Resize(Config.IMG_SIZE),
                v2.Normalize(Config.mean, Config.std),
])
Config.test_transforms = v2.Compose([
            v2.Resize(Config.IMG_SIZE),
            v2.Normalize(Config.mean, Config.std),
        ])

CFG = Config()


CFG.check_cuda()
CFG.seed_everything()
config = dict(inspect.getmembers(CFG, lambda a:not(inspect.isroutine(a))))

In [None]:
class Wildberries5000(vision.VisionDataset):
    """`Wildberries products <https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip>`_ Dataset.

    Args:
        root (str or ``pathlib.Path``): Root directory of dataset where directory
            ``student_5000.zip`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    url = "https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip"
    filename = "student_5000.zip"
    data_format = '.jpg'
    

    def __init__(
        self,
        root: Union[str, pathlib.Path] = '',
        train: bool = True,
        split: float = 0.3,
        transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__(root, transform=transform)
        self.train_csv_path = pathlib.Path(self.root, "./5000/5000.csv")
        self.train_dir = pathlib.Path(self.root, "./5000/images/")
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        index = (f'{self.train_dir}/' \
            + pd.read_csv(self.train_csv_path, header=None, usecols=[0], dtype=str) \
            + Wildberries5000.data_format)[0]

        targets = pd.read_csv(self.train_csv_path, header=None, usecols = [1,2,3,4], keep_default_na=False)
        targets = targets.apply(' '.join , axis=1)
        targets = targets.str.strip()

        self.full_dataset: pd.DataFrame = pd.DataFrame({'index': index, 'targets':targets})
        self.full_dataset = self.full_dataset.loc[~self.full_dataset["index"].isin(self.data_sanity_check())]

        self.targets_list = targets.apply(lambda x: x.split())

        self.mlb = MultiLabelBinarizer().fit(self.targets_list)
        self.one_hot_targets = pd.DataFrame(self.mlb.transform(self.targets_list), columns=self.mlb.classes_)
        self.full_dataset[self.one_hot_targets.columns] =  self.one_hot_targets
        self.prepared = self.full_dataset.drop(['targets','good-image'], axis=1)

        if split:
            train_data, test_data = random_split(self.full_dataset, lengths=[1-split, split])
            self.data = train_data if self.train else test_data
        else:
            self.data = self.full_dataset
        


    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
            Args:
                idx (int): Index

            Returns:
                tuple: (image, target) where target is one-hot-encoded vector of the target classes.
        """
        img = cv2.imread(self.prepared.iloc[idx,0])
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = img / 255.
        img = torch.Tensor(img).permute([2,1,0]) # convert to HWC
        
        label = self.prepared.iloc[idx, 1:].astype(np.int8).values
        label = torch.Tensor(label)

        if self.transform:                                
            return self.transform(img), label
        return img, label
    

    def __len__(self) -> int:
        return len(self.prepared)

    def _check_integrity(self) -> bool:
        return os.path.isfile(pathlib.Path(self.root, Wildberries5000.filename))

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded")
            return

        pathlib.Path(self.root).mkdir(parents=True, exist_ok=True)
        os.system(f'wget {Wildberries5000.url} -o {pathlib.Path(self.root, self.filename)}')
        os.system(f'unzip {Wildberries5000.filename} -d {self.root}')
        

    def extra_repr(self) -> str:
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
    
    def data_sanity_check(self):
        """
            this will check each image file for corrupted or missing and 
            returns index of corrupted / missing files .Doing this will
            prevent us from running into any data errors during training phase .
        """
        idx = []
        start = time.time()
        for i in range(len(self.full_dataset)):
            try:#       checks for corrupeted or missing image files
                if len(cv2.imread(self.full_dataset.iloc[i,0])) == 3:
                    _ = 1
            except:
                idx.append(self.full_dataset.iloc[i,0])
        end = time.time()
        print(end-start)
        _ = gc.collect()
        print(idx)
        return idx

In [None]:
class WBData(LightningDataModule):
    
    def __init__(self, dataset_class: torch.utils.data.Dataset, 
                 batch_size=CFG.BATCH_SIZE, 
                 split=0.3, # None - is inference mode
                 train_transform=None, val_transform=None):
        super().__init__()
        
        params = dict(root=pathlib.Path(os.getcwd()) / "content", 
                      download=True)
        
        if split:
            self.train = dataset_class(train=True, transform=train_transform, split=split, **params)
            self.train_sampler = LeastSampledClassSampler(labels=torch.tensor(self.train.one_hot_targets.values, dtype=torch.int32), 
                                                      indices=list(self.train.data.indices))
            
            dev = dataset_class(train=False, transform=val_transform, split=split, **params)
            self.val, self.test = random_split(dev, lengths=[1-split, split])
            self.val_sampler = SubsetRandomSampler(list(self.val.indices))
            self.test_sampler = SubsetRandomSampler(list(self.test.indices))
        else:
            self.test = dataset_class(transform=val_transform, split=split, **params)
            self.test_sampler = SubsetRandomSampler(list(self.test.full_dataset.index))
        
        self.batch_size = batch_size
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train, sampler=self.train_sampler,
                          batch_size=self.batch_size, num_workers=CFG.NUM_WORKERS)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val.dataset, sampler=self.val_sampler,
                          batch_size=self.batch_size, num_workers=CFG.NUM_WORKERS)
        
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test.dataset, sampler=self.test_sampler, 
                                           batch_size=1, num_workers=CFG.NUM_WORKERS)

## Анализ данных (EDA)

In [None]:
df = WBData(Wildberries5000, CFG.BATCH_SIZE)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))
df.train.one_hot_targets.sum().plot.bar(title='Target Class Distribution')

- 60% составляют изображения без особенностей
- по 20% от всего датасета составляют изображения с фоном или с текстом
- остальные 40% изображений имеют особенности "композиции", "обрезки", "множества объектов" и "неаккуратности" (~= по 5%)

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))
df.train.one_hot_targets.sum(axis=1).value_counts().plot.bar(title='Distribution of Number of Labels per Image')

- 88% изображений присвоена 1 метка (72% из которых принадлежат изображениям "без особенностей")
- только 1% изображений имеет по 2 метки
- 1,5% изображений имеют по 3 метки
- 0,08% изображений присвоено 4 метки

In [None]:
def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = v2.functional.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = v2.functional.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [None]:
def plot_img_grid(samples, df, classes):

    for cl in classes:
        
        df_cl = df.loc[df[cl] == 1].iloc[:samples]
        for idx in range(len(df_cl)):
            orig_img = Image.open(df_cl.iloc[idx,0])
            
            transform = CFG.train_augmentations
            img = [transform(orig_img) for _ in range(samples)]
            plot([orig_img] + img, ['\n'.join(df_cl.iloc[idx,1].split())])
            

plot_img_grid(4, df.train.full_dataset, classes=df.train.mlb.classes_)

In [None]:
def plot_orig_imgs(rows, df, classes):
    fig = plt.figure(figsize=(3.5*len(classes), 3*rows))

    for cl_idx, cl in enumerate(classes):
        df_cl = df.loc[df[cl] == 1].iloc[:rows]
        for idx in range(len(df_cl)):
            img = cv2.cvtColor(cv2.imread(df_cl.iloc[idx,0]), cv2.COLOR_BGR2RGB)
            ax = fig.add_subplot(rows, len(classes), idx*len(classes)+cl_idx + 1)
            ax.title.set_text('\n'.join(df_cl.iloc[idx,1].split()))
            plt.xticks([]) ; plt.yticks([]) 
            plt.imshow(img)
    plt.show()

plot_orig_imgs(4, df.train.full_dataset, classes=df.train.mlb.classes_)

1. angle-composition - включает фотографии позирующих моделей (в т.ч. снятых сзади), либо одежды/обуви, снятой с нетипичного ракурса (сверху, сзади, на ноге)
2. background - фотографии, снятые не на белом фоне - т.е. с естественным фоном комнаты/улицы или с искуственным фоном
3. crop - включает фотографии, где обрезан сам товар (на модели или без нее)
4. good-image - модели стоят прямо перед камерой, фон белый, товар показан полностью (или в случае обуви наполовину -  1 ботинок, 1 босоножка). Среди размеченных этого класса фото были найдены фото с серым фоном, с позирующими моделями 
5. multiple-objects - фотографии с множеством объектов (одного и того же товара в разных ракурсах, одного товара в разной расцветке, на разных моделях), а также с миниатюрой/ами в углу фотографии этого товара
6. text - фотографии, содержащие текст, логотипы, значки
7. untidy - неопрятный товар на фотографии (непроглаженный, неаккуратно уложенный) или снятый не в студийных условиях/при плохом освещении

In [None]:
%%capture --no-display
from upsetplot import from_memberships, UpSet

ax_dict = UpSet(from_memberships(df.train.targets_list, data=df.train.full_dataset), 
                subset_size="count", show_counts=True).plot()

Топ 5 самый частых сочетаний меток на изображениях:
- 3165 - good_image
- 508 - text
- 456 - background
- 282 - text+background
- 112 - untidy

## Baseline

### Calculate mean/std

In [None]:
class ImageData(Wildberries5000):

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_path = self.prepared.iloc[idx, 0]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)['image']
        return img


In [None]:
def calc_mean_std(image_size, batch_size):
    image_dataset = ImageData(transform=A.Compose(
    [
        A.Resize(height=image_size[0], width=image_size[1]),
        A.Normalize(mean=(0, 0, 0), std=(1, 1, 1)),
        ToTensorV2(),
    ]), split=0.3, root=pathlib.Path(os.getcwd()) / "content")

    image_loader = DataLoader(
        dataset=image_dataset,
        batch_size=batch_size,
        sampler=LeastSampledClassSampler(labels=torch.tensor(image_dataset.one_hot_targets.values, 
                                                            dtype=torch.int32), 
                                        indices=list(image_dataset.data.indices)),
        num_workers=CFG.NUM_WORKERS,
        pin_memory=True,
    )

    ####### COMPUTE MEAN / STD

    # placeholders
    psum = torch.tensor([0.0, 0.0, 0.0])
    psum_sq = torch.tensor([0.0, 0.0, 0.0])

    # loop through images
    for inputs in tqdm(image_loader):
        psum += inputs.sum(axis=[0, 2, 3])
        psum_sq += (inputs**2).sum(axis=[0, 2, 3])

    ####### FINAL CALCULATIONS

    # pixel count
    count = len(image_dataset) * image_size[0] * image_size[1]

    # mean and std
    total_mean = psum / count
    total_var = (psum_sq / count) - (total_mean**2)
    total_std = torch.sqrt(total_var)

    # output
    return total_mean, total_std
   

In [None]:
def get_data(batch_size):
    
    print("mean: " + str(CFG.mean))
    print("std:  " + str(CFG.std))
    
    return WBData(Wildberries5000, batch_size, 
                  train_transform=CFG.train_transforms, 
                  val_transform=CFG.test_transforms)

In [None]:
total_mean, total_std = calc_mean_std(CFG.IMG_SIZE, CFG.BATCH_SIZE)
print("mean: " + str(total_mean))
print("std:  " + str(total_std))
# total_mean, total_std = [0.5061, 0.4890, 0.4901], [0.4247, 0.4200, 0.4184]

In [None]:
df = get_data(CFG.BATCH_SIZE)
df.train

In [None]:
# print("================ Training phase ===============")
def class_distribution(df, max_test=3):
    count = 0
    for batch in df.train_dataloader():
        if count > max_test: break
        else: count+=1
        labels = batch[1]
        print("Label counts per class:")
        
        good_images_count = labels.sum(axis=1).unique(return_counts=True)[1][0]
        
        sum_ = list(labels.sum(axis=0))
        sum_.append(good_images_count)
        print(sum_)
        print("Difference between min and max")
        print(f"{max(sum_)} - {min(sum_)}: {max(sum_) - min(sum_)}", end="\n\n")

class_distribution(df, max_test=3)

### Model

In [None]:
class MLCNNet(nn.Module):
    
    def __init__(self, backbone, frozen, n_classes, n_features, p_dropout):
        super(MLCNNet,self).__init__()
        self.model = backbone
        if frozen:
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Additional linear layer and dropout layer
        self.classifier = nn.Sequential(nn.LazyLinear(n_features),
                                            nn.ReLU(),
                                            nn.Dropout(p_dropout),
                                            nn.Linear(n_features, n_classes))
       
    def forward(self,x):
        x = self.model(x)
        x = self.classifier(x)
        return x

In [None]:
class LitMLCNet(LightningModule):
    
    def __init__(self, model, config, logger=None):
        super().__init__()
        self.model = model
        self.CFG = config
        if logger: self.custom_logger = logger

    def forward(self, X):
        return self.model(X)

    def eval_loss(self, batch, batch_idx, mode):
        x,y = batch
        outputs = self.model(x)
        loss = F.binary_cross_entropy_with_logits(outputs, y)
        self.log(f"loss/{mode}", loss.item() / len(y), prog_bar=True, on_epoch=True)
        return loss
    
    def eval_accuracy(self, batch, batch_idx, mode):
        X,y = batch
        Out  = self(X)
        y_hat = torch.sigmoid(Out).round()
        accuracy = accuracy_score(y.detach().cpu(), y_hat.detach().cpu())
        self.log(f"accuracy/{mode}", accuracy, prog_bar=True, on_epoch=True)
        return accuracy

    def hamming_loss(self, batch, batch_idx, mode):
        X,y = batch
        Out  = self(X)
        y_hat = torch.sigmoid(Out).round()
        loss = hamming_loss(y.detach().cpu(), y_hat.detach().cpu())
        self.log(f"hamming_loss/{mode}", loss, prog_bar=True, on_epoch=True)
        return loss
    
    def training_step(self, batch, batch_idx):
        self.eval_accuracy(batch, batch_idx, 'train')
        self.hamming_loss(batch, batch_idx, 'train')
        return self.eval_loss(batch, batch_idx, 'train')
    
    def validation_step(self,batch,batch_idx):
        self.eval_loss(batch, batch_idx, 'val')
        self.eval_accuracy(batch, batch_idx, 'val')
        self.hamming_loss(batch, batch_idx, 'val')
    
    def test_step(self, batch, batch_idx):
        self.eval_accuracy(batch, batch_idx, 'test')
        self.hamming_loss(batch, batch_idx, 'test')

    def confusion_matrix(self, trainer, df):
       
        preds_labels = trainer.predict(self, df.test_dataloader())

        preds,labels =[],[]
        for item in preds_labels:
            preds.append(torch.round(torch.sigmoid(item[0][0])).detach().cpu().numpy().tolist())
            labels.append(item[1][0].detach().cpu().numpy().tolist())

        # Матрица ошибок с помощью mlcm:
        conf_mat, normal_conf_mat = mlcm.cm(labels, preds)

        fig = plt.figure(figsize=(20, 20))

        ax1 = fig.add_subplot(2,2,1)
        matrix_classes = np.delete(df.train.mlb.classes_, np.where(df.train.mlb.classes_ == 'good-image'))
        matrix_classes = np.append(matrix_classes, 'good_image')
        disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=matrix_classes)
        disp.plot(cmap=plt.cm.Blues, xticks_rotation=90, ax=ax1)
        plt.title('Raw confusion Matrix:')

        ax2 = fig.add_subplot(2,2,2)
        disp = ConfusionMatrixDisplay(confusion_matrix=normal_conf_mat, display_labels=matrix_classes)
        disp.plot(cmap=plt.cm.Blues, xticks_rotation=90, ax=ax2)
        plt.title('Normalized confusion Matrix (%)')
        plt.show()

        self.statistics(conf_mat, matrix_classes)


    def statistics(self, conf_mat, matrix_classes):
        f = io.StringIO()
        with redirect_stdout(f):
            bins_conf_matrix = mlcm.stats(conf_mat, print_binary_mat=False)
        out = f.getvalue()

        stats = [row.split() for row in out.split('\n') if len(row) > 0]
        stats[-3:] = [['-'.join(row[:2])] + row[2:] for row in stats[-3:]]
        stats = np.array(stats)
        index = np.concatenate([matrix_classes, stats[-3:,0]])

        stats = pd.DataFrame(stats[1:, 1:], columns=stats[0][1:], index=index)
        stats.style.set_properties(subset=['precision'], **{'width': '25px'})
        if hasattr(self, 'custom_logger'): 
            self.custom_logger.log_table("statistics_table", dataframe=stats)
        print(stats)
        self.confusion_matrix_per_classes(bins_conf_matrix, matrix_classes)

    def confusion_matrix_per_classes(self, bins_conf_matrix, matrix_classes):
        fig = plt.figure(figsize=(20, 20))
        for idx, matrix in enumerate(bins_conf_matrix, 1):
            ax = fig.add_subplot(4,4,idx)
            ax.title.set_text(matrix_classes[idx-1])
            confusion_matrix = ConfusionMatrixDisplay(confusion_matrix=matrix)
            confusion_matrix.plot(ax=ax, cmap=plt.cm.Blues)
        plt.show()

    def predict_step(self, batch, batch_idx=0,dataloader_idx=0):
        X, y = batch
        Out = self(X)
        return Out, y
    
    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters(), lr=self.CFG.LEARNING_RATE, 
                                  eps=self.CFG.EPS, weight_decay=self.CFG.WD)
        return optim
    
timm.list_models("resnet*")

In [None]:
def train(config=None, pl_Model=None):
    assert pl_Model
    with wandb.init(config=config) as run:
        if not config: config = wandb.config
        logger = WandbLogger(project=CFG.PROJECT_NAME)
        pl_Model.custom_logger = logger
        trainer = Trainer(
            # max_epochs=3,
            accelerator='gpu',
            devices=1,
            callbacks=[EarlyStopping(monitor="loss/train_epoch", mode="min", patience=10)],
            log_every_n_steps=5,
            logger=logger
        )
        df = get_data(CFG.BATCH_SIZE)
        trainer.fit(model=pl_Model,
                    train_dataloaders=df.train_dataloader(),
                    val_dataloaders=df.val_dataloader())
        os.makedirs("./weights/", exist_ok=True)
        torch.save(pl_Model.model.state_dict(), 'weights/model.pth')
        artifact = wandb.Artifact('model', type='model')
        artifact.add_file('weights/model.pth')
        run.log_artifact(artifact)

        trainer.test(pl_Model, df.test_dataloader())
        pl_Model.confusion_matrix(trainer, df)
    return trainer, pl_Model, logger

In [None]:
timm.list_models('*mobilenet*', pretrained=True)

In [None]:
CFG.MODEL_NAME = 'resnetv2_50'
backbone = timm.create_model(CFG.MODEL_NAME, pretrained=True, num_classes=0) 
model = MLCNNet(backbone, True, 6, 
                512, 0.5)
pl_Model = LitMLCNet(model, CFG)
pl_Model

In [None]:
pl_Model.to('cuda')
summary(pl_Model, (3, *CFG.IMG_SIZE))

In [None]:
trainer, pl_Model, logger = train(CFG, pl_Model)

In [None]:
##############---OPTIMIZATION-HP---##############
# !export WANDB_NOTEBOOK_NAME='EX2_Quality.ipynb'

sweep_config = {
    'method': 'bayes',
    'name': 'optimize',
    'project':CFG.PROJECT_NAME,
    'metric': {
        'name': 'accuracy/val',
        'goal': 'maximize'
    },
    'parameters': {
        'epochs': {
            'value': 5
            },
        'learning_rate': {
            # a flat distribution between 0 and 0.1
            'distribution': 'uniform',
            'min': 0,
            'max': 0.01
            },
        'eps': {
            # a flat distribution between 0 and 0.1
            'distribution': 'uniform',
            'min': 0,
            'max': 1e-2
            },
        'wd': {
            # a flat distribution between 0 and 0.1
            'distribution': 'uniform',
            'min': 0,
            'max': 0.1
            },
        'batch_size': {
            # integers between 32 and 256
            # # with evenly-distributed logarithms 
            'distribution': 'q_log_uniform_values',
            'q': 8,
            'min': 32,
            'max': 256,
            },
        'image_size': {
            # integers between 32 and 256
            # # with evenly-distributed logarithms 
            'distribution': 'q_log_uniform_values',
            'q': 8,
            'min': 32,
            'max': 512,
            },
        'fc_layer_size': {
            'values': [128, 256, 512]
            },
        'dropout': {
            'values': [0.3, 0.4, 0.5, 0.6]
            },
    }
}

sweep_id = wandb.sweep(sweep_config)
wandb.agent(sweep_id, function=train, count=150)

## Метрики

## Решение

In [None]:
CFG.MODEL_NAME = "mobilenetv2_050.lamb_in1k"
backbone = timm.create_model(CFG.MODEL_NAME, pretrained=True, num_classes=0) 
model = MLCNNet(backbone, False, 6, 
                512, 0.5)
pl_Model = LitMLCNet(model, CFG)
pl_Model


In [None]:
pl_Model.to('cuda')
summary(pl_Model, (3, *CFG.IMG_SIZE))

In [None]:
trainer, pl_Model, logger = train(CFG, pl_Model)

## Оценка результата

In [None]:
def test_image(df, idx):
    img = cv2.imread(df.iloc[idx,0])
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img_prepared = img / 255.
    img_prepared = torch.Tensor(img_prepared).permute([2,1,0]) # convert to HWC

    label = df.iloc[idx, 1:].astype(np.int8).values
    label = torch.Tensor(label)
    transform = CFG.test_transforms
    X = transform(img_prepared)
    pred = torch.sigmoid(pl_Model(X.unsqueeze(0))[0].detach().cpu()).round().numpy().tolist()
    return img, label.numpy().tolist(), pred

def print_test(classes, df_test):
    len_classes = len(classes)
    fig = plt.figure(figsize=(len_classes*3, len_classes*4))
    fig.suptitle('true:\npredicted:', fontsize=16)
    for cl_idx, cl in enumerate(classes):
        df_cl = df_test.loc[df_test.loc[:,cl] == 1]
        for idx in range(len_classes):
            img, label, pred = test_image(df_cl, idx)
            
            ax = fig.add_subplot(len_classes, len_classes, idx*len_classes+cl_idx + 1)
            true_labels, predicted_labels = classes[list(map(bool,label))], classes[list(map(bool,pred))]
            true_labels, predicted_labels = list(true_labels) if len(true_labels) else ['good-image'], \
                                            list(predicted_labels) if len(predicted_labels) else ['good-image']
            if true_labels == predicted_labels:
                img_title = '\n'.join(true_labels)
                plt.setp(ax.title, color='g')
            else:
                img_title = '\n'.join(true_labels) \
                            + "\n---\n" \
                            + '\n'.join(predicted_labels)
                plt.setp(ax.title, color='r')
            ax.title.set_text(img_title)
            plt.xticks([]) ; plt.yticks([]) 
            plt.imshow(img)
    plt.tight_layout()
    plt.show()



In [None]:
import wandb
with wandb.init() as run:
    model_v = 'model:v15'
    artifact = run.use_artifact(f'luzinsan/intership-ex2/{model_v}', type='model')
    artifact_dir = artifact.download()

    backbone = timm.create_model("resnetv2_50", pretrained=True, num_classes=0)
    Model = MLCNNet(backbone, False, 6, 512, 0.5)
    logger = WandbLogger(project=CFG.PROJECT_NAME)
    pl_Model = LitMLCNet(Model, CFG, logger)
    pl_Model.model.load_state_dict(torch.load(f'artifacts/{model_v}/model.pth'))
    pl_Model.eval()

    trainer = Trainer(
                accelerator='gpu',
                devices=1,
                log_every_n_steps=5,
                logger=logger)
    df = get_data(CFG.BATCH_SIZE)
    clear_output()
    
    print("Test Size: ", len(df.test))
    trainer.test(pl_Model, df.test_dataloader())
    pl_Model.confusion_matrix(trainer, df)
    pl_Model.eval()
    classes = np.delete(df.train.mlb.classes_, np.where(df.train.mlb.classes_ == 'good-image'))
    df_test = df.test.dataset.prepared.iloc[df.test.indices]
    print_test(classes, df_test)

## Вывод

- 

# Тестовый блок для проверки

Поместите сюда весь необходимый код для тестирования вашей модели на новых данных. Убедитесь что

- Импортируются все библиотеки и классы
- Подгружабтся веса с внешних ресурсов
- Происходит рассчет метрик
...

In [None]:
##################---LIBRARIES---##################
from torchvision.datasets import vision
from typing import Union, Optional, Callable, Tuple
import pathlib
import torch
from torch import nn
from lightning import LightningDataModule
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, hamming_loss, ConfusionMatrixDisplay
from mlcm import mlcm
import matplotlib.pyplot as plt

from torch.utils.data.sampler import SubsetRandomSampler
import lightning as pl
from torchvision.transforms import v2
import cv2
import os, gc, io
import pandas as pd
import numpy as np
from IPython.display import clear_output
from contextlib import redirect_stdout

import wandb
import timm, time

wandb.login(anonymous='allow')

In [None]:
##################---DATA---##################
class Wildberries5000(vision.VisionDataset):
    """`Wildberries products <https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip>`_ Dataset.

    Args:
        root (str or ``pathlib.Path``): Root directory of dataset where directory
            ``student_5000.zip`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    url = "https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip"
    filename = "student_5000.zip"
    data_format = '.jpg'
    

    def __init__(
        self,
        root: Union[str, pathlib.Path] = '',
        transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__(root, transform=transform)
        self.csv_path = pathlib.Path(self.root, "./5000/5000.csv")
        self.dir = pathlib.Path(self.root, "./5000/images/")
        

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        index = (f'{self.dir}/' \
            + pd.read_csv(self.csv_path, header=None, usecols=[0], dtype=str) \
            + Wildberries5000.data_format)[0]

        targets = pd.read_csv(self.csv_path, header=None, usecols = [1,2,3,4], keep_default_na=False)
        targets = targets.apply(' '.join , axis=1)
        targets = targets.str.strip()

        self.full_dataset: pd.DataFrame = pd.DataFrame({'index': index, 'targets':targets})
        self.full_dataset = self.full_dataset.loc[~self.full_dataset["index"].isin(self.data_sanity_check())]

        self.targets_list = targets.apply(lambda x: x.split())

        self.mlb = MultiLabelBinarizer().fit(self.targets_list)
        self.one_hot_targets = pd.DataFrame(self.mlb.transform(self.targets_list), columns=self.mlb.classes_)
        self.full_dataset[self.one_hot_targets.columns] =  self.one_hot_targets
        self.prepared = self.full_dataset.drop(['targets','good-image'], axis=1)

        self.data = self.full_dataset
        

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img = cv2.imread(self.prepared.iloc[idx,0])
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = img / 255.
        img = torch.Tensor(img).permute([2,1,0]) # convert to HWC
        
        label = self.prepared.iloc[idx, 1:].astype(np.int8).values
        label = torch.Tensor(label)

        if self.transform:                                
            return self.transform(img), label
        return img, label
    

    def __len__(self) -> int:
        return len(self.prepared)

    def _check_integrity(self) -> bool:
        return os.path.isfile(pathlib.Path(self.root, Wildberries5000.filename))

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded")
            return

        pathlib.Path(self.root).mkdir(parents=True, exist_ok=True)
        os.system(f'wget {Wildberries5000.url} -o {pathlib.Path(self.root, self.filename)}')
        os.system(f'unzip {Wildberries5000.filename} -d {self.root}')
        

    def extra_repr(self) -> str:
        split = "Test"
        return f"Split: {split}"
    
    def data_sanity_check(self):
        """
            this will check each image file for corrupted or missing and 
            returns index of corrupted / missing files .Doing this will
            prevent us from running into any data errors during training phase .
        """
        idx = []
        start = time.time()
        for i in range(len(self.full_dataset)):
            try:#       checks for corrupeted or missing image files
                if len(cv2.imread(self.full_dataset.iloc[i,0])) == 3:
                    _ = 1
            except:
                idx.append(self.full_dataset.iloc[i,0])
        end = time.time()
        print(end-start)
        _ = gc.collect()
        print(idx)
        return idx
    
class WBData(LightningDataModule):
    
    def __init__(self, dataset_class: torch.utils.data.Dataset, 
                 batch_size=52, 
                 val_transform=None):
        super().__init__()
        
        params = dict(root=pathlib.Path(os.getcwd()) / "content", 
                      download=True)
        
        self.test = dataset_class(transform=val_transform, **params)
        self.test_sampler = SubsetRandomSampler(list(self.test.full_dataset.index))
        self.batch_size = batch_size
        
        
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test, sampler=self.test_sampler, 
                                           batch_size=1, num_workers=8)

In [None]:
##################---MODEL---##################
class MLCNNet(nn.Module):
    
    def __init__(self, backbone, frozen, n_classes, n_features, p_dropout):
        super(MLCNNet,self).__init__()
        self.model = backbone
        self.classifier = nn.Sequential(nn.LazyLinear(n_features),
                                            nn.ReLU(),
                                            nn.Dropout(p_dropout),
                                            nn.Linear(n_features, n_classes))
       
    def forward(self,x):
        x = self.model(x)
        x = self.classifier(x)
        return x
    

class LitMLCNet(pl.LightningModule):
    
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, X):
        return self.model(X)
    
    def eval_accuracy(self, batch, batch_idx, mode):
        X,y = batch
        Out  = self(X)
        y_hat = torch.sigmoid(Out).round()
        accuracy = accuracy_score(y.detach().cpu(), y_hat.detach().cpu())
        self.log(f"accuracy/{mode}", accuracy, prog_bar=True, on_epoch=True)
        return accuracy

    def hamming_loss(self, batch, batch_idx, mode):
        X,y = batch
        Out  = self(X)
        y_hat = torch.sigmoid(Out).round()
        loss = hamming_loss(y.detach().cpu(), y_hat.detach().cpu())
        self.log(f"hamming_loss/{mode}", loss, prog_bar=True, on_epoch=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        self.eval_accuracy(batch, batch_idx, 'test')
        self.hamming_loss(batch, batch_idx, 'test')

    def confusion_matrix(self, trainer, df):
        preds_labels = trainer.predict(self, df.test_dataloader())

        preds,labels =[],[]
        for item in preds_labels:
            preds.append(torch.round(torch.sigmoid(item[0][0])).detach().cpu().numpy().tolist())
            labels.append(item[1][0].detach().cpu().numpy().tolist())

        # Матрица ошибок с помощью mlcm:
        conf_mat, normal_conf_mat = mlcm.cm(labels, preds)

        fig = plt.figure(figsize=(20, 20))

        ax1 = fig.add_subplot(2,2,1)
        matrix_classes = np.delete(df.test.mlb.classes_, np.where(df.test.mlb.classes_ == 'good-image'))
        matrix_classes = np.append(matrix_classes, 'good_image')
        disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=matrix_classes)
        disp.plot(cmap=plt.cm.Blues, xticks_rotation=90, ax=ax1)
        plt.title('Raw confusion Matrix:')

        ax2 = fig.add_subplot(2,2,2)
        disp = ConfusionMatrixDisplay(confusion_matrix=normal_conf_mat, display_labels=matrix_classes)
        disp.plot(cmap=plt.cm.Blues, xticks_rotation=90, ax=ax2)
        plt.title('Normalized confusion Matrix (%)')
        plt.show()

        self.statistics(conf_mat, matrix_classes)


    def statistics(self, conf_mat, matrix_classes):
        f = io.StringIO()
        with redirect_stdout(f):
            bins_conf_matrix = mlcm.stats(conf_mat, print_binary_mat=False)
        out = f.getvalue()

        stats = [row.split() for row in out.split('\n') if len(row) > 0]
        stats[-3:] = [['-'.join(row[:2])] + row[2:] for row in stats[-3:]]
        stats = np.array(stats)
        index = np.concatenate([matrix_classes, stats[-3:,0]])

        stats = pd.DataFrame(stats[1:, 1:], columns=stats[0][1:], index=index)
        stats.style.set_properties(subset=['precision'], **{'width': '25px'})
        if hasattr(self, 'custom_logger'): 
            self.custom_logger.log_table("statistics_table", dataframe=stats)
        print(stats)
        self.confusion_matrix_per_classes(bins_conf_matrix, matrix_classes)

    def confusion_matrix_per_classes(self, bins_conf_matrix, matrix_classes):
        fig = plt.figure(figsize=(20, 20))
        for idx, matrix in enumerate(bins_conf_matrix, 1):
            ax = fig.add_subplot(4,4,idx)
            ax.title.set_text(matrix_classes[idx-1])
            confusion_matrix = ConfusionMatrixDisplay(confusion_matrix=matrix)
            confusion_matrix.plot(ax=ax, cmap=plt.cm.Blues)
        plt.show()

    def predict_step(self, batch):
        X, y = batch
        Out = self(X)
        return Out, y
    

def test_image(df, idx, test_transforms):
    img = cv2.imread(df.iloc[idx,0])
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img_prepared = img / 255.
    img_prepared = torch.Tensor(img_prepared).permute([2,1,0]) # convert to HWC

    label = df.iloc[idx, 1:].astype(np.int8).values
    label = torch.Tensor(label)
    transform = test_transforms
    X = transform(img_prepared)
    pred = torch.sigmoid(pl_Model(X.unsqueeze(0))[0].detach().cpu()).round().numpy().tolist()
    return img, label.numpy().tolist(), pred


def print_test(classes, df_test, test_transforms):
    len_classes = len(classes)
    fig = plt.figure(figsize=(len_classes*3, len_classes*4))
    fig.suptitle('true:\npredicted:', fontsize=16)
    for cl_idx, cl in enumerate(classes):
        df_cl = df_test.loc[df_test.loc[:,cl] == 1]
        for idx in range(len_classes):
            img, label, pred = test_image(df_cl, idx, test_transforms)
            
            ax = fig.add_subplot(len_classes, len_classes, idx*len_classes+cl_idx + 1)
            true_labels, predicted_labels = classes[list(map(bool,label))], classes[list(map(bool,pred))]
            true_labels, predicted_labels = list(true_labels) if len(true_labels) else ['good-image'], \
                                            list(predicted_labels) if len(predicted_labels) else ['good-image']
            if true_labels == predicted_labels:
                img_title = '\n'.join(true_labels)
                plt.setp(ax.title, color='g')
            else:
                img_title = '\n'.join(true_labels) \
                            + "\n---\n" \
                            + '\n'.join(predicted_labels)
                plt.setp(ax.title, color='r')
            ax.title.set_text(img_title)
            plt.xticks([]) ; plt.yticks([]) 
            plt.imshow(img)
    plt.tight_layout()
    plt.show()

In [None]:
with wandb.init() as run:
    # Загрузка тестовых данных
    Wildberries5000.url = "https://ml.gan4x4.ru/wb/quality/5000/student_5000.zip"
    Wildberries5000.filename = "student_5000.zip"
    Wildberries5000.data_format = '.jpg'

    test_transforms = v2.Compose([
            v2.Resize((224,224)),
            v2.Normalize([0.5061, 0.4890, 0.4901], 
                         [0.4247, 0.4200, 0.4184]),
        ])
    df = WBData(Wildberries5000, 32, 
                val_transform=test_transforms)
    
    # Загрузка весов модели
    model_v = 'model:v15'
    artifact = run.use_artifact(f'luzinsan/intership-ex2/{model_v}', type='model')
    artifact_dir = artifact.download()

    # Инициализация модели
    backbone = timm.create_model("resnetv2_50", num_classes=0)
    Model = MLCNNet(backbone, False, 6, 512, 0.5)
    pl_Model = LitMLCNet(Model)
    pl_Model.model.load_state_dict(torch.load(f'artifacts/{model_v}/model.pth'))
    pl_Model.eval()
    trainer = pl.Trainer()
    clear_output()

    # Вычисление accuracy и hamming_loss на тестовом датасете
    print("Test Size: ", len(df.test))
    trainer.test(pl_Model, df.test_dataloader())
    # матрица ошибок
    pl_Model.confusion_matrix(trainer, df)

    # посмотрим с картинками
    classes = np.delete(df.test.mlb.classes_, np.where(df.test.mlb.classes_ == 'good-image'))
    print_test(classes, df.test.prepared, test_transforms)