<b>Назначение:</b> <br>
Подготовка кастомного Dataset-класса и тестирование Dataloader-класса

In [1]:
import pandas as pd
import os 
import numpy as np
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import cv2
from transformers import AutoImageProcessor
import torch

from datasets import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

In [2]:
class CustomCarDataset(Dataset):
    def __init__(self, part_name, fronts_info_path, processor):
        data_info = pd.read_csv(fronts_info_path, sep=';')
        self._data = data_info.loc[data_info['part'] == part_name, :].reset_index(drop=True)
        self.uniq_labels = self._data['label'].unique()
        self.labels_map = {label: i for i, label in enumerate(self.uniq_labels)}
        self.processor = processor

    def __len__(self):
        return self._data.shape[0]

    def __getitem__(self, idx):
        image_path = f"{self._data['relative_path'][idx]}/{self._data['image_name'][idx]}"
        image = cv2.imread(image_path)[...,::-1]
        image = torch.tensor(self.processor(image)['pixel_values'][0])
        label = self.labels_map[self._data['label'][idx]]

        return image, label
    
    def __getitems__(self, idxs):
        return [self.__getitem__(idx) for idx in idxs]
        

def custom_collate(data):

    images = torch.cat([torch.unsqueeze(item[0], 0) for item in data], 0)
    labels = torch.tensor([item[1] for item in data])

    return {
        "images": images, 
        "labels": labels
    }

In [3]:
TT_INFO = './data/tt_union_fronts_info.csv'

In [5]:
resnet_processor

ConvNextImageProcessor {
  "crop_pct": 0.875,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ConvNextImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 224
  }
}

In [4]:
resnet_processor = AutoImageProcessor.from_pretrained("./base_models/microsoft_resnet50")
m2f_processor = AutoImageProcessor.from_pretrained("./base_models/facebook-m2f_swin_large")

In [31]:
# Datasets
train_dataset = CustomCarDataset('train', TT_INFO, m2f_processor)
eval_dataset = CustomCarDataset('eval', TT_INFO, m2f_processor)

In [32]:
# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=16, 
                              shuffle=True, 
                              collate_fn=custom_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=16,
                              collate_fn=custom_collate)

In [35]:
# TEST

print(next(iter(train_dataloader))['images'].shape)
for batch in tqdm(train_dataloader):
    break

print(next(iter(eval_dataloader))['images'].shape)
for batch in tqdm(eval_dataloader):
    break

torch.Size([16, 3, 384, 384])


  0%|          | 0/700 [00:00<?, ?it/s]


torch.Size([16, 3, 384, 384])


  0%|          | 0/175 [00:00<?, ?it/s]
