In [3]:
import os
from tqdm import tqdm
import pickle

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import torch 
from torchvision.io import read_image 
from torch.utils.data import Dataset, random_split

# Getting Train and Test Indicies

In [None]:
# making dataset class to get train test splits
class NaiveEuroSATDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.sorted_class_names = sorted(os.listdir(self.data_path))
        self.num_classes = len(os.listdir(self.data_path))
        self.num_img_per_class = torch.zeros(self.num_classes, dtype=torch.int)
        # getting cumsum number of images per class sorted alphabetically        
        for i, land_class in enumerate(self.sorted_class_names):
            self.num_img_per_class[i] = len(os.listdir(os.path.join(self.data_path, land_class)))
        
    def __len__(self):
        return torch.sum(self.num_img_per_class)

    def __getitem__(self, idx):
        # calculating which class folder to read from
        idx_diff = self.cumsum_img_per_class - idx
        class_idx = torch.sum(idx_diff <= 0)
        # recalculating index if going to other folders
        if class_idx != 0:
            idx = idx - self.cumsum_img_per_class[class_idx - 1]
        
        # getting image tensor and class name
        class_name = self.sorted_class_names[class_idx]
        class_path = os.path.join(self.data_path, class_name)
        img_name = os.listdir(class_path)[idx]
        img_path = os.path.join(class_path, img_name)
        img = read_image(img_path)
        
        sample = {'image': img, 'land_use': class_idx}
        return sample

In [None]:
# path to EuroSAT dataset
data_path = '../EuroSAT_RGB'

In [None]:
# splitting train and test data and only computing mean and std of train data
naive_eurosat = NaiveEuroSATDataset(data_path)
generator = torch.Generator().manual_seed(0)
train_val_set, test_set = random_split(naive_eurosat, [0.8, 0.2], generator = generator)
train_val_idx = train_val_set.indices
test_idx = test_set.indices

# Calculating Mean and SD

In [None]:
# do not run again if preprocessing statistics already saved in preprocessing folder
already_preprocessed = True

In [None]:
def calculate_channel_mean():
    # calculating mean over each RGB channel
    means_sum = torch.zeros(3)
    num_img = 0
    # iterate over each class folder inside dataset
    for land_class in os.listdir(data_path):
        print(land_class)
        land_class_path = os.path.join(data_path, land_class)
        # iterate over each image for each class
        for file in tqdm(os.listdir(land_class_path)):
            if num_img in test_idx:
                num_img += 1
                continue
            img_path = os.path.join(land_class_path, file)
            img = read_image(img_path).to(torch.float64) 
            means_sum += torch.mean(img, dim=(1,2))
            num_img += 1
        print('==========================================')
    channel_mean = means_sum / num_img
    return channel_mean, num_img

In [None]:
# calculates average of squared deviations across channels
def sample_var(x, channel_mean):
    return torch.mean((x - channel_mean[:, None, None])**2, dim = (1,2))

def calculate_channel_sd(channel_mean, num_img):
    num_img = 0
    # calculating variance over each RGB channel
    vars_sum = torch.zeros(3)
    # iterate over each class folder inside dataset
    for land_class in os.listdir(data_path):
        print(land_class)
        land_class_path = os.path.join(data_path, land_class)
        # iterate over each image for each class
        for file in tqdm(os.listdir(land_class_path)):
            if num_img in test_idx:
                num_img += 1 
                continue
            img_path = os.path.join(land_class_path, file)
            img = read_image(img_path).to(torch.float64) 
            vars_sum += sample_var(img, channel_mean)
            num_img += 1
        print('==========================================')
    # take square root to get standard deviation
    channel_sd = torch.sqrt(vars_sum / num_img)
    return(channel_sd)

In [None]:
# run or load preprocessing statistics
if not already_preprocessed:
    channel_mean, num_img = calculate_channel_mean()
    channel_sd = calculate_channel_sd(channel_mean, num_img)
    
    preprocessing_stats = {
        'mean': channel_mean, 
        'sd': channel_sd, 
        'num_img': num_img
    }
    with open('./preprocessing_stats.p', 'wb') as f:
        pickle.dump(preprocessing_stats, f)
else:
    with open('./preprocessing_stats.pkl', 'rb') as f:
        preprocessing_stats = pickle.load(f)
    channel_mean = preprocessing_stats['mean']
    channel_sd = preprocessing_stats['sd']
    num_img = preprocessing_stats['num_img']

FileNotFoundError: [Errno 2] No such file or directory: './preprocessing_stats.p'