In [2]:
%%capture
!pip install albumentations

In [3]:
%%capture
!pip install --upgrade -q wandb

In [4]:
%%capture
!pip install lightning

In [5]:
import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import os
from PIL import Image
import numpy as np
import pandas as pd
from pathlib import Path
import albumentations as A
import lightning as l
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryPrecision, BinaryRecall
from lightning.pytorch.callbacks import TQDMProgressBar
from torchmetrics.detection import MeanAveragePrecision
from torchvision.ops import box_iou
import math

  check_for_updates()


In [6]:
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()

In [7]:
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login(key = '7ade9459940c133dce45cacda59977617e1ae315')

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmaximmalahovsky14[0m ([33mmaximmalahovsky14-ITMO University[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [8]:
class Yolov1Custom(nn.Module):
    def __init__(self, S, B, C):
        super().__init__()
        self.B = B
        self.C = C
        self.S = S
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size = 2),

            nn.Conv2d(64, 192, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size = 2),

            nn.Conv2d(192, 128, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 256, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 256, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size = 2),

            nn.Conv2d(512, 256, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            
            nn.Conv2d(512, 256, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 256, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 256, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 512, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size = 2),

            nn.Conv2d(1024, 512, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(1024, 512, kernel_size = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(1024, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, 1024, kernel_size = 3, stride = 2, padding = 1),
            nn.LeakyReLU(0.1),

            nn.Conv2d(1024, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, 1024, kernel_size = 3, padding = 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, 1024, kernel_size = 4, padding = 1),
            nn.LeakyReLU(0.1)
        )

        self.preds = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.S * self.S* 1024, 4096),
            nn.ReLU(),
            nn.Linear(4096, self.S *self.S * (self.C + self.B * 5 )),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.preds(x)
        return x.view(-1, self.S, self.S, 5 * self.B + self.C)

In [9]:
df = pd.DataFrame(columns = ['image_path', 'label_path'])

In [10]:
images = sorted(os.listdir('/kaggle/input/pigs-data/tmp/frames'))
labels = sorted(os.listdir('/kaggle/input/pigs-data/tmp/obj_Train_data'))

In [11]:
for index in range(len(images)):
    if index >= 520:
        break
    elif images[index].split('.')[0] == labels[index].split('.')[0]:
        df.loc[len(df)] = [images[index], labels[index]]
    else:
        print(images[index], labels[index])

In [12]:
class PigsDataset(Dataset):
    def __init__(self, df, S, B, C, transform = None):
        self.df = df
        self.S = S
        self.C = C
        self.B = B
        self.path_im = Path('/kaggle/input/pigs-data/tmp/frames')
        self.path_lab = Path('/kaggle/input/pigs-data/tmp/obj_Train_data')
        if transform is None:
            self.trasnform = A.Compose([
                A.Resize(480, 480),
                A.Normalize(),
                A.ToTensorV2()
            ])
    
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        image_path = self.df['image_path'][idx]
        image = Image.open(self.path_im / Path(image_path)).convert('RGB')
        im_arr = np.asarray(image)

        im_arr = self.trasnform(image = im_arr)['image']
        
        labels_path = self.df['label_path'][idx]
        labels_tensor = torch.zeros(self.S, self.S, (self.B * 5 + self.C))
        
        with open(self.path_lab / Path(labels_path), 'r') as file:
            for line in file.readlines():
                parts = [float(l) for l in line.split()]
                
                class_id = int(parts[0])
                x_center_norm = parts[1]
                y_center_norm = parts[2]
                width_norm = parts[3]
                height_norm = parts[4]

                x_grid = min(int(x_center_norm * self.S), self.S - 1)
                y_grid = min(int(y_center_norm * self.S), self.S - 1)

                x_cell = (x_center_norm * self.S) - x_grid
                y_cell = (y_center_norm * self.S) - y_grid
                
                w_sqrt = math.sqrt(width_norm)
                h_sqrt = math.sqrt(height_norm)

                class_one_hot = torch.zeros(self.C)
                class_one_hot[class_id] = 1.0
                
                box_info = torch.tensor([1.0, x_cell, y_cell, w_sqrt, h_sqrt], dtype=torch.float32)
                
                labels_tensor[y_grid, x_grid, :5] = box_info 
                labels_tensor[y_grid, x_grid, 5 : 5 + self.C] = class_one_hot
        
        return im_arr, labels_tensor

In [13]:
def calc_coords(inp):
    x_c = inp[:, 0]
    y_c = inp[:, 1]
    w = inp[:, 2]
    h = inp[:, 3]

    w = torch.clamp(w, min=0.0)
    h = torch.clamp(h, min=0.0)

    x_min = x_c - w / 2
    y_min = y_c - h / 2
    x_max = x_c + w / 2
    y_max = y_c + h / 2
    
    x_min = torch.clamp(x_min, 0.0, 1.0)
    y_min = torch.clamp(y_min, 0.0, 1.0)
    x_max = torch.clamp(x_max, 0.0, 1.0)
    y_max = torch.clamp(y_max, 0.0, 1.0)

    return torch.stack([x_min, y_min, x_max, y_max], dim=-1)

In [14]:
def yolo_loss(pred, label, alpha_coord, alpha_noo, B = 2):
    mask = label[:, :, :, 0] > 0
    mask_no = label[:, :, :, 0] == 0
    start_bbox = 1
    fin_bbox = 5
    
    # получение элементов, которые внутри себя содержат объекты
    obj_p = pred[mask]
    obj_l = label[mask]
    iou_res = torch.zeros((obj_l.shape[0], B))

    for index in range(B):
        val = obj_p[:, start_bbox:fin_bbox]
        start_bbox += 5
        fin_bbox += 5
        real_coords_p = calc_coords(val) 
        real_coords_l = calc_coords(obj_l[:, 1:5])
        iou_matrix = box_iou(real_coords_p, real_coords_l)
        iou_values = torch.diag(iou_matrix)
        iou_res[:, index] = iou_values 

    coord_b = iou_res.argmax(dim=1)
    most_sim_start = (coord_b * 5 + 1).to(obj_p.device)
    conf_ind = (coord_b * 5).to(obj_p.device)
    bbox_ind = most_sim_start.unsqueeze(1) + torch.arange(4, device = obj_p.device).unsqueeze(0) 
    bbox_pred = torch.gather(input = obj_p, index = bbox_ind, dim = 1)
    conf_pred = torch.gather(input = obj_p, index = conf_ind.unsqueeze(1), dim = 1).squeeze(1)

    # расчет лоссов
    x_loss = F.mse_loss(input = bbox_pred[:, 0], 
                        target = obj_l[:, 1])
    y_loss = F.mse_loss(input = bbox_pred[:, 1], 
                        target = obj_l[:, 2])

    center_loss = alpha_coord * (x_loss + y_loss)

    w_loss = F.mse_loss(input = torch.sqrt(bbox_pred[:, 2]), 
                        target = torch.sqrt(obj_l[:, 3]))
    h_loss = F.mse_loss(input = torch.sqrt(bbox_pred[:, 3]), 
                        target = torch.sqrt(obj_l[:, 4]))

    size_loss = alpha_coord * (w_loss + h_loss)

    conf_loss = F.binary_cross_entropy_with_logits(input = conf_pred, 
                                                   target = iou_res.max(dim=1)[0].to(conf_pred.device))

    class_loss = F.binary_cross_entropy_with_logits(input = obj_p[:, -1],
                                                   target = obj_l[:, -1])
    
    no_obj_p = pred[mask_no]
    no_obj_l = label[mask_no]
    all_conf_indices_no_obj = torch.arange(B, device=no_obj_p.device) * 5
    all_conf_preds_no_obj = torch.gather(no_obj_p, 1, all_conf_indices_no_obj.unsqueeze(0).expand(no_obj_p.shape[0], -1))
    target_no_obj_conf = torch.zeros_like(all_conf_preds_no_obj)
    
    conf_no = F.binary_cross_entropy_with_logits(input = all_conf_preds_no_obj,
                                                 target = target_no_obj_conf)
    
    return center_loss + size_loss + conf_loss + class_loss + alpha_noo * conf_no 

In [15]:
class LightningYolov1(l.LightningModule):
    def __init__(self, base_model, loss, alpha_coord, alpha_noo):
        super().__init__()
        self.base_model = base_model
        self.loss = loss
        self.alpha_coord = alpha_coord
        self.alpha_noo = alpha_noo
        self.metrics = MetricCollection({
            "obj_prec": BinaryPrecision(threshold=0.5), 
            "obj_rec": BinaryRecall(threshold=0.5),    
            "class_prec": BinaryPrecision(threshold=0.5), 
            "class_rec": BinaryRecall(threshold=0.5),    
            # "map": MeanAveragePrecision(iou_type="bbox")
        })
        self.val_metrics = self.metrics.clone(prefix = 'val_')
        self.test_metrics = self.metrics.clone(prefix = 'test_')
    
    def forward(self, x):
        return self.base_model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = self.loss(preds, y, self.alpha_coord, self.alpha_noo)
        self.log('train_loss', loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = self.loss(preds, y, self.alpha_coord, self.alpha_noo)
        
        pred_class_logits = preds[:, :, :, self.base_model.B * 5 : ].contiguous()
        true_class_labels = y[:, :, :, 5 ].contiguous()

        pred_class_sigmoid_flat = F.sigmoid(pred_class_logits).flatten()
        true_class_labels_flat = true_class_labels.flatten()
        
        pred_conf_logits = preds[:, :, :, torch.arange(self.base_model.B ) * 5].flatten().contiguous()
        
        true_obj_labels = y[:, :, :, 0].contiguous() 
        true_obj_labels_expanded = torch.cat([true_obj_labels,true_obj_labels]).flatten()

        pred_conf_sigmoid_flat = F.sigmoid(pred_conf_logits).flatten()
        true_obj_labels_flat = true_obj_labels_expanded.flatten()
                
        self.val_metrics["obj_prec"].update(pred_conf_sigmoid_flat, true_obj_labels_flat)
        self.val_metrics["obj_rec"].update(pred_conf_sigmoid_flat, true_obj_labels_flat)
        self.val_metrics["class_prec"].update(pred_class_sigmoid_flat, true_class_labels_flat) # error
        self.val_metrics["class_rec"].update(pred_class_sigmoid_flat, true_class_labels_flat)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        preds = F.sigmoid(preds[:, : , :, 4:])
        self.test_metrics.update(preds, y)

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics.compute(), prog_bar = True, on_epoch = True)
        self.test_metrics.reset()

    def on_validation_epoch_end(self):
        self.log_dict(self.val_metrics.compute(), prog_bar = True, on_epoch = True)
        self.val_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.base_model.parameters(), lr = 1e-4)

In [16]:
def nms(preds, conf_thresh = 0.2, iou_thresh = 0.5, B = 2):
    if preds.dim() == 3:
        preds = preds.unsqueeze(0)

    batch_nms_results = []

    for i_batch in range(preds.shape[0]):
        image_preds = preds[i_batch]
        
        val_bboxes_only = image_preds[..., :B*5]

        val_reshaped_per_bbox = val_bboxes_only.reshape(*val_bboxes_only.shape[:-1], B, 5)
        tensor_single_bbox = val_reshaped_per_bbox.reshape(-1, 5)

        tensor_single_bbox[:, 0] = F.sigmoid(tensor_single_bbox[:, 0])

        _, sort_indices = torch.sort(tensor_single_bbox[:, 0], descending=True)
        
        sorted_tensor = tensor_single_bbox[sort_indices]

        mask_conf = sorted_tensor[:, 0] > conf_thresh
        sorted_tensor = sorted_tensor[mask_conf]

        final_boxes_for_image = []

        while sorted_tensor.shape[0] > 0:
            current_bbox = sorted_tensor[0]
            final_boxes_for_image.append(current_bbox)

            if sorted_tensor.shape[0] == 1:
                break

            current_min_max = calc_coords(current_bbox[1:]) 

            remaining_bboxes = sorted_tensor[1:]
            remaining_min_max = calc_coords(remaining_bboxes[:, 1:])

            iou_res = box_iou(current_min_max, remaining_min_max).squeeze(0)

            mask_to_keep = iou_res <= iou_thresh

            sorted_tensor = remaining_bboxes[mask_to_keep]
        
        batch_nms_results.append(torch.stack(final_boxes_for_image) if final_boxes_for_image else torch.empty(0, 5, device=preds.device))

    if preds.shape[0] == 1:
        return batch_nms_results[0]
    else:
        return batch_nms_results

In [17]:
class PigsDataModule(l.LightningDataModule):
    def __init__(self, df, S=7, B=2, C=1, batch_size = 64):
        self.S = S
        self.C = C
        self.B = B
        self.df = df
        self.batch_size = batch_size
        self._log_hyperparams = False 

    def setup(self, stage:str):
        self.data = PigsDataset(df = self.df, S = self.S, B = self.B, C = self.C)
        self.ind = torch.randperm(len(self.data))
        self.train_ind = self.ind[:int(0.7*len(self.data))]
        self.val_ind = self.ind[int(0.7*len(self.data)) : int(0.9*len(self.data))]
        self.test_ind = self.ind[int(0.9 * len(self.data)) : len(self.data)]
        self.train = torch.utils.data.Subset(self.data, self.train_ind.tolist())
        self.val = torch.utils.data.Subset(self.data, self.val_ind.tolist())
        self.test = torch.utils.data.Subset(self.data, self.test_ind.tolist())

    @property
    def allow_zero_length_dataloader_with_multiple_devices(self):
        return False

    def train_dataloader(self):
        return DataLoader(self.train, batch_size = self.batch_size, num_workers = 4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size = self.batch_size, num_workers = 4)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size = self.batch_size, num_workers = 4)

In [18]:
def calculate_map50(preds_batch, targets_batch, B, num_classes, conf_thresh_nms=0.05, iou_thresh_nms=0.5, iou_ap_threshold=0.5):
    all_detections_flat = [] 
    all_ground_truths_flat = []

    for img_idx in range(targets_batch.shape[0]):
        gt_image = targets_batch[img_idx]
        
        gt_object_mask = gt_image[:, :, 0] > 0.0 
        gt_object_cells = gt_image[gt_object_mask]
        
        if gt_object_cells.shape[0] == 0:
            continue

        gt_bbox_xywh = gt_object_cells[:, 1:5]
        gt_class_idx = torch.argmax(gt_object_cells[:, 5:], dim=-1)
        
        gt_bbox_xyxy = calc_coords(gt_bbox_xywh)

        for i in range(gt_bbox_xyxy.shape[0]):
            all_ground_truths_flat.append({
                'image_idx': img_idx,
                'bbox': gt_bbox_xyxy[i],
                'class_id': gt_class_idx[i].item(),
                'matched': False
            })

    for img_idx in range(preds_batch.shape[0]):
        detections_for_image = nms_for_map(preds_batch[img_idx], conf_thresh_nms, iou_thresh_nms, B, num_classes)
        
        if detections_for_image.shape[0] == 0:
            continue

        det_bbox_xywh = detections_for_image[:, 1:5]
        det_bbox_xyxy = calc_coords(det_bbox_xywh)
        
        det_conf = detections_for_image[:, 0]
        det_class_idx = detections_for_image[:, 5].long()

        for i in range(det_bbox_xyxy.shape[0]):
            all_detections_flat.append({
                'image_idx': img_idx,
                'confidence': det_conf[i].item(),
                'bbox': det_bbox_xyxy[i],
                'class_id': det_class_idx[i].item()
            })
    
    all_detections_flat.sort(key=lambda x: x['confidence'], reverse=True)

    aps = []

    for class_id in range(num_classes):
        class_detections = [d for d in all_detections_flat if d['class_id'] == class_id]
        class_ground_truths = [gt for gt in all_ground_truths_flat if gt['class_id'] == class_id]

        num_gt_for_class = len(class_ground_truths)
        num_det_for_class = len(class_detections)

        if num_gt_for_class == 0 and num_det_for_class == 0:
            continue
        if num_det_for_class == 0:
            aps.append(0.0)
            continue
        if num_gt_for_class == 0:
            aps.append(0.0)
            continue

        TP = torch.zeros(num_det_for_class, device=preds_batch.device)
        FP = torch.zeros(num_det_for_class, device=preds_batch.device)
        
        for gt in class_ground_truths:
            gt['matched'] = False 

        for det_idx, detection in enumerate(class_detections):
            best_iou = 0.0
            best_gt_index_in_list = -1

            for gt_list_idx, gt in enumerate(class_ground_truths):
                if gt['matched']: 
                    continue
                
                if detection['image_idx'] != gt['image_idx']:
                    continue

                iou = box_iou(detection['bbox'].unsqueeze(0), gt['bbox'].unsqueeze(0)).item()
                if iou > best_iou:
                    best_iou = iou
                    best_gt_index_in_list = gt_list_idx
            
            if best_iou >= iou_ap_threshold and best_gt_index_in_list != -1:
                TP[det_idx] = 1
                class_ground_truths[best_gt_index_in_list]['matched'] = True 
            else:
                FP[det_idx] = 1

        cum_TP = torch.cumsum(TP, dim=0)
        cum_FP = torch.cumsum(FP, dim=0)

        precision = cum_TP / (cum_TP + cum_FP)
        recall = cum_TP / num_gt_for_class

        precision = torch.cat((torch.tensor([1.0], device=precision.device), precision))
        recall = torch.cat((torch.tensor([0.0], device=recall.device), recall))
        
        for i in range(recall.shape[0] - 1, 0, -1):
            precision[i-1] = torch.max(precision[i-1], precision[i])

        ap = torch.sum((recall[1:] - recall[:-1]) * precision[1:])
        
        aps.append(ap.item())

    if len(aps) == 0:
        return 0.0
        
    return sum(aps) / len(aps)

In [19]:
base_model = Yolov1Custom(S = 7, B = 2, C = 1)
model = LightningYolov1(base_model, yolo_loss, 5, 0.005)

In [25]:
wandb_logger = WandbLogger(project = 'yolov1-pigs-dataset',
                          save_dir = 'yolov1_training',
                          name = 'yolov1_train_pigs') 

In [26]:
device = torch.device('cuda:1' if torch.cuda.is_available else 'cpu')
device

device(type='cuda', index=1)

In [27]:
datamodule = PigsDataModule(df, batch_size = 32)

In [28]:
trainer = l.Trainer(
    max_epochs = 30,
    log_every_n_steps = 10,
    callbacks = [TQDMProgressBar()],
    logger = wandb_logger,
    devices = [1]
)

In [29]:
trainer.fit(model = model, datamodule = datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name         | Type             | Params | Mode 
----------------------------------------------------------
0 | base_model   | Yolov1Custom     | 284 M  | train
1 | metrics      | MetricCollection | 0      | train
2 | val_metrics  | MetricCollection | 0      | train
3 | test_metrics | MetricCollection | 0      | train
----------------------------------------------------------
284 M     Trainable params
0         Non-trainable params
284 M     Total params
1,138.670 Total estimated model params size (MB)
78        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# 