In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import rasterio as rio
import torch
from torchvision import transforms
from PIL import Image
import cv2 as cv2

# set random seeds
torch.manual_seed(3)
np.random.seed(3)

In [None]:
class SmokePlumeDataset(Dataset):
    def __init__(self, datadir='path/to/gray_images', mult=1, transform=None):
        self.datadir = datadir
        self.transform = transform
        self.imgfiles = []  # list of image files
        self.labels = []    # list of image file labels
        self.positive_indices = []  # list of indices for positive examples
        self.negative_indices = []  # list of indices for negative examples

        # read in image file names
        idx = 0
        for root, dirs, files in os.walk(datadir):
            for filename in files:
                if not filename.endswith('.png'):
                    # ignore files that are not PNGs
                    continue
                self.imgfiles.append(os.path.join(root, filename))
                if 'positive' in root:
                    # positive example (smoke plume present)
                    self.labels.append(True)
                    self.positive_indices.append(idx)
                    idx += 1
                elif 'negative' in root:
                    # negative example (no smoke plume present)
                    self.labels.append(False)
                    self.negative_indices.append(idx)
                    idx += 1

        # turn lists into arrays
        self.imgfiles = np.array(self.imgfiles)
        self.labels = np.array(self.labels)
        self.positive_indices = np.array(self.positive_indices)
        self.negative_indices = np.array(self.negative_indices)

        # increase data set size by factor `mult`
        if mult > 1:
            self.imgfiles = np.array([*self.imgfiles] * mult)
            self.labels = np.array([*self.labels] * mult)
            self.positive_indices = np.array([*self.positive_indices] * mult)
            self.negative_indices = np.array([*self.negative_indices] * mult)

    def __len__(self):
        return len(self.imgfiles)

    def __getitem__(self, idx):
        imgfile = self.imgfiles[idx]
        imgdata = cv2.imread(imgfile, cv2.IMREAD_GRAYSCALE)

        sample = {
            'idx': idx,
            'img': imgdata,
            'lbl': self.labels[idx],
            'imgfile': imgfile
        }

        if self.transform:
            sample = self.transform(sample)

        return sample

    def display(self, idx, target_size=(256, 256)):
        imgdata = self[idx]['img']

        # Scale image data
        imgdata = (imgdata - np.min(imgdata)) / (np.max(imgdata) - np.min(imgdata))

        f, ax = plt.subplots(1, 1, figsize=(3, 3))
        ax.imshow(imgdata, cmap='gray')
        ax.axis('off')

        plt.show()


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        out = {
            'idx': sample['idx'],
            'img': torch.from_numpy(sample['img'].copy()).unsqueeze(0),  # Add extra dimension for channel
            'lbl': sample['lbl'],
            'imgfile': sample['imgfile']
        }
        return out

class Normalize(object):
    """Normalize pixel values to zero mean and range [-1, +1] measured in standard deviations."""
    def __init__(self):
        self.channel_mean = 127.5  # Assuming the range is 0-255
        self.channel_std = 127.5

    def __call__(self, sample):
        sample['img'] = (sample['img'] - self.channel_mean) / self.channel_std
        return sample

def create_dataset(*args, apply_transforms=True, **kwargs):
    if apply_transforms:
        data_transforms = transforms.Compose([
            Normalize(),
            ToTensor()
        ])
    else:
        data_transforms = None

    data = SmokePlumeDataset(*args, **kwargs, transform=data_transforms)
    return data
