In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import rasterio as rio
import torch
from torchvision import transforms
from PIL import Image
import cv2 as cv2

# set random seeds
torch.manual_seed(3)
np.random.seed(3)

In [None]:
class SmokePlumeDataset(Dataset):
    def __init__(self, datadir='path/to/images', mult=1, transform=None):
        self.datadir = datadir
        self.transform = transform
        self.imgfiles = []  # list of image files
        self.labels = []    # list of image file labels

        self.positive_indices = []  # list of indices for positive examples
        self.negative_indices = []  # list of indices for negative examples

        # read in image file names
        idx = 0
        for root, dirs, files in os.walk(datadir):
            for filename in files:
                if not filename.endswith('.tif'):
                    # ignore files that are not GeoTIFFs
                    continue
                self.imgfiles.append(os.path.join(root, filename))
                if 'positive' in root:
                    # positive example (smoke plume present)
                    self.labels.append(True)
                    self.positive_indices.append(idx)
                    idx += 1
                elif 'negative' in root:
                    # negative example (no smoke plume present)
                    self.labels.append(False)
                    self.negative_indices.append(idx)
                    idx += 1

        # turn lists into arrays
        self.imgfiles = np.array(self.imgfiles)
        self.labels = np.array(self.labels)
        self.positive_indices = np.array(self.positive_indices)
        self.negative_indices = np.array(self.negative_indices)

        # increase data set size by factor `mult`
        if mult > 1:
            self.imgfiles = np.array([*self.imgfiles] * mult)
            self.labels = np.array([*self.labels] * mult)
            self.positive_indices = np.array([*self.positive_indices] * mult)
            self.negative_indices = np.array([*self.negative_indices] * mult)

    def __len__(self):
        return len(self.imgfiles)

    def __getitem__(self, idx):
      imgfile = rio.open(self.imgfiles[idx])
      imgdata = np.array([imgfile.read(i) for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13]])

      # Convert image to grayscale
      gray_imgdata = np.mean(imgdata, axis=0)

      # Scale image data
      imgdata = (gray_imgdata - np.min(gray_imgdata)) / (np.max(gray_imgdata) - np.min(gray_imgdata))
      imgdata = (imgdata * 255).astype(np.uint8)  # Scale to 0-255 range

      # Determine the save directory based on the label
      if self.labels[idx]:
        save_dir = 'path/to/save/gray_images'
      else:
        save_dir = 'path/to/save/gray_images'

      # Create the save directory if it doesn't exist
      #os.makedirs(save_dir, exist_ok=True)

      # Save grayscale image
      save_path = os.path.join(save_dir, f'image_{idx}.png')
      cv2.imwrite(save_path, imgdata)

      sample = {
        'idx': idx,
        'img': gray_imgdata,
        'lbl': self.labels[idx],
        'imgfile': self.imgfiles[idx]
      }

      if self.transform:
        sample = self.transform(sample)

      return sample

    def display(self, idx, target_size=(256, 256)):
      imgdata = self[idx]['img']

      # Convert to grayscale
      gray_imgdata = np.mean(imgdata, axis=0)

      # Scale image data
      imgdata = (gray_imgdata - np.min(gray_imgdata)) / (np.max(gray_imgdata) - np.min(gray_imgdata))

      f, ax = plt.subplots(1, 1, figsize=(3, 3))
      ax.imshow(imgdata, cmap='gray')
      ax.axis('off')

      plt.show()


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        out = {
            'idx': sample['idx'],
            'img': torch.from_numpy(sample['img'].copy()),
            'lbl': sample['lbl'],
            'imgfile': sample['imgfile']
        }
        return out
