In [6]:
import os
import numpy as np
import rasterio
from rasterio.plot import reshape_as_image

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, random_split

from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

In [7]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(11, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 224 -> 112

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 112 -> 56

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 56 -> 28

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 28 -> 14
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 14 * 14, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [8]:
class SmallCNN4(nn.Module):
    def __init__(self, input_channels=10, num_classes=2):
        super(SmallCNN4, self).__init__()

        # First CNN Layer
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # Downsample by 2

        # Depthwise Separable Convolution
        self.depthwise = nn.Conv2d(32, 32, kernel_size=3, padding=1, groups=32)
        self.pointwise = nn.Conv2d(32, 64, kernel_size=1)

        # Global Average Pooling to (2, 2)
        self.gap = nn.AdaptiveAvgPool2d((4, 4))

        # Fully Connected Layer
        self.fc = nn.Linear(64 * 4 * 4, num_classes)  # 64 channels with spatial size (2, 2)

    def forward(self, x):
        # First CNN Layer
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)

        # Depthwise Separable Convolution
        x = F.relu(self.pointwise(self.depthwise(x)))

        # Global Average Pooling
        x = self.gap(x).view(x.size(0), -1)

        # Fully Connected Layer
        x = self.fc(x)

        return x

In [9]:
# make prediction map
# Define custom Dataset for the TIFF image


# Define custom Dataset for the TIFF image
class TiffDataset(Dataset):
    def __init__(self, image, tile_size, stride):
        """
        Args:
            image (np.ndarray): Input multi-band image (H, W, C).
            tile_size (int): Size of the patches.
            stride (int): Stride for sliding window.
        """
        self.image = image
        self.tile_size = tile_size
        self.stride = stride
        self.tiles = []
        
        # Generate tiles with sliding window
        for y in range(0, image.shape[0] - tile_size + 1, stride):
            for x in range(0, image.shape[1] - tile_size + 1, stride):
                self.tiles.append((y, x))
    
    def __len__(self):
        return len(self.tiles)
    
    def __getitem__(self, idx):
        """
        Returns:
            torch.Tensor: Normalized patch of shape (10, tile_size, tile_size).
        """
        y, x = self.tiles[idx]
        patch = self.image[y:y + self.tile_size, x:x + self.tile_size, :]

        # Extract the bands
        blue, green, red, re, nir, swir = patch[..., 0], patch[..., 1], patch[..., 2], patch[..., 3], patch[..., 4], patch[..., 5]

        # Calculate indices
        # NDVI
        ndvi = (nir - red) / (nir + red + 1e-6)
        # SAVI (L=0.5)
        L = 0.5
        savi = ((nir - red) / (nir + red + L)) * (1 + L)
        # VARI
        vari = (green - red) / (green + red - blue + 1e-6)
        # EXG
        exg = 2 * green - red - blue
        # NDRE
        ndre = (nir - re) / (nir + re + 1e-6)

        # Stack bands into a 10-band patch
        # patch_10band = np.stack([blue, green, red, re, nir, swir, ndvi, savi, vari, exg, ndre], axis=-1)
        patch_10band = np.stack([blue, green, red, re, nir, ndvi, savi, vari, exg, ndre], axis=-1)

        # Normalize each band independently
        patch_10band = (patch_10band - patch_10band.mean(axis=(0, 1), keepdims=True)) / \
                       (patch_10band.std(axis=(0, 1), keepdims=True) + 1e-8)

        # Convert to PyTorch tensor and rearrange to (C, H, W)
        patch_10band = torch.from_numpy(patch_10band).float().permute(2, 0, 1)
        return patch_10band, y, x, savi[self.tile_size // 2, self.tile_size // 2]


In [11]:
# tiff_src = "/blue/changzhao/zhou.tang/botanical_composition/data/pad2_july.tif"
# tiff_src = "/blue/changzhao/zhou.tang/botanical_composition/data/All_Paddock_26_JUL_2024_ortho_bgrent.tiff"
# tiff_src = "/blue/changzhao/zhou.tang/botanical_composition/data/All_Paddock_22_SEP_2024_ortho_bgrent.tiff"
tiff_src = "/blue/changzhao/zhou.tang/botanical_composition/data/2023/Paddock_9_July_2023.tif"
# Parameters
tile_size = 50
stride = 1
batch_size = 7200  # Adjust based on GPU memory
weight_url = 'best_model_20241221_50_smallmodel4.pth'
out_tif_url = 'predictions_pad9_2023July_20241227_smallCNN4_50px_filtersavi0.6.tiff'
n_worker = 36
savi_thred = 0.6
predict_none = 3.0

# Load the TIFF image
with rasterio.open(tiff_src) as src:
    image = src.read()  # Read all bands
    transform = src.transform
    crs = src.crs

# Ensure shape: (H, W, C)
image = np.moveaxis(image, 0, -1)

# Padding the image
pad_h = (image.shape[0] + tile_size - 1) // tile_size * tile_size
pad_w = (image.shape[1] + tile_size - 1) // tile_size * tile_size
padded_image = np.pad(image, ((0, pad_h - image.shape[0]), (0, pad_w - image.shape[1]), (0, 0)), mode='reflect')

# Create the Dataset and DataLoader
dataset = TiffDataset(padded_image, tile_size, stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=n_worker, pin_memory=True)


# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SimpleCNN(num_classes=2).to(device)
# model.load_state_dict(torch.load('best_model_20241116_0.94.pth'))
model = SmallCNN4(num_classes=2).to(device)
model.load_state_dict(torch.load(weight_url))

model.eval().to(device)

# Prepare the output prediction array
output_predictions = np.zeros((padded_image.shape[0], padded_image.shape[1]), dtype=np.float32)

# Compute the offset to place the prediction in the center of the patch
center_offset = tile_size // 2

# Perform prediction
with torch.no_grad():
    for tiles, ys, xs, savis in tqdm(dataloader, desc="Processing tiles"):
        tiles = tiles.to(device)  # Move tiles to GPU
        tiles = tiles.contiguous()
        # print(torch.cuda.memory_summary())
        preds = model(tiles)  # Adjust based on your model's output shape
        preds = preds.argmax(dim=1)
        # print(f"get {len(ys)} prediction")

        for i in range(len(ys)):
            y, x, savi = ys[i].item(), xs[i].item(), savis[i].numpy()
            # print(f"savi is {savi}")
            # print(f"x is {x}, y is {y}")
            center_y = y + center_offset
            center_x = x + center_offset
            if savi >= savi_thred:
                output_predictions[center_y, center_x] = preds[i].cpu().item()
            else:
                output_predictions[center_y, center_x] = predict_none
                    

# Crop the predictions back to the original image size
output_predictions = output_predictions[:image.shape[0], :image.shape[1]]

# Save the predictions as a new TIFF
with rasterio.open(
    out_tif_url,
    'w',
    driver='GTiff',
    height=output_predictions.shape[0],
    width=output_predictions.shape[1],
    count=1,
    dtype='float32',
    crs=crs,
    transform=transform,
) as dst:
    dst.write(output_predictions, 1)

print("Prediction TIFF has been saved successfully.")

Processing tiles: 100%|██████████| 7261/7261 [43:58<00:00,  2.75it/s]  


Prediction TIFF has been saved successfully.
