In [17]:
import glob
from PIL import Image
import json
import numpy as np
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import torch
import random

In [10]:
class NeutronDataset(Dataset):
    def __init__(self, data, target, n_classes=21, transform=None):
        self.n_classes = n_classes
        self.data = data
        self.target = target
        self.transform = transform

    def get_one_hot(self, targets, nb_classes):
        res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
        return res.reshape(list(targets.shape) + [nb_classes])

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]

        if self.transform:
            x = self.transform(x)
            y = self.transform(y)

        y = np.moveaxis(self.get_one_hot(y.astype(int), self.n_classes), -1, 0)
        return torch.from_numpy(x).float(), torch.from_numpy(y).float()

    def __len__(self):
        return len(self.data)


class NeutronDataLoader(pl.LightningDataModule):
    def __init__(self, data_dir: str = "C:/Users/Tobias/Downloads/HIDA-ufz_image_challenge/photos_annotated",
                 batch_size: int = 8,
                 num_workers: int = 1, transform=None):
        super().__init__()

        self.LABEL_SUFFIX = "*.png"
        self.IMAGE_SUFFIX = "*.jpg"

        self.channels = 3
        self.image_shape = None
        self.train_data = None
        self.valid_data = None
        self.test_data = None

        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform

    def load_data(self):
        images = sorted(glob.glob(os.path.join(self.data_dir, self.IMAGE_SUFFIX)))
        labels = sorted(glob.glob(os.path.join(self.data_dir, self.LABEL_SUFFIX)))

        image_array = []
        label_array = []

        for image_file, label_file in zip(images, labels):
            image_array.append(np.array(Image.open(image_file)))
            this_label = np.array(Image.open(label_file))
            if len(this_label.shape) == 2:
                label_array.append(this_label)
            elif len(this_label.shape) == 3:
                label_array.append(this_label[:, :, 0])
            else:
                print("Error")

        image_array = np.moveaxis(np.array(image_array), -1, 1)
        return image_array / 255, np.array(label_array)

    def setup(self, stage=None):

        image_array, label_array = self.load_data()

        self.image_shape = (image_array.shape[1], image_array.shape[2])

        length = image_array.shape[0]

        train_split_start = 0
        train_split_end = int(length * 0.8)
        valid_split_start = train_split_end
        valid_split_end = int(length * 0.9)
        test_split_start = valid_split_end
        test_split_end = length

        if stage == 'fit' or stage is None:
            self.train_data = NeutronDataset(image_array[train_split_start: train_split_end],
                                             label_array[train_split_start: train_split_end],
                                             transform=self.transform)
            self.valid_data = NeutronDataset(image_array[valid_split_start: valid_split_end],
                                             label_array[valid_split_start: valid_split_end],
                                             transform=self.transform)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.test_data = NeutronDataset(image_array[test_split_start: test_split_end],
                                            label_array[test_split_start: test_split_end],
                                            transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)


In [15]:
images = sorted(glob.glob("/home/robert/ds-wildfire/data/processed_data/extracted/annotations/*"))

In [21]:
labels = sorted(glob.glob("/home/robert/ds-wildfire/data/processed_data/extracted/annotations/*"))

In [23]:
images[20:30]

['/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-07-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-07-28.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-04.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-14.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-17.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-24.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-27.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-09-03.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-09-13.json',
 '/home/robert/ds-wildfire/data/proce

In [24]:
labels[20:30]

['/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-07-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-07-28.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-04.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-14.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-17.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-24.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-08-27.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-09-03.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MD_2015-09-13.json',
 '/home/robert/ds-wildfire/data/proce

In [25]:
temp = list(zip(images, labels)) 
random.shuffle(temp) 
images, labels = zip(*temp)

In [26]:
images[20:30]

('/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_ND_2016-10-30.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_NC_2016-09-07.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_NE_2015-12-22.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PE_2015-11-12.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2015-10-20.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_ME_2015-07-28.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_NC_2016-05-13.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2016-04-20.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PE_2015-07-25.json',
 '/home/robert/ds-wildfire/data/proce

In [27]:
labels[20:30]

('/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_ND_2016-10-30.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_NC_2016-09-07.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_NE_2015-12-22.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PE_2015-11-12.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2015-10-20.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_ME_2015-07-28.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_NC_2016-05-13.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2016-04-20.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PE_2015-07-25.json',
 '/home/robert/ds-wildfire/data/proce

In [18]:
images_shuffle = random.sample(images, len(images))

In [20]:
images

['/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-07-08.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-07-15.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-07-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-07-28.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-04.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-14.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-17.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-24.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-27.json',
 '/home/robert/ds-wildfire/data/proce

In [19]:
images_shuffle

['/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2015-12-19.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_NF_2015-12-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_MC_2015-08-04.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_ME_2015-07-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_PD_2016-10-14.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_ND_2016-11-06.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_S_NC_2015-07-25.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PF_2015-09-13.json',
 '/home/robert/ds-wildfire/data/processed_data/extracted/annotations/annotations_29_T_PF_2015-08-24.json',
 '/home/robert/ds-wildfire/data/proce

In [41]:
dl = QGDataLoader()
dl.setup()

In [42]:
g_max = 0
for x, y in dl.train_data:
    foo = get_one_hot(y.numpy().astype(int), 17)
    
    break

In [43]:
foo.shape

(600, 800, 17)

In [44]:
def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets)]
    return res.reshape(list(targets.shape)+[nb_classes])

In [45]:
y.shape

torch.Size([600, 800])