In [1]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import os
import glob
import re
import imageio.v2 as imageio  # Explicitly use version 2 API
from torch.utils.data import get_worker_info
from sklearn.preprocessing import MinMaxScaler
import torch.utils.data as data
from torch.amp import autocast, GradScaler
import rasterio
from rasterio.transform import from_origin
from rasterio.crs import CRS


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# inputs
transaction_ID = '200010' 
threshold = 0.5

In [3]:
# Set device (use GPU if available, otherwise fallback to CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [4]:
# Define U-Net model
ENCODER = 'efficientnet-b7'
#ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['solar_panel']
ACTIVATION = 'sigmoid'

model = smp.Unet(
    in_channels = 4, #4 for all bands
    encoder_name=ENCODER, 
    encoder_weights=None,  # No pretraining, since we are loading trained weights
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

# Move model to device
model = model.to(device)


In [5]:
# Load trained weights
#Regular input images (256x256)
#weights_path = os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection","models", "weights", "u-net_efficientnet-b7_v1", "unet-seed23_weights.pth")
#BiCubic Inter images x2 (512x512)
weights_path = os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection","models", "weights", "u-net_efficientnet-b7_vBiC_intx2", "unet-seed23_wDA&Int_weights.pth")

if os.path.exists(weights_path):
    model.load_state_dict(torch.load(weights_path, map_location=device))
    print("Model weights loaded successfully.")
else:
    print("Error: Weights file not found!")
    
# Set model to evaluation mode
model.eval()


Model weights loaded successfully.


Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      4, 64, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          64, 64, kernel_size=(3, 3), stride=[1, 1], groups=64, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          64, 16, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          16, 64, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePaddi

In [6]:
# Function to extract the numeric value from filenames
def numeric_sort_key(filepath):
    match = re.search(r'\d+', filepath)
    return int(match.group()) if match else 0
    

In [7]:
# Get all the input images and sort them numerically

#OLD, getting images from acquisition (without image Enhancement / SR)
# folder_data_input = sorted(
#     glob.glob(os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection","acquisition", transaction_ID, "*tif")),
#     key=numeric_sort_key
# )
folder_data_input = sorted(
    glob.glob(os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection","image_enhancement", transaction_ID, "*tif")),
    key=numeric_sort_key
)



input_image_paths = folder_data_input[:]

In [8]:
print(input_image_paths)
print(120*"-")
print(f'Number of files in the acquisition ID "{transaction_ID}": {len(input_image_paths)}')

['/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_000_000.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_001_000.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_002_000.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_003_000.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_000_001.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_001_001.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_002_001.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_003_001.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_000_002.tif', '/home/sagemaker-user/satellite-ml-solarp-detection/image_enhancement/200010/200010_001_002.tif', '/home/sagemaker-us

In [9]:
def plot_histograms(image_paths, num_samples=5):

    """
    Plot histograms for each band (R, G, B, NIR) for a given sample of images.
    """
    red_values = []
    green_values = []
    blue_values = []
    nir_values = []
    
    # Iterate over a subset of images
    for img_path in image_paths[:num_samples]:  
        image = imageio.imread(img_path)
        
        # Separate the bands
        red_values.extend(image[:, :, 0].flatten())
        green_values.extend(image[:, :, 1].flatten())
        blue_values.extend(image[:, :, 2].flatten())
        nir_values.extend(image[:, :, 3].flatten())
    
    # Plot histograms
    plt.figure(figsize=(16, 8))
    
    plt.subplot(2, 2, 1)
    plt.hist(red_values, bins=50, color='red', alpha=0.7)
    plt.title("Red Band Histogram")
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Frequency")
    
    plt.subplot(2, 2, 2)
    plt.hist(green_values, bins=50, color='green', alpha=0.7)
    plt.title("Green Band Histogram")
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Frequency")
    
    plt.subplot(2, 2, 3)
    plt.hist(blue_values, bins=50, color='blue', alpha=0.7)
    plt.title("Blue Band Histogram")
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Frequency")
    
    plt.subplot(2, 2, 4)
    plt.hist(nir_values, bins=50, color='purple', alpha=0.7)
    plt.title("NIR Band Histogram")
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Frequency")
    
    plt.tight_layout()
    plt.show()


In [10]:
# Call the function with your training image paths
#plot_histograms(input_image_paths, num_samples=10)

In [9]:
class CustomDataset(data.Dataset):
    def __init__(self, image_paths, transform=None, band=None):

        self.image_paths = image_paths
        self.transform = transform
        self.band = band # Specify which band to use (0: R, 1: G, 2: B, 3: NIR, None: all bands)
        self.scaler = MinMaxScaler() 
                
    def __getitem__(self, index):
        try:
            image = imageio.imread(self.image_paths[index]).astype(np.float32)
    
            # Select a specific band if specified
            if self.band is not None:
                image = image[:, :, self.band] #Select only the specified band
                image = image[:, :, np.newaxis]
                
            # # Normalize the image
            # image_reshaped = image.reshape(-1, image.shape[-1])
            # image_scaled = self.scaler.fit_transform(image_reshaped)
            # image = image_scaled.reshape(image.shape)
    
            # Reshape for MinMaxScaler and apply normalization
            image_reshaped = image.reshape(-1, 4)
            image_scaled = self.scaler.fit_transform(image_reshaped)
            image = image_scaled.reshape(image.shape)
            
            # Apply the transformation to both image and mask if self.transform is set
            if self.transform:
                image = self.transform(image)  # Pass both to transform if synchronized
            return image, self.image_paths[index]    # Return image and corresponding filename
            
        except Exception as e:
            print(f"Error loading data at index {index}: {e}")
            raise e

    def __len__(self):
        return len(self.image_paths)

class ToTensor:
    def __call__(self, image):
        return torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # Convert image to [C, H, W]
        

In [10]:
# Initialize the CustomDataset objects
use_all_bands = True
model_band = None # 0:red, 1:green, 2:blue, 3:NIR, None:all 4 bands

input_dataset = CustomDataset(input_image_paths, transform=ToTensor(), band=model_band)

image = input_dataset[0]  # Load the first item
print(f"Transformed image shape: {image[0].shape}, dtype: {image[0].dtype}")


Transformed image shape: torch.Size([4, 512, 512]), dtype: torch.float32


In [13]:
def visualize_images(dataset, num_samples=3):
    for i in range(num_samples):
        image = dataset[i]
        
        # Display input image and mask side-by-side
        plt.figure(figsize=(10, 5))
        
        plt.subplot(1, 2, 1)
        plt.imshow(image[0].permute(1, 2, 0))  # Convert CHW to HWC for display
        plt.title("Input Image")
        
        plt.show()

In [14]:
#visualize_images(input_dataset, num_samples=15)



In [11]:
# Replace DataLoader setup for training, validation, and testing
input_loader = torch.utils.data.DataLoader(input_dataset, batch_size=40, shuffle=False, num_workers=0)


In [12]:
# Create directory for saving predictions

output_dir = os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection", "prediction", transaction_ID)
os.makedirs(output_dir, exist_ok=True)

In [13]:
# Prediction loop
with torch.no_grad():
    for batch_idx, (images, image_paths) in enumerate(input_loader):
        images = images.to(device)
        predictions = model(images)
        predictions = (predictions > threshold).int().cpu().numpy()
        
        # Save each prediction with CRS information
        for i in range(len(images)):
            filename = os.path.basename(image_paths[i])
            save_path = os.path.join(output_dir, f"pred_{filename}")
            pred_image = predictions[i].squeeze().astype(np.float32)
            
            # Read CRS and transform from source image
            with rasterio.open(image_paths[i]) as src:
                profile = src.profile.copy()
                profile.update(
                    dtype=rasterio.float32,
                    count=1,
                    compress='lzw'
                )
                
                with rasterio.open(save_path, 'w', **profile) as dst:
                    dst.write(pred_image, 1)
                    print(f"Saved with CRS: {save_path}")

print(f"All predictions saved in {output_dir}")

# # Prediction loop
# with torch.no_grad():
#     for batch_idx, (images, image_paths) in enumerate(input_loader):
#         images = images.to(device)
#         predictions = model(images)
#         predictions = (predictions > threshold).int().cpu().numpy()
        
#         # Save each prediction
#         for i in range(len(images)):
#             filename = os.path.basename(image_paths[i])
#             save_path = os.path.join(output_dir, f"pred_{filename}")
#             pred_image = predictions[i].squeeze().astype(np.float32)
#             imageio.imwrite(save_path, pred_image)
#             print(f"Saved: {save_path}")

# print(f"All predictions saved in {output_dir}")


Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_000_000.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_001_000.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_002_000.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_003_000.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_000_001.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_001_001.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_002_001.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_003_001.tif
Saved with CRS: /home/sagemaker-user/satellite-ml-solarp-detection/prediction/200010/pred_200010_000_002.tif
Saved with CRS: /ho