In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
import random
import numpy as np
from tqdm import tqdm
import argparse
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import rasterio
from model import *
from dataset import *
from utils import *
import sys
import rasterio
from rasterio.merge import merge
from rasterio.plot import show
import os
import glob
from sklearn.metrics import precision_score, recall_score, f1_score

In [31]:
run = wandb.init(project="Gully_Classification")
artifact = run.use_artifact('tousi-team/Gully_Classification/model_epoch_100:v0', type='model')
artifact_dir = artifact.download("./artifacts/models/")

[34m[1mwandb[0m: Currently logged in as: [33mstmmc[0m ([33mtousi-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model_epoch_100:v0, 114.00MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:16.2


In [2]:
class RGB_RasterTilesDataset_Geo(Dataset):
    def __init__(self, dem_dir, so_dir, rgb_dir, transform=None):
        self.dem_dir = dem_dir
        self.so_dir = so_dir
        self.rgb_dir = rgb_dir
        self.transform = transform
        # Assume all DEM, SO, and RGB files share the same tile identifiers
        self.tile_identifiers = [f.split('_')[-1].split('.')[0] for f in os.listdir(dem_dir) if 'dem_tile' in f]

    def __len__(self):
        return len(self.tile_identifiers)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        tile_id = self.tile_identifiers[idx]
        dem_file = os.path.join(self.dem_dir, f'dem_tile_{tile_id}.tif')
        so_file = os.path.join(self.so_dir, f'dem_tile_{tile_id}.tif')
        rgb_files = [os.path.join(self.rgb_dir, f'rgb{k}_tile_{tile_id}.tif') for k in range(6)]

        # Prepare sample dictionary
        sample = {}

        # Read DEM file and extract the transform
        with rasterio.open(dem_file) as src:
            dem_image = src.read(1)  # Read the first band
            dem_transform = src.transform
            sample['DEM'] = dem_image
            sample['DEM_transform'] = dem_transform

        # Read SO file and extract the transform
        with rasterio.open(so_file) as src:
            so_image = src.read(1)
            so_transform = src.transform
            sample['SO'] = so_image
            sample['SO_transform'] = so_transform

        # Read RGB files and extract their transforms
        rgb_images = []
        rgb_transforms = []
        for file in rgb_files:
            with rasterio.open(file) as src:
                rgb_images.append(src.read([1, 2, 3]))  # Read the RGB bands
                rgb_transforms.append(src.transform)
        sample['RGB'] = rgb_images
        sample['RGB_transforms'] = rgb_transforms

        if self.transform:
            sample = self.transform(sample)

        return sample

In [3]:
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
from rasterio.transform import from_origin

def save_as_geotiff(data, profile, output_path, single_value=False, image_size=128):
    """
    Save a numpy array or single value as a GeoTIFF file with a given raster profile.
    
    :param data: Numpy array to save (2D or 3D array where 3D includes bands) or a single value.
    :param profile: Dictionary containing raster metadata like crs, transform, etc.
    :param output_path: Path where the GeoTIFF file will be saved
    """
    # Check if data is a single value and create a 2D array with the same value
    if single_value:

        data = np.full((image_size, image_size), data, dtype=profile.get('dtype', 'float32'))

    # Update the profile to accommodate the data dimensions
    profile.update({
        'dtype': data.dtype,
        'height': data.shape[0],
        'width': data.shape[1],
        'count': 1 if data.ndim == 2 else data.shape[0],  # Assuming bands are the first dimension if 3D
        'driver': 'GTiff',
        'nodata': None  # Set this to the appropriate nodata value if required
    })
#     print(profile)

    # Save the data to a GeoTIFF file
    with rasterio.open(output_path, 'w', **profile) as dst:
        if data.ndim == 2:
#             print('2')
            dst.write(data, 1)  # Write data as the first band
        else:
            for i in range(data.shape[0]):
                dst.write(data[i], i + 1)  # Write each band data
                
def calculate_stats(folder_path):
    means = []
    stds = []
    
    # Loop through all files in the directory
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(".tif"):
            file_path = os.path.join(folder_path, filename)
            # Open the TIFF file
            with rasterio.open(file_path) as src:
                # Read data, assuming it's a single band
                array = src.read(1)
                # Combine conditions for NoData and zero values
                if src.nodata is not None:
                    mask = (array != src.nodata) & (array != 0)
                else:
                    mask = (array != 0)
                
                # Apply mask
                valid_data = array[mask]
                
                # Calculate mean and std dev and append to lists if valid data exists
                if valid_data.size > 0:
                    means.append(np.mean(valid_data))
                    stds.append(np.std(valid_data))
                else:
                    print(f"Warning: No valid data in file {filename} after masking. Skipping statistics.")

    # Calculate overall statistics
    overall_mean = np.mean(means) if means else 0
    overall_std = np.mean(stds) if stds else 0

    return overall_mean, overall_std

In [66]:
field_name = 'Gotman'

In [5]:
# dem_dir = f'/home/macula/SMATousi/Gullies/SO_Paper/data/{field_name}/dem/'
# so_dir = f'/home/macula/SMATousi/Gullies/SO_Paper/data/{field_name}/dem/'
# rgb_dir = f'/home/macula/SMATousi/Gullies/SO_Paper/data/{field_name}/rgb/'
# pretrained_model_path = '/home/macula/SMATousi/cluster/docker-images/dem2so_more_data/pre_models/B3_rn50_moco_0099_ckpt.pth'

dem_dir = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/dem/'
so_dir = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/dem/'
rgb_dir = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/rgb/'
pretrained_model_path = '/home/macula/SMATousi/cluster/docker-images/dem2so_more_data/pre_models/B3_rn50_moco_0099_ckpt.pth'


mean, std = calculate_stats(dem_dir)
print("Overall Mean:", mean)
print("Overall Standard Deviation:", std)

class RGB_RasterTransform_Geo:
    """
    A custom transform class for raster data.
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        dem, so, rgb = sample['DEM'], sample['SO'], sample['RGB']
        dem_meta, so_meta, rgb_meta = sample['DEM_transform'], sample['SO_transform'], sample['RGB_transforms']

        # Random horizontal flipping
        # if torch.rand(1) > 0.5:
        #     dem = TF.hflip(dem)
        #     so = TF.hflip(so)

        # # Random vertical flipping
        # if torch.rand(1) > 0.5:
        #     dem = TF.vflip(dem)
        #     so = TF.vflip(so)

        # Convert numpy arrays to tensors
        dem = TF.to_tensor(dem)
        so = TF.to_tensor(so)
        rgb_images = [TF.to_tensor(image) for image in rgb]
        float_rgb_images = [image.float() for image in rgb_images]
        # rgb_images = rgb_images.float()

        dem = TF.normalize(dem, mean, std)

        so = so.long()

        return {'DEM': dem, 'SO': so.squeeze(), 'RGB': float_rgb_images,
                'DEM_transform' : dem_meta, 'SO_transform' : so_meta, 'RGB_transforms' : rgb_meta}

transform = RGB_RasterTransform_Geo()
    
dataset = RGB_RasterTilesDataset_Geo(dem_dir=dem_dir, so_dir=so_dir, rgb_dir=rgb_dir, transform=transform)
# dataset = RGB_RasterTilesDataset_Geo(dem_dir=dem_dir, so_dir=so_dir, rgb_dir=rgb_dir)

batch_size = 1
learning_rate = 0.0001
epochs = 1
number_of_workers = 0
image_size = 128
val_percent = 0.0

n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=number_of_workers, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=number_of_workers, pin_memory=True, drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

model = Gully_Classifier(input_size=6*2048, hidden_size=512, output_size=1).to(device)

# model = RGB_DEM_to_SO(resnet_output_size=(8, 8), 
#                             fusion_output_size=(128, 128), 
#                             model_choice = "Unet_1", 
#                             resnet_saved_model_path=pretrained_model_path,
#                             dropout_rate=0.5).to(device)

from torch.optim import Adam
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

state_dict = torch.load('./artifacts/models/model_epoch_100.pth')
# state_dict_new = torch.load('./artifacts/new_loss/model_epoch_600.pth')

new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

print(len(train_loader))

model.eval()

for i, batch in enumerate(tqdm(train_loader)):
# for i, batch in enumerate(train_loader):
    
    dem = batch['DEM'].to(device)
    so = batch['SO'].to(device)
    rgbs = [batch['RGB'][k].to(device) for k in range(6)]

    permute_rgbs = [torch.permute(image,(0,2,1,3)) for image in rgbs]

    
    output = model(permute_rgbs)
    
#     print(output)
    
#     break
    
#     val_outputs = model(dem, permute_rgbs)

#     pred = F.softmax(val_outputs, dim=1)              
#     pred = torch.argmax(pred, dim=1).squeeze(1)
    
    profile = {
    'transform': batch['SO_transform'],  # Example values: (west, north, xsize, ysize)
    'crs': 'EPSG:4326',  # Standard WGS84 CRS
    }

    # Path to save the GeoTIFF file
    root_path = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/classification_results'
    os.makedirs(root_path, exist_ok=True)
    output_file_path = f'{root_path}/{i}.tif'
    

    # Call the function to save the file
#     print(output.float().cpu().detach().numpy().squeeze())
    save_as_geotiff(output.float().cpu().detach().numpy().squeeze(), profile, output_file_path, single_value=True)
    
#     break


Overall Mean: 226.09827
Overall Standard Deviation: 1.8296707
cuda


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


3696


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 3696/3696 [15:47<00:00,  3.90it/s]


In [59]:
profile

{'transform': Affine(tensor([1.0000e-05], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([-91.0065], dtype=torch.float64),
        tensor([0.], dtype=torch.float64), tensor([-1.0000e-05], dtype=torch.float64), tensor([39.1192], dtype=torch.float64)),
 'crs': 'EPSG:4326'}

In [6]:
def merge_tiffs(input_folder, output_filepath):
    """
    Merge multiple GeoTIFF files into a single larger TIFF file.

    :param input_folder: Folder containing all TIFF files to merge.
    :param output_filepath: Path to save the merged TIFF file.
    """
    # Search for TIFF files in the folder
    search_criteria = "*.tif"
    query = os.path.join(input_folder, search_criteria)
    tif_files = glob.glob(query)

    # List to hold open datasets
    src_files_to_mosaic = []

    # Open and append each TIFF file to the list
    for filepath in tif_files:
        src = rasterio.open(filepath)
        src_files_to_mosaic.append(src)

    # Merge function from rasterio
    mosaic, out_trans = merge(src_files_to_mosaic)

    # Copy the metadata
    out_meta = src_files_to_mosaic[0].meta.copy()

    # Update the metadata to reflect the number of layers
    out_meta.update({
        "driver": "GTiff",
        "height": mosaic.shape[1],
        "width": mosaic.shape[2],
        "transform": out_trans,
        "crs": src_files_to_mosaic[0].crs
    })

    # Write the mosaic raster to the new file
    with rasterio.open(output_filepath, "w", **out_meta) as dest:
        for i in range(1, mosaic.shape[0]+1):
            dest.write(mosaic[i-1], i)

    # Close all rasterio opened files
    for src in src_files_to_mosaic:
        src.close()

    print("Merge completed successfully. Output saved at:", output_filepath)
    
input_folder = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/classification_results'  # Update this path to your folder path
output_filepath = f'/home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/merged_classification_results.tif'  # Set your output file path

merge_tiffs(input_folder, output_filepath)

Merge completed successfully. Output saved at: /home/macula/SMATousi/Gullies/ground_truth/organized_data/tiled_HUCs/HUC_070801030408/merged_classification_results.tif


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "




AttributeError: 'list' object has no attribute 'lower'