In [1]:
try:
    from nvidia.dali.pipeline import pipeline_def
    import nvidia.dali.types as types
    import nvidia.dali.fn as fn
    from nvidia.dali.plugin.pytorch import DALIGenericIterator
except:
    !pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110
    from nvidia.dali.pipeline import pipeline_def
    import nvidia.dali.types as types
    import nvidia.dali.fn as fn
    from nvidia.dali.plugin.pytorch import DALIGenericIterator
    
from pathlib import Path
try:
    from deli import load_text
except:
    !pip install deli > None
    from deli import load_text
import numpy as np
import matplotlib.pyplot as plt
from random import shuffle
import torch
import time
from tqdm.notebook import tqdm
from satellite_connectome import Satellite # utility script, first time it will run pip install inside of it
from gc import collect

Looking in indexes: https://pypi.org/simple, https://developer.download.nvidia.com/compute/redist
Collecting nvidia-dali-cuda110
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda110/nvidia_dali_cuda110-1.26.0-8269290-py3-none-manylinux2014_x86_64.whl (488.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m488.5/488.5 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nvidia-dali-cuda110
Successfully installed nvidia-dali-cuda110-1.26.0
[0m

In [2]:
from connectome import Source, Transform, Chain, meta, impure
from random import shuffle

class Patching(Transform):
    """
    Transform that crops image based on boxes coordinates.
    From now on, such images will be called patches.
    """
    __inherit__ = ['coords', 'dataset', 'group', 'id', 'ids', 'mask']
    _half_patch_shape: int = 512
    _avoid_duplicates: bool = True
    _sample_empty_patches: bool = True
    
    @impure
    def _patch_boundaries(image, coords, _half_patch_shape, _avoid_duplicates, _sample_empty_patches):
        """
        Takes image (just for it's shape) and coords.
        Creates patch coordinates based on the box's center,
        check if coordinates are valid (corners must be in the picture).
        Makes random patch if there are no boxes.
        
        Args :
        *image - np.array with shape (H, W, 3)
        *coords - np.array of coords of boxes in format of [..., [x, y, w, h], ...], where (x, y) - center
        *_half_patch_shape - half of patch size
        *_avoid_duplicates - bool, iteratively delete patches where boxes are already in the previous patches
        *_sample_empty_patches - bool, if True, samples patches without boxes in as much as patches with boxes
        
        Returns : 
        *patch_boundaries - np.array of patches with format [..., [start, stop], ...] where start&stop = [y, x]
        """
        img_shape = image.shape[:2]
        patch_boundaries = []
        if coords is None:
            y_patch_center = np.random.randint(_half_patch_shape, img_shape[0] - _half_patch_shape)
            x_patch_center = np.random.randint(_half_patch_shape, img_shape[1] - _half_patch_shape)
            
            y_patch_start = y_patch_center - _half_patch_shape
            y_patch_stop = y_patch_center + _half_patch_shape
            x_patch_start = x_patch_center - _half_patch_shape
            x_patch_stop = x_patch_center + _half_patch_shape
            start = [y_patch_start, x_patch_start]
            stop = [y_patch_stop, x_patch_stop]
            patch_boundaries.append([start, stop])
            
            patch_boundaries = np.asarray(patch_boundaries).astype(np.uint16)
            
        else:
            for coord in coords:
                x_central, y_central, w, h =  coord
                x_central = int(x_central)
                y_central = int(y_central)

                y_patch_center = np.random.randint(np.clip(y_central - _half_patch_shape, a_min=_half_patch_shape, a_max=1e4), np.clip(y_central + _half_patch_shape, a_min=0, a_max=img_shape[0] - _half_patch_shape))
                x_patch_center = np.random.randint(np.clip(x_central - _half_patch_shape, a_min=_half_patch_shape, a_max=1e4), np.clip(x_central + _half_patch_shape, a_min=0, a_max=img_shape[1] - _half_patch_shape))

                y_patch_start = y_patch_center - _half_patch_shape
                y_patch_stop = y_patch_center + _half_patch_shape
                x_patch_start = x_patch_center - _half_patch_shape
                x_patch_stop = x_patch_center + _half_patch_shape
                start = [y_patch_start, x_patch_start]
                stop = [y_patch_stop, x_patch_stop]
                patch_boundaries.append([start, stop])
            
            patch_boundaries = np.asarray(patch_boundaries).astype(np.uint16)
                
            if _avoid_duplicates:
                np.random.shuffle(patch_boundaries)
                coords_taken = np.zeros(len(coords), dtype=bool)
                remain_patches = np.zeros(len(patch_boundaries), dtype=bool)
                for patch_idx, boundary in enumerate(patch_boundaries):
                    start, stop = boundary
                    patch_y_start, patch_x_start = start
                    patch_y_stop, patch_x_stop = stop

                    for coord_idx, coord in enumerate(coords):
                        if if_coords_in_patch(coord, patch_y_start, patch_x_start, patch_y_stop, patch_x_stop) and (not coords_taken[coord_idx]):
                            remain_patches[patch_idx] = True
                            coords_taken[coord_idx] = True
                patch_boundaries = patch_boundaries[remain_patches]
            
            if _sample_empty_patches: 
                empty_patch_boundaries = []
                for _ in range(len(patch_boundaries)):
                    y_patch_center = np.random.randint(_half_patch_shape, img_shape[0] - _half_patch_shape)
                    x_patch_center = np.random.randint(_half_patch_shape, img_shape[1] - _half_patch_shape)
                    
                    y_patch_start = y_patch_center - _half_patch_shape
                    y_patch_stop = y_patch_center + _half_patch_shape
                    x_patch_start = x_patch_center - _half_patch_shape
                    x_patch_stop = x_patch_center + _half_patch_shape
                    
                    coords_in_patch = True
                    while coords_in_patch:
                        coords_in_patch = False
                        for coord in coords:
                            if if_coords_in_patch(coord, y_patch_start, x_patch_start, y_patch_stop, x_patch_stop):
                                coords_in_patch = True
                                break
                                
                        if coords_in_patch:
                            y_patch_center = np.random.randint(_half_patch_shape, img_shape[0] - _half_patch_shape)
                            x_patch_center = np.random.randint(_half_patch_shape, img_shape[1] - _half_patch_shape)
                            
                            y_patch_start = y_patch_center - _half_patch_shape
                            y_patch_stop = y_patch_center + _half_patch_shape
                            x_patch_start = x_patch_center - _half_patch_shape
                            x_patch_stop = x_patch_center + _half_patch_shape
                            
                    start = [y_patch_start, x_patch_start]
                    stop = [y_patch_stop, x_patch_stop]
                    empty_patch_boundaries.append([start, stop])
                empty_patch_boundaries = np.asarray(empty_patch_boundaries).astype(np.uint16)
                patch_boundaries = np.vstack((patch_boundaries, empty_patch_boundaries))
                
        return patch_boundaries
    
    def box_list(boxes, coords, _patch_boundaries, _half_patch_shape):
        """
        !!!Replaces 'boxes' with 'box_list'
        Takes boxes and previously made patch boundaries.
        Creates list of boxes, which are correspond to patches,
        e.g. for _patch_boundaries[j] boxes[j] is np.array of boxes
        with format [..., [start, stop], ...], where start and stop = [y, x] and [start, stop]
        are written in _patch_boundaries[j] coordinates
        
        Args :
        *boxes - np.array of boxes with format [..., [start, stop], ...], where start and stop = [y, x]
        *coords - np.array of coords of boxes in format of [..., [x, y, w, h], ...], where (x, y) - center
        *_patch_boundaries - np.array of patches with format [..., [start, stop], ...] where start&stop = [y, x]
        *_half_patch_shape - half of patch size
        
        Returns : 
        *box_list - None or list of np.array of boxes with format [..., [start, stop], ...],
        where start and stop = [y, x] and [start, stop] with correspondence to _patch_boundaries
        """
        if boxes is None:
            return None
        else:
            box_list = []
            for boundary in _patch_boundaries:
                start, stop = boundary
                patch_y_start, patch_x_start = start
                patch_y_stop, patch_x_stop = stop
                
                patch_upper_left_corner = start.reshape(1, 2)
                clip_box_min = start.reshape(1, 2).astype(np.uint16)
                clip_box_max = stop.reshape(1, 2).astype(np.uint16)
                
                boxes_in_patch = []
                for box, coord in zip(boxes, coords):
                    if if_coords_in_patch(coord, patch_y_start, patch_x_start, patch_y_stop, patch_x_stop):
                        new_box = np.clip(box, clip_box_min, clip_box_max)
                        new_box -= patch_upper_left_corner
                        boxes_in_patch.append(new_box)
                boxes_in_patch = np.asarray(boxes_in_patch).astype(np.uint16)
                box_list.append(boxes_in_patch)
            return box_list
                    
    def patches(image, _patch_boundaries):
        """
        !!!Replaces 'image' with 'patches'
        Takes image and previously made _patch_boundaries.
        Creates list of patches.
        
        Args:
        *image - np.array with shape (3, H, W)
        *_patch_boundaries - np.array of patches with format [..., [start, stop], ...] where start&stop = [y, x]
        
        Returns:
        *patches - list of np.array with shape (2*_half_patch_shape, 2*_half_patch_shape, 3)
        """
        patches = []
        for boundary in _patch_boundaries:
            start, stop = boundary
            y_start, x_start = start
            y_stop, x_stop = stop
            patches.append(image[y_start:y_stop,x_start:x_stop,:])
        return patches
    
def if_coords_in_patch(coord, patch_y_start, patch_x_start, patch_y_stop, patch_x_stop):
    x_central, y_central, w, h =  coord
    x_central = int(x_central)
    y_central = int(y_central)
    if (patch_y_start < y_central < patch_y_stop) and (patch_x_start < x_central < patch_x_stop):
        return True
    else:
        return False

In [3]:
class ExternalInputCallable:
    def __init__(self,
                 directory="/kaggle/input/satelite-dataset/train",
                 batch_size=16, 
                 file_extension='JPG',
                 half_patch_size=512,
                 prob_sample_with_boxes=0.5,
                 pad_to=8,
                 make_shuffle=True
                ):
        self.batch_size = batch_size
        self.prob_sample_with_boxes = prob_sample_with_boxes
        self.pad_to = pad_to
        self.image_pathes = list(Path(directory).glob('**/*.{}'.format(file_extension)))
        self.half_patch_size = half_patch_size
            
        self.full_iterations = len(self.image_pathes) // batch_size

        self.perm = None
        self.last_seen_epoch = None

    def __call__(self, sample_info):
        if sample_info.iteration >= self.full_iterations:
            raise StopIteration
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx).permutation(len(self.image_pathes))
        sample_idx = self.perm[sample_info.idx_in_epoch]

        image_path = self.image_pathes[sample_idx]
        with open(image_path, 'rb') as f:
            encoded_img = np.frombuffer(f.read(), dtype=np.uint8)
            
        img_shape = get_jpeg_size(encoded_img)
        label_path = Path(str(image_path).replace("image", "label")).with_suffix(".txt")
        coords = get_coords(label_path, img_shape)
        patch_left_up_corner, coords_in_patch, mask = bernulli_sampling(img_shape, coords, self.half_patch_size, self.prob_sample_with_boxes, self.pad_to)
        
        return encoded_img, patch_left_up_corner.astype(int), coords_in_patch.astype(int), mask.astype(int), np.array([2*self.half_patch_size, 2*self.half_patch_size]).astype(int)
    
def bernulli_sampling(img_shape, coords, half_patch_shape=512, prob_sample_with_boxes=0.5, pad_to=5):
    if len(coords) == 0 or np.random.rand(1) > prob_sample_with_boxes:
        # sample patch without box
        patch_left_up_corner = get_empty_patch_left_up_corner(img_shape, coords, half_patch_shape)
    else:
        anchor_coord = coords[np.random.randint(len(coords))]
        patch_left_up_corner = get_patch_left_up_corner_anchor(img_shape, anchor_coord, half_patch_shape)
    coords_in_patch, mask = get_coords_in_patch(patch_left_up_corner, half_patch_shape, coords, pad_to)
    return patch_left_up_corner, coords_in_patch, mask
        
def get_coords_in_patch(left_up_corner, half_patch_shape, coords, pad_to=5):
    x_start_patch, y_start_patch = left_up_corner
    x_stop_patch, y_stop_patch = x_start_patch + 2 * half_patch_shape, y_start_patch + 2 * half_patch_shape
    coords_in_patch = []
    for coord in coords:
        clip_box_min = left_up_corner.reshape(1, 2)
        clip_box_max = (left_up_corner + 2*half_patch_shape).reshape(1, 2)
        if if_coords_in_patch(coord, y_start_patch, x_start_patch, y_stop_patch, x_stop_patch):
            x, y, w, h = coord
            x_start, x_stop = x - w/2, x + w/2
            y_start, y_stop = y - h/2, y + h/2
            new_box = np.array([[x_start, y_start], [x_stop, y_stop]])
            new_box = np.clip(new_box, clip_box_min, clip_box_max)
            new_box -= left_up_corner
            x, y, w, h = (new_box[0][0] + new_box[1][0])/2, (new_box[0][1] + new_box[1][1])/2, (new_box[1][0] - new_box[0][0]), (new_box[1][1] - new_box[0][1])
            coords_in_patch.append([x, y, w, h])
    coords_in_patch = np.asarray(coords_in_patch, dtype=int)
    if pad_to < len(coords_in_patch):
        print(coords_in_patch)
        raise ValueError('pad_to can not be less that num boxes in patch')
    elif len(coords_in_patch) == 0:
        coords_in_patch = np.zeros((pad_to, 4))
        mask = np.zeros((pad_to, 4))
    elif len(coords_in_patch) == pad_to:
        mask = np.ones((pad_to, 4))
    else:
        mask = np.ones((len(coords_in_patch), 4))
        appendix = np.zeros((pad_to-len(coords_in_patch), 4))
        mask = np.vstack((mask, appendix))
        coords_in_patch = np.vstack((coords_in_patch, appendix))
    
    return coords_in_patch, mask
        
def get_empty_patch_left_up_corner(img_shape, coords, half_patch_shape):
    y_patch_center = np.random.randint(half_patch_shape, img_shape[0] - half_patch_shape)
    x_patch_center = np.random.randint(half_patch_shape, img_shape[1] - half_patch_shape)

    y_patch_start = y_patch_center - half_patch_shape
    y_patch_stop = y_patch_center + half_patch_shape
    x_patch_start = x_patch_center - half_patch_shape
    x_patch_stop = x_patch_center + half_patch_shape

    coords_in_patch = True
    while coords_in_patch:
        coords_in_patch = False
        for coord in coords:
            if if_coords_in_patch(coord, y_patch_start, x_patch_start, y_patch_stop, x_patch_stop):
                coords_in_patch = True
                break

        if coords_in_patch:
            y_patch_center = np.random.randint(half_patch_shape, img_shape[0] - half_patch_shape)
            x_patch_center = np.random.randint(half_patch_shape, img_shape[1] - half_patch_shape)

            y_patch_start = y_patch_center - half_patch_shape
            y_patch_stop = y_patch_center + half_patch_shape
            x_patch_start = x_patch_center - half_patch_shape
            x_patch_stop = x_patch_center + half_patch_shape
    return np.array([x_patch_start, y_patch_start], dtype=int)

def get_patch_left_up_corner_anchor(img_shape, anchor_coord, half_patch_shape):
    x_central, y_central, w, h =  anchor_coord
    x_central = int(x_central)
    y_central = int(y_central)

    y_patch_center = np.random.randint(np.clip(y_central - half_patch_shape, a_min=half_patch_shape, a_max=1e4), np.clip(y_central + half_patch_shape, a_min=0, a_max=img_shape[0] - half_patch_shape))
    x_patch_center = np.random.randint(np.clip(x_central - half_patch_shape, a_min=half_patch_shape, a_max=1e4), np.clip(x_central + half_patch_shape, a_min=0, a_max=img_shape[1] - half_patch_shape))

    y_patch_start = y_patch_center - half_patch_shape
    x_patch_start = x_patch_center - half_patch_shape
    return np.array([x_patch_start, y_patch_start])
    
def get_coords(path, shape):
    yolo_coords = [c.replace("0 ", "") for c in load_text(path).splitlines()]

    if len(yolo_coords) == 0 or len(yolo_coords[0]) == 0: # no boxes
        return np.array([], dtype=np.uint16).reshape(0, 0)

    _coords = []
    y_shape, x_shape = shape
    for coord_str in yolo_coords:
        c = [float(w) for w in coord_str.split(" ")]

        c[0] = round(c[0] * x_shape)
        c[1] = round(c[1] * y_shape)
        c[2] = round(c[2] * x_shape)
        c[3] = round(c[3] * y_shape)
        _coords.append(c)

    return np.asarray(_coords).astype(int)
    
def coords2boxes(coords):
    if len(coords) == 0:
        return np.array([], dtype=np.uint16).reshape(0, 0, 0)
    slices = []
    for c in coords:
        slices.append(coords2box(c))
    return np.asarray(slices).astype(int)


def coords2box(coords):
    start = np.array((coords[1] - coords[3] // 2, coords[0] - coords[2] // 2), dtype=np.uint16)
    stop = np.array((coords[1] + coords[3] // 2, coords[0] + coords[2] // 2), dtype=np.uint16)
    return (start, stop)

def if_coords_in_patch(coord, patch_y_start, patch_x_start, patch_y_stop, patch_x_stop):
    x_central, y_central, w, h =  coord
    x_central = int(x_central)
    y_central = int(y_central)
    if (patch_y_start < y_central < patch_y_stop) and (patch_x_start < x_central < patch_x_stop):
        return True
    else:
        return False
    
def get_jpeg_size(data):
    data_size=len(data)
    i=0
    if(data[0] == 0xFF and data[1] == 0xD8 and data[2] == 0xFF and data[3] == 0xE0): 
        i += 4
        if(data[i+2] == ord('J') and data[i+3] == ord('F') and data[i+4] == ord('I') and data[i+5] == ord('F') and data[i+6] == 0x00):
            block_length = data[i] * 256 + data[i+1]
            while (i<data_size):
                i+=block_length
                if(i >= data_size):
                    return False
                if(data[i] != 0xFF):
                    return False
                if(data[i+1] == 0xC0):
                    height = data[i+5]*256 + data[i + 6]
                    width = data[i+7]*256 + data[i + 8]
                    return np.asarray((height, width), dtype=np.uint16)
                else:
                    i+=2
                    block_length = data[i] * 256 + data[i + 1]
            return False
        else:
            return False
    else:
        return False

In [4]:
external_input_callable_def = """
import numpy as np
from deli import load_text
from pathlib import Path

class ExternalInputCallable:
    def __init__(self,
                 directory="/kaggle/input/satelite-dataset/train",
                 batch_size=16, 
                 file_extension='JPG',
                 half_patch_size=512,
                 prob_sample_with_boxes=0.5,
                 pad_to=8,
                 make_shuffle=True
                ):
        self.batch_size = batch_size
        self.prob_sample_with_boxes = prob_sample_with_boxes
        self.pad_to = pad_to
        self.image_pathes = list(Path(directory).glob('**/*.{}'.format(file_extension)))
        self.half_patch_size = half_patch_size
            
        self.full_iterations = len(self.image_pathes) // batch_size

        self.perm = None
        self.last_seen_epoch = None

    def __call__(self, sample_info):
        if sample_info.iteration >= self.full_iterations:
            raise StopIteration
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx).permutation(len(self.image_pathes))
        sample_idx = self.perm[sample_info.idx_in_epoch]

        image_path = self.image_pathes[sample_idx]
        with open(image_path, 'rb') as f:
            encoded_img = np.frombuffer(f.read(), dtype=np.uint8)
            
        img_shape = get_jpeg_size(encoded_img)
        label_path = Path(str(image_path).replace("image", "label")).with_suffix(".txt")
        coords = get_coords(label_path, img_shape)
        patch_left_up_corner, coords_in_patch, mask = bernulli_sampling(img_shape, coords, self.half_patch_size, self.prob_sample_with_boxes, self.pad_to)
        
        return encoded_img, patch_left_up_corner.astype(int), coords_in_patch.astype(int), mask.astype(int), np.array([2*self.half_patch_size, 2*self.half_patch_size]).astype(int)
    
def bernulli_sampling(img_shape, coords, half_patch_shape=512, prob_sample_with_boxes=0.5, pad_to=5):
    if len(coords) == 0 or np.random.rand(1) > prob_sample_with_boxes:
        # sample patch without box
        patch_left_up_corner = get_empty_patch_left_up_corner(img_shape, coords, half_patch_shape)
    else:
        anchor_coord = coords[np.random.randint(len(coords))]
        patch_left_up_corner = get_patch_left_up_corner_anchor(img_shape, anchor_coord, half_patch_shape)
    coords_in_patch, mask = get_coords_in_patch(patch_left_up_corner, half_patch_shape, coords, pad_to)
    return patch_left_up_corner, coords_in_patch, mask
        
def get_coords_in_patch(left_up_corner, half_patch_shape, coords, pad_to=5):
    x_start_patch, y_start_patch = left_up_corner
    x_stop_patch, y_stop_patch = x_start_patch + 2 * half_patch_shape, y_start_patch + 2 * half_patch_shape
    coords_in_patch = []
    for coord in coords:
        clip_box_min = left_up_corner.reshape(1, 2)
        clip_box_max = (left_up_corner + 2*half_patch_shape).reshape(1, 2)
        if if_coords_in_patch(coord, y_start_patch, x_start_patch, y_stop_patch, x_stop_patch):
            x, y, w, h = coord
            x_start, x_stop = x - w/2, x + w/2
            y_start, y_stop = y - h/2, y + h/2
            new_box = np.array([[x_start, y_start], [x_stop, y_stop]])
            new_box = np.clip(new_box, clip_box_min, clip_box_max)
            new_box -= left_up_corner
            x, y, w, h = (new_box[0][0] + new_box[1][0])/2, (new_box[0][1] + new_box[1][1])/2, (new_box[1][0] - new_box[0][0]), (new_box[1][1] - new_box[0][1])
            coords_in_patch.append([x, y, w, h])
    coords_in_patch = np.asarray(coords_in_patch, dtype=int)
    if pad_to < len(coords_in_patch):
        raise ValueError('pad_to can not be less that num boxes in patch')
    elif len(coords_in_patch) == 0:
        coords_in_patch = np.zeros((pad_to, 4))
        mask = np.zeros((pad_to, 4))
    elif len(coords_in_patch) == pad_to:
        mask = np.ones((pad_to, 4))
    else:
        mask = np.ones((len(coords_in_patch), 4))
        appendix = np.zeros((pad_to-len(coords_in_patch), 4))
        mask = np.vstack((mask, appendix))
        coords_in_patch = np.vstack((coords_in_patch, appendix))
    
    return coords_in_patch, mask
        
def get_empty_patch_left_up_corner(img_shape, coords, half_patch_shape):
    y_patch_center = np.random.randint(half_patch_shape, img_shape[0] - half_patch_shape)
    x_patch_center = np.random.randint(half_patch_shape, img_shape[1] - half_patch_shape)

    y_patch_start = y_patch_center - half_patch_shape
    y_patch_stop = y_patch_center + half_patch_shape
    x_patch_start = x_patch_center - half_patch_shape
    x_patch_stop = x_patch_center + half_patch_shape

    coords_in_patch = True
    while coords_in_patch:
        coords_in_patch = False
        for coord in coords:
            if if_coords_in_patch(coord, y_patch_start, x_patch_start, y_patch_stop, x_patch_stop):
                coords_in_patch = True
                break

        if coords_in_patch:
            y_patch_center = np.random.randint(half_patch_shape, img_shape[0] - half_patch_shape)
            x_patch_center = np.random.randint(half_patch_shape, img_shape[1] - half_patch_shape)

            y_patch_start = y_patch_center - half_patch_shape
            y_patch_stop = y_patch_center + half_patch_shape
            x_patch_start = x_patch_center - half_patch_shape
            x_patch_stop = x_patch_center + half_patch_shape
    return np.array([x_patch_start, y_patch_start], dtype=int)

def get_patch_left_up_corner_anchor(img_shape, anchor_coord, half_patch_shape):
    x_central, y_central, w, h =  anchor_coord
    x_central = int(x_central)
    y_central = int(y_central)

    y_patch_center = np.random.randint(np.clip(y_central - half_patch_shape, a_min=half_patch_shape, a_max=1e4), np.clip(y_central + half_patch_shape, a_min=0, a_max=img_shape[0] - half_patch_shape))
    x_patch_center = np.random.randint(np.clip(x_central - half_patch_shape, a_min=half_patch_shape, a_max=1e4), np.clip(x_central + half_patch_shape, a_min=0, a_max=img_shape[1] - half_patch_shape))

    y_patch_start = y_patch_center - half_patch_shape
    x_patch_start = x_patch_center - half_patch_shape
    return np.array([x_patch_start, y_patch_start])
    
def get_coords(path, shape):
    yolo_coords = [c.replace("0 ", "") for c in load_text(path).splitlines()]

    if len(yolo_coords) == 0 or len(yolo_coords[0]) == 0: # no boxes
        return np.array([], dtype=np.uint16).reshape(0, 0)

    _coords = []
    y_shape, x_shape = shape
    for coord_str in yolo_coords:
        c = [float(w) for w in coord_str.split(" ")]

        c[0] = round(c[0] * x_shape)
        c[1] = round(c[1] * y_shape)
        c[2] = round(c[2] * x_shape)
        c[3] = round(c[3] * y_shape)
        _coords.append(c)

    return np.asarray(_coords).astype(int)
    
def coords2boxes(coords):
    if len(coords) == 0:
        return np.array([], dtype=np.uint16).reshape(0, 0, 0)
    slices = []
    for c in coords:
        slices.append(coords2box(c))
    return np.asarray(slices).astype(int)


def coords2box(coords):
    start = np.array((coords[1] - coords[3] // 2, coords[0] - coords[2] // 2), dtype=np.uint16)
    stop = np.array((coords[1] + coords[3] // 2, coords[0] + coords[2] // 2), dtype=np.uint16)
    return (start, stop)

def if_coords_in_patch(coord, patch_y_start, patch_x_start, patch_y_stop, patch_x_stop):
    x_central, y_central, w, h =  coord
    x_central = int(x_central)
    y_central = int(y_central)
    if (patch_y_start < y_central < patch_y_stop) and (patch_x_start < x_central < patch_x_stop):
        return True
    else:
        return False
    
def get_jpeg_size(data):
    data_size=len(data)
    i=0
    if(data[0] == 0xFF and data[1] == 0xD8 and data[2] == 0xFF and data[3] == 0xE0): 
        i += 4
        if(data[i+2] == ord('J') and data[i+3] == ord('F') and data[i+4] == ord('I') and data[i+5] == ord('F') and data[i+6] == 0x00):
            block_length = data[i] * 256 + data[i+1]
            while (i<data_size):
                i+=block_length
                if(i >= data_size):
                    return False
                if(data[i] != 0xFF):
                    return False
                if(data[i+1] == 0xC0):
                    height = data[i+5]*256 + data[i + 6]
                    width = data[i+7]*256 + data[i + 8]
                    return np.asarray((height, width), dtype=np.uint16)
                else:
                    i+=2
                    block_length = data[i] * 256 + data[i + 1]
            return False
        else:
            return False
    else:
        return False
"""

with open("external_input_tmp_module.py", 'w') as f:
    f.write(external_input_callable_def)

import external_input_tmp_module

In [5]:
class DALISatelliteDataloader:
    def __init__(
        self,
        directory="/kaggle/input/satelite-dataset/train",
        batch_size=16,
        num_threads=2,
        device='mixed',
        device_id=0,
        py_num_workers=4,
        half_patch_size=512,
        prob_sample_with_boxes=0.5,
        pad_to=8,
        make_shuffle=True
    ):
        if __name__ == '__main__':
            self.external_sourse = external_input_tmp_module.ExternalInputCallable(
                directory=directory,
                batch_size=batch_size, 
                half_patch_size=half_patch_size,
                prob_sample_with_boxes=prob_sample_with_boxes,
                pad_to=pad_to,
                make_shuffle=make_shuffle
            )
            self.pipeline = parallel_pipeline(
                batch_size=batch_size,
                num_threads=num_threads,
                device=device,
                device_id=device_id,
                py_num_workers=py_num_workers,
                py_start_method='spawn'
            )
            self.pipeline.build()
            self.dataloader = train_data = DALIGenericIterator(
                [self.pipeline],
                ['image', 'coords', 'mask'],
                auto_reset=True
            )
            
            
    def __del__(self):
        if __name__ == '__main__':
            del(self.pipeline)
            del(self.dataloader)
            del(self.external_sourse)
            collect()
        
    def __iter__(self):
        if __name__ == '__main__':
            return self.dataloader
        
    def __len__(self):
        return self.external_sourse.full_iterations

@pipeline_def
def parallel_pipeline(device='mixed'):
    encoded_img, patch_left_up_corner, coords_in_patch, mask, shape = fn.external_source(
        source=external_input_tmp_module.ExternalInputCallable(batch_size=batch_size), 
        num_outputs=5,
        batch=False,
        parallel=True)
    decode = fn.decoders.image_slice(
        encoded_img,
        patch_left_up_corner,
        shape,
        device=device,
        use_fast_idct=True,
        memory_stats=True,
        hw_decoder_load=1.,
        hybrid_huffman_threshold=1e6,
        jpeg_fancy_upsampling=False,
        preallocate_height_hint=0,
        preallocate_width_hint =0,
        affine=False
    )
    return decode, coords_in_patch, mask