# How to train the dataset?

In [1]:
# Lets first install hub
from IPython.display import clear_output
!pip install hub
clear_output()

In [2]:
# Import necessary packages
import hub
import torch
from torchvision import transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt

In [3]:
ds_train = hub.load('/kaggle/input/food-recognition-2022/hub/train')
ds_val = hub.load('/kaggle/input/food-recognition-2022/hub/val')
ds_test = hub.load('/kaggle/input/food-recognition-2022/hub/test')

In [4]:
global class_labels;
class_labels = ds_train.categories.info.class_names

In [9]:
def transform(data):
    image, bboxes = data['images'], data['boxes']
    masks, categories = data['masks'], data['categories']

    transform_all = A.Compose(
        [
            A.CenterCrop(100, 100, p=1),
            ToTensorV2(),
        ],
        bbox_params={
            'format':'coco', 
            'label_fields': ['class_labels']
        }
    )

    transformed = transform_all(
        image=image,
        mask=masks,
        bboxes=bboxes,
        class_labels=categories,
    )

    image = transformed['image']
    boxes = transformed['bboxes']
    masks = transformed['mask']
    categories = transformed['class_labels']
    
    return image, bboxes, masks, categories;

def collate_fn(batch):    
    return {
        "images" : torch.stack([x[0] for x in batch]), 
        "boxes" : [torch.tensor(x[1]) for x in batch],
        "masks" : [x[2] for x in batch],
        "categories" : [torch.tensor(x[3]) for x in batch]
    }

In [10]:
dataloader = ds_train.pytorch(num_workers = 0,
                        batch_size = 1,
                        transform = transform,
                        tensors = ['images', 'boxes', 'masks', 'categories'],
                        collate_fn = collate_fn,) 

for batch in dataloader:
    break;