In [6]:
import numpy as np
import torch
import torch.nn as nn
import rasterio

In [7]:
print(f"PyTorch version: {torch.__version__}")
print(f"MPS (Metal Performance Shaders) available: {torch.backends.mps.is_available()}")
print(f"Built with MPS: {torch.backends.mps.is_built()}")

PyTorch version: 2.7.0.dev20250112
MPS (Metal Performance Shaders) available: True
Built with MPS: True


In [8]:
# Define the Neural Network
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [9]:
# Model, Loss, and Optimizer
input_size = 5  # Number of input features
hidden_size = 64  # Hidden layer size
output_size = 3  # Number of targets (N, P, K)

model = NeuralNet(input_size, hidden_size, output_size)

In [13]:
state_dict = torch.load('./models/best_model_macro_npk.pth')

In [14]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [15]:
def predict_raster(raster, model):
    with rasterio.open(raster) as src:
        # Read all raster bands
        bands = [src.read(i + 1) for i in range(src.count)]  # Read all bands into a list
        bands = np.stack(bands, axis=0)  # Shape: (num_bands, height, width)

        # Min-Max Normalization (to range 0-1)
        min_val = bands.min(axis=(1, 2), keepdims=True)  # Find min value across all pixels for each band
        max_val = bands.max(axis=(1, 2), keepdims=True)  # Find max value across all pixels for each band

        # Normalize to 0-1 range
        normalized_bands = (bands - min_val) / (max_val - min_val)

        #calculate NDVI
        red_band = bands[2].astype(float)
        nir_band = bands[3].astype(float)
        ndvi = (nir_band - red_band) / (nir_band + red_band + 1e-8)
        ndvi_band = np.expand_dims(ndvi, axis=0)

          #concat bands
        dataset = np.concatenate((normalized_bands, ndvi_band), axis=0)
        pixels = np.transpose(dataset, (1, 2, 0)).reshape(-1, dataset.shape[0])

        input = torch.from_numpy(pixels).float().to('cpu')
        print(input.shape)
          #Predict
        predictions = model(input)
        print(predictions.shape)
        predictions_numpy = predictions.detach().cpu().numpy()
        result = predictions_numpy.reshape([bands.shape[1], bands.shape[2], 3])

        # Save each band as a separate GeoTIFF file
        for i in range(3):
            single_band = result[:, :, i]
            output_path = f'./predicted/band_{i}.tif'
            with rasterio.open(
                    output_path,
                    'w',
                    driver='GTiff',
                    height=single_band.shape[0],
                    width=single_band.shape[1],
                    count=1,
                    dtype=single_band.dtype,
                    crs=src.crs,
                    transform=src.transform,
                ) as dst:
                    dst.write(single_band, 1)

        print(f"Band {i + 1} saved to {output_path}")
        return result

In [17]:
result = predict_raster("../Datasets/CITRA_MS_AAPA/CITRA_MS_AAPA.tif", model)

torch.Size([40556174, 5])
torch.Size([40556174, 3])
Band 3 saved to ./predicted/band_2.tif
