In [1]:
import os 
import random 
random.seed(97)
import scipy.io as io 
import numpy as np 
import cv2 
import matplotlib.pyplot as plt

In [2]:
import torch
from torch.utils.data import Dataset 

In [56]:
class HSI_Dataset_Train(Dataset):
    def __init__(self,train_data_dir,input_image_shape,data_transforms=None):
        self.train_data_dir = train_data_dir
        self.data_transforms = data_transforms
        self.rgb_path = self.train_data_dir+"rgb/"
        self.hsi_path = self.train_data_dir + "hsi/"
        self.input_image_shape = input_image_shape
        self.input_image_size = self.input_image_shape[0]
        
        
        #Get file names
        self.img_root_names = []
        for img_name in os.listdir(self.hsi_path):
            self.img_root_names.append(img_name.split(".mat")[0])

        self.length = len(self.img_root_names)
        
        #generate file names for spectral cubes and jpg images
        self.rgb_image_files = []
        self.hsi_image_files = []
        for img_root_name in self.img_root_names:
            self.rgb_image_files.append(img_root_name+"_RealWorld.jpg")
            self.hsi_image_files.append(img_root_name+".mat")
    
    def __getitem__(self, index):
        rgb_img_path = self.rgb_path + self.rgb_image_files[index]
        hsi_img_path = self.hsi_path + self.hsi_image_files[index]
        img_root_name = self.img_root_names[index]

        hsi_img = io.loadmat(hsi_img_path)["cube"]
        rgb_img = cv2.imread(rgb_img_path,-1)

        rgb_img = rgb_img.astype(np.float64)/255.0 #normalisation to [0,1]

        #crop a patch of the scene
        h,w = rgb_img.shape[:2]
        if (h>self.input_image_size) or (w>self.input_image_size):
            rand_h = random.randint(0,h-self.input_image_size)
            rand_w = random.randint(0,w-self.input_image_size)

            rgb_img = rgb_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            hsi_img = hsi_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            
        rgb_img = torch.from_numpy(rgb_img.astype(np.float32).transpose(2, 0, 1)).contiguous() #to tensor
        hsi_img = torch.from_numpy(hsi_img.astype(np.float32).transpose(2, 0, 1)).contiguous()
            
        return rgb_img, hsi_img, img_root_name
    
    def __len__(self):
        return self.length

In [None]:
class HSI_Dataset_Train(Dataset):
    def __init__(self,train_data_dir,input_image_shape,data_transforms=None):
        self.train_data_dir = train_data_dir
        self.data_transforms = data_transforms
        self.rgb_path = self.train_data_dir+"rgb/"
        self.hsi_path = self.train_data_dir + "hsi/"
        self.input_image_shape = input_image_shape
        self.input_image_size = self.input_image_shape[0]
        
        
        #Get file names
        self.img_root_names = []
        for img_name in os.listdir(self.hsi_path):
            self.img_root_names.append(img_name.split(".mat")[0])

        self.length = len(self.img_root_names)
        
        #generate file names for spectral cubes and jpg images
        self.rgb_image_files = []
        self.hsi_image_files = []
        for img_root_name in self.img_root_names:
            self.rgb_image_files.append(img_root_name+"_RealWorld.jpg")
            self.hsi_image_files.append(img_root_name+".mat")
    
    def __getitem__(self, index):
        rgb_img_path = self.rgb_path + self.rgb_image_files[index]
        hsi_img_path = self.hsi_path + self.hsi_image_files[index]
        img_root_name = self.img_root_names[index]

        hsi_img = io.loadmat(hsi_img_path)["cube"]
        rgb_img = cv2.imread(rgb_img_path,-1)

        rgb_img = rgb_img.astype(np.float64)/255.0 #normalisation to [0,1]

        #crop a patch of the scene
        h,w = rgb_img.shape[:2]
        if (h>self.input_image_size) or (w>self.input_image_size):
            rand_h = random.randint(0,h-self.input_image_size)
            rand_w = random.randint(0,w-self.input_image_size)

            rgb_img = rgb_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            hsi_img = hsi_img[rand_h:rand_h+self.input_image_size, rand_w:rand_w+self.input_image_size, :] 
            
        rgb_img = torch.from_numpy(rgb_img.astype(np.float32).transpose(2, 0, 1)).contiguous() #to tensor
        hsi_img = torch.from_numpy(hsi_img.astype(np.float32).transpose(2, 0, 1)).contiguous()
            
        return rgb_img, hsi_img, img_root_name
    
    def __len__(self):
        return self.length