In [8]:
from osgeo import gdal
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch import permute
from torch import nan_to_num
from torchvision.transforms import ToTensor
import rasterio
import numpy as np

In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


Using device: mps


In [51]:

import os
import random

def check_maskfolder(mask_folder):
    # List all GeoTIFF files in the mask folder
    tiff_files = [f for f in os.listdir(mask_folder) if f.lower().endswith((".tif", ".tiff"))]
    if not tiff_files:
        print("No TIFF files in folder:", mask_folder)
        return
    # Compute statistics for all TIFF files in the mask folder
    all_class_counts = {}
    total_pixels = 0
    shapes = []
    for tiff_file in tiff_files:
        tiff_path = os.path.join(mask_folder, tiff_file)
        with rasterio.open(tiff_path) as src:
            mask_data = src.read(1)
            shapes.append(mask_data.shape)
        unique, counts = np.unique(mask_data, return_counts=True)
        for cls, cnt in zip(unique, counts):
            all_class_counts[cls] = all_class_counts.get(cls, 0) + cnt
        total_pixels += mask_data.size

    print(f"Processed {len(tiff_files)} mask files.")
    if shapes:
        first_shape = shapes[0]
        all_same_shape = all(shape == first_shape for shape in shapes)
        print(f"All images have same shape: {all_same_shape}. Example shape: {first_shape}")
    print("Aggregated class statistics for the folder:")
    for cls, cnt in sorted(all_class_counts.items()):
        print(f"  Class {cls}: {cnt} pixels ({(cnt/total_pixels)*100:.2f}%)")
    print(f"Total pixels in all masks: {total_pixels}")


In [59]:
check_maskfolder("/Users/omaralshatti/Documents/KuwaitRaster/Dataset/tiles/imagenet_full_size/061417/train/class")

Processed 1548 mask files.
All images have same shape: False. Example shape: (512, 512)
Aggregated class statistics for the folder:
  Class 0: 7795452 pixels (1.93%)
  Class 10: 868515 pixels (0.21%)
  Class 20: 1695024 pixels (0.42%)
  Class 30: 1780776 pixels (0.44%)
  Class 40: 4433244 pixels (1.10%)
  Class 50: 15479553 pixels (3.83%)
  Class 60: 371333865 pixels (91.90%)
  Class 70: 2792 pixels (0.00%)
  Class 80: 565065 pixels (0.14%)
  Class 90: 112018 pixels (0.03%)
Total pixels in all masks: 404066304


In [None]:
# import the necessary packages
class SegmentationDataset(Dataset):
	def __init__(self, imagePaths, maskPaths, transforms):
		# store the image and mask filepaths, and augmentation
		# transforms
		self.imagePaths = imagePaths
		self.maskPaths = maskPaths
		self.transforms = transforms
	def __len__(self):
		# return the number of total samples contained in the dataset
		return len(self.imagePaths)
	def __getitem__(self, idx):
		# grab the image path from the current index
		imagePath = self.imagePaths[idx]
		image = rasterio.open(imagePath)
		image = image.read()
		image = ToTensor()(image)
		image = permute(image, (1,2,0))
		#print(image.shape)
		
		mask = rasterio.open(self.maskPaths[idx])
		mask = mask.read()
		mask = ToTensor()(mask)
		nan_to_num(mask, nan=0.0)
		#print(mask.shape)
		mask = permute(mask, (1,2,0))
        
		# check to see if we are applying any transformations
		if self.transforms is not None:
			# apply the transformations to both image and its mask
			image = self.transforms(image)
			mask = self.transforms(mask)
		# return a tuple of the image and its mask
		return (image, mask)