In [None]:
import json
import os
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import segmentation_models_pytorch as smp
import numpy as np
from PIL import Image

from pprint import pprint
from torch.utils.data import DataLoader

plt.style.use('dark_background')

ORIGINAL_IMAGE_SIZE = (848, 480)


class AgroDataset(torch.utils.data.Dataset):
    def __init__(self, images_path, labels_path, mode="train", transform=None):

        assert mode in {"train", "valid", "test"}

        self.mode = mode
        self.transform = transform

        with open(labels_path) as f:
            self.labels = json.load(f)
        self.labels = [v for v in self.labels.values() if v['regions']]


        if self.mode == "train":  # 80% for train
            #self.filenames = [x for i, x in enumerate(self.filenames) if i % 10 > 1]
            self.filenames = self.filenames[:len(self.filenames)//2]
        elif self.mode == "valid":  # 10% for validation
            #self.filenames = [x for i, x in enumerate(self.filenames) if i % 10 == 1]
            self.filenames = self.filenames[len(self.filenames)//2: len(self.filenames)//2 + 100]
        elif self.mode == "test":  # 10% for validation
            #self.filenames = [x for i, x in enumerate(self.filenames) if i % 10 == 0]
            self.filenames = self.filenames[len(self.filenames)//2 + 100: len(self.filenames)//2 + 200]

    def convert_labels_to_bboxes(self):

        bboxes = []
        for l in self.labels:
            min_x, max_x, min_y, max_y = 1e9, -1e9, 1e9, -1e9
            for x, y in zip(l['regions']['all_points_x'], l['regions']['all_points_y']):
                min_x = min(min_x, x)
                max_x = max(max_x, x)
                min_y = min(min_y, y)
                max_y = max(max_y, y)

            min_x /= ORIGINAL_IMAGE_SIZE[0]
            max_x /= ORIGINAL_IMAGE_SIZE[0]
            min_y /= ORIGINAL_IMAGE_SIZE[1]
            max_y /= ORIGINAL_IMAGE_SIZE[1]
            bbox = [(min_x + max_x) / 2, (min_y + max_y) / 2, (max_x - min_x), (max_y - min_y)]

            bboxes.append(bbox)

        return bboxes

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        folder = self.filenames[idx]
        filename = os.path.basename(folder)

        image_path = os.path.join(self.root, folder, 'images', f"{filename}.png")
        mask_path = os.path.join(self.root, folder, 'masks', f"{filename}.png")

        image = np.array(Image.open(image_path).convert("RGB"))

        trimap = np.array(Image.open(mask_path))
        mask = self._preprocess_mask(trimap)

        sample = dict(image=image, mask=mask, trimap=trimap)
        if self.transform is not None:
            sample = self.transform(**sample)

        return sample


    @staticmethod
    def _preprocess_mask(mask):
        mask = mask.astype(np.float32)
        mask[mask < 2.0] = 0.0
        mask[mask > 2.0] = 1.0
        return mask


class SimpleAgroDataset(AgroDataset):
    def __getitem__(self, *args, **kwargs):

        sample = super().__getitem__(*args, **kwargs)

        # resize images
        image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR))
        mask = np.array(Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST))
        trimap = np.array(Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST))

        # convert to other format HWC -> CHW
        sample["image"] = np.moveaxis(image, -1, 0)
        sample["mask"] = np.expand_dims(mask, 0)
        sample["trimap"] = np.expand_dims(trimap, 0)

        return sample