In [1]:
# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

# Standard ML Libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

# Deep learning utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

# Torch data manipulation
import torchvision
from torchvision import transforms
import albumentations as A

# Metrics
import torchmetrics as tm

# GeoTIFF image
import rasterio as rio
from rasterio.plot import show

In [2]:
# Set device variable
DEVICE = 'gpu' if torch.cuda.is_available else 'cpu'

# Random control
RANDOM_SEED = 42

# Main path
MAIN_PATH = 'C://Users//vishal-sharma//Downloads//archive (2)//AMAZON//AMAZON'

In [3]:
def prepare_path_lists(main_path: str,
                      data_name: str):

    image_list = []
    mask_list = []
    mask_paths = []
    
    # List of possible mask dirs
    for option in ['label', 'mask', 'masks']:
        mask_paths.append(os.path.join(data_name, option))
    
    # Image dir
    image_path = os.path.join(data_name, 'image')
    
    # Loop throught all files to get proper lists
    for root, dirs, files in os.walk(main_path):
        for file in files:
            full_path = os.path.join(root, file)
            if image_path in full_path and full_path.endswith('.tif'):
                image_list.append(full_path)
            elif any(mask in full_path for mask in mask_paths) and full_path.endswith('.tif'):
                mask_list.append(full_path)
                
    return image_list, mask_list

In [4]:
# Training data
training_image_list, training_mask_list = prepare_path_lists(main_path=MAIN_PATH, data_name='Training')

# Test data
test_image_list, test_mask_list = prepare_path_lists(main_path=MAIN_PATH, data_name='Test')

# Validation data
val_image_list, val_mask_list = prepare_path_lists(main_path=MAIN_PATH, data_name='Validation')

In [None]:
train_df = pd.DataFrame(list(zip(training_image_list, training_mask_list)), 
                        columns=['image_path', 'mask_path'])

test_df = pd.DataFrame(list(zip(test_image_list, test_mask_list)), 
                        columns=['image_path', 'mask_path'])

val_df = pd.DataFrame(list(zip(val_image_list, val_mask_list)), 
                        columns=['image_path', 'mask_path'])

print(f'Training data contains: {len(train_df)} files')
print(f'Test data contains: {len(test_df)} files')
print(f'Validation data contains: {len(val_df)} files')