## 1.Importing Libraries

In [6]:
import os,json,cv2,numpy as np,matplotlib.pyplot as plt 

import torch 
from torch.utils.data import Dataset,DataLoader 

import torchvision 
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F 

import albumentations as A

In [2]:
import transforms,utils,engine,train 
from utils import collate_fn 
from engine import train_one_epoch,evaluate

## 2.Augmentations

In [12]:
def train_transforms():
    return A.Compose([
        A.Sequential([
            A.RandomRotate90(p=1),
            A.HorizontalFlip(p=0.5), 
            A.RandomBrightnessContrast(p=0.3)
        ],p=1)
    ],
    keypoint_params = A.KeypointParams(format='xy'),
    bbox_params = A.BboxParams(format='pascal_voc',label_fields=['mouth'])
    )

## 3.Dataset Class

In [10]:
class ClassDataset(Dataset):
    def __init__(self,root,transform=None,demo=False):
        self.root = root 
        self.transform = transform 
        self.demo = demo 
        self.imgs_files = sorted(os.listdir(os.path.join(root,"images"))) 
        self.annotations_files = sorted(os.listdir(os.path.join(root,"annotations"))) 

    def __len__(self):
            return len(self.imgs_files)
        
    def __getitem__(self,idx):
        img_path = os.path.join(self.root,"images",self.imgs_files[idx])
        annotations_path = os.path.join(self.root,"annotations",self.annotations_files[idx]) 
        
        img_original = cv2.imread(img_path) 
        img_original = cv2.cvtColor(img_original,cv2.COLOR_BGR2RGB) 
        
        with open(annotations_path) as f:
            data = json.load(f) 
            bboxes_original = data['bboxes']
            keypoints_original = data['keypoints'] 

            bboxes_labels_original = ['mouth' for _ in bboxes_original] 

        if self.transform:
            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp] 
            transformed = self.transform(image=img_original,bboxes=bboxes_original,bboxes_label=bboxes_labels_original,keypoints=keypoints_original)
            img = transformed['image']
            bboxes = transformed['bboxes']

            keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']),(-1,2,2)).tolist() 

            keypoints=[] 
            for o_idx,obj in enumerate(keypoints_transformed_unflattened):
                obj_keypoints = [] 
                for k_idx,kp in enumerate(obj):
                    obj_keypoints.append(kp+[keypoints_original[o_idx][k_idx][2]])
                keypoints.append(obj_keypoints) 
        else:
            img,bboxes,keypoints = img_original,bboxes_original,keypoints_original 

        bboxes = torch.as_tensor(bboxes,dtype=torch.float32) 
        target = {} 
        target["boxes"] = bboxes 
        target["labels"] = torch.as_tensor([1 for _ in bboxes],dtype=torch.int64)
        target["image_id"] = torch.tensor([idx])
        target["area"] = (bboxes[:,3]-bboxes[:,1])*(bboxes[:,2]-bboxes[:,0])
        target["iscrowd"] = torch.zeros(len(bboxes),dtype=torch.int64)
        target["keypoints"] = torch.as_tensor(keypoints,dtype=torch.float32) 
        img = F.to_tensor(img) 

        bboxes_original =torch.as_tensor(bboxes_original,dtype=torch.float32) 
        target_original = {} 
        target_original["boxes"] = bboxes_original 
        target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original],dtype=torch.int64)
        target_original["image_id"] = torch.tensor([idx])
        target_original["area"] = (bboxes_original[:,3]-bboxes_original[:,1])*(bboxes_original[:,2]-bboxes[:,0])
        target_original["iscrowd"] = torch.zeros(len(bboxes_original),dtype=torch.int64)
        target_original["keypoints"] = torch.as_tensor(keypoints_original,dtype=torch.float32) 
        img_original = F.to_tensor(img_original) 

        if self.demo:
            return img,target,img_original,target_original 
        else:
            return img,target 
        
        


## 4.Visualize example

In [14]:
KEYPOINTS_FOLDER_TRAIN ='./train'
dataset = ClassDataset(KEYPOINTS_FOLDER_TRAIN,transform=train_transforms(),demo=True)
data_loader = DataLoader(dataset,batch_size=1,shuffle=True,collate_fn=collate_fn) 

iterator = iter(data_loader) 
batch = next(iterator)

print("Original")
print(batch[3])

print("Transformed")
print(batch[1])

ValueError: Your 'label_fields' are not valid - them must have same names as params in dict