<a href="https://colab.research.google.com/github/DannyinDelft/INR/blob/main/February_10th_2025_1544_pointsnotimages.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Point-based Training and Testing for Custom SIREN Model with Georeferencing and Distinct Outputs
"""
!pip install rasterio geopandas torch torchmetrics

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import geopandas as gpd
import rasterio
from rasterio.transform import from_bounds
from google.colab import drive
import re
from torchmetrics.image import PeakSignalNoiseRatio

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on device: {device}")

# 🔹 Fourier Encoding for High-Frequency Details
class FourierEncoding(nn.Module):
    def __init__(self, input_dim, mapping_size=128, scale=40.0):
        super().__init__()
        self.B = torch.randn((input_dim, mapping_size)) * scale

    def forward(self, x):
        x_proj = 2 * np.pi * x @ self.B.to(x.device)
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

# 🔹 SIREN Layer
class SirenLayer(nn.Module):
    def __init__(self, in_channels, out_channels, omega=30.0, is_first=False):
        super(SirenLayer, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels)
        self.omega = omega
        self.init_weights(is_first)

    def init_weights(self, is_first):
        with torch.no_grad():
            bound = 1 / self.linear.in_features if is_first else np.sqrt(6 / self.linear.in_features) / self.omega
            nn.init.uniform_(self.linear.weight, -bound, bound)

    def forward(self, x):
        return torch.sin(self.omega * self.linear(x))

# 🔹 Total Variation Loss for Smoother Images
def total_variation_loss(img):
    if img.ndim == 1:
        return torch.tensor(0.0, device=img.device)

    img = img.view(-1, 1, int(np.sqrt(img.shape[0])), int(np.sqrt(img.shape[0])))
    dx = torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]).mean()
    dy = torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]).mean()
    return dx + dy

# 🔹 Custom Loss Function (MSE + TV Loss)
class CustomLoss(nn.Module):
    def __init__(self, alpha=0.001):
        super().__init__()
        self.mse = nn.MSELoss()
        self.alpha = alpha

    def forward(self, outputs, targets):
        loss_mse = self.mse(outputs, targets)
        loss_tv = total_variation_loss(outputs)
        return loss_mse + self.alpha * loss_tv

# 🔹 SIREN Model (Fixed Input-Output Dimension Issue)
class CustomSirenModel(nn.Module):
    def __init__(self, input_dim=2, fourier_dim=256, site_embedding_dim=128, hidden_features=512, hidden_layers=8, out_features=1):
        super(CustomSirenModel, self).__init__()
        self.fourier = FourierEncoding(input_dim, mapping_size=fourier_dim)
        self.site_embeddings = nn.Embedding(20, site_embedding_dim)

        self.layers = nn.ModuleList([SirenLayer(fourier_dim * 2 + site_embedding_dim, hidden_features, is_first=True)])
        self.layers.extend([SirenLayer(hidden_features, hidden_features) for _ in range(hidden_layers)])
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, x, site_id):
        fourier_encoded = self.fourier(x)
        site_embed = self.site_embeddings(site_id).squeeze(0)
        site_embed = site_embed.unsqueeze(0).expand(x.size(0), -1)

        x = torch.cat([fourier_encoded, site_embed], dim=-1)
        for layer in self.layers:
            x = layer(x)
        return self.final_layer(x)

# 🔹 Sample Ground Truth Pixel Values
def sample_ground_truth(raster_path, coords, raster_crs, raster_bounds):
    with rasterio.open(raster_path) as src:
        if raster_crs != src.crs:
            coords = coords.to_crs(src.crs)

        values = []
        valid_coords = []
        for point in coords.geometry:
            x, y = point.x, point.y
            if raster_bounds[0] <= x <= raster_bounds[2] and raster_bounds[1] <= y <= raster_bounds[3]:
                try:
                    row, col = src.index(x, y)
                    val = src.read(1)[row, col]
                    if not np.isnan(val):
                        values.append(val)
                        valid_coords.append((x, y))
                except Exception:
                    continue
        return torch.tensor(values, dtype=torch.float32).to(device), valid_coords

# 🔹 Training Function
def train(model, train_loader, raster_loader, epochs=500):
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-6)
    loss_fn = CustomLoss(alpha=0.001)
    psnr_metric = PeakSignalNoiseRatio()

    for epoch in range(epochs):
        total_loss = 0
        total_psnr = 0
        for site_id in sorted(train_loader.keys()):
            coords, file_name = train_loader[site_id]
            if site_id not in raster_loader:
                continue

            raster_path, raster_crs, raster_bounds, _ = raster_loader[site_id]
            targets, valid_coords = sample_ground_truth(raster_path, coords, raster_crs, raster_bounds)

            if len(valid_coords) == 0:
                continue

            coords_tensor = torch.tensor(valid_coords, dtype=torch.float32).to(device)
            optimizer.zero_grad()
            outputs = model(coords_tensor, torch.tensor([site_id], dtype=torch.long).to(device)).squeeze()
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_psnr += psnr_metric(outputs, targets).item()

        avg_psnr = total_psnr / len(train_loader)
        print(f"🔥 Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.6f}, PSNR: {avg_psnr:.2f}")

# 🔹 Main Function (Corrected Format)
def main():
    model = CustomSirenModel().to(device)

    train_input_path = "/content/drive/My Drive/Thesis_imagery/Ecostress/1000mpatches/Rastertopoint/Training_input"
    train_gt_path = "/content/drive/My Drive/Thesis_imagery/Ecostress/1000mpatches/Rastertopoint/Training_ground_truth"
    test_input_path = "/content/drive/My Drive/Thesis_imagery/Ecostress/1000mpatches/Rastertopoint/Testing_input"
    output_dir = "/content/drive/My Drive/Thesis_imagery/Ecostress/1000mpatches/Rastertopoint/Testing_output"

    train(model, point_data_loader(train_input_path, "EPSG:25832"), raster_data_loader(train_gt_path), epochs=500)

if __name__ == "__main__":
    main()


Mounted at /content/drive
Running on device: cpu
🔥 Epoch [1/500], Loss: 35.631991, PSNR: -1.26
🔥 Epoch [2/500], Loss: 11.631541, PSNR: 2.25
🔥 Epoch [3/500], Loss: 6.158193, PSNR: 5.04
🔥 Epoch [4/500], Loss: 3.326389, PSNR: 7.73
🔥 Epoch [5/500], Loss: 2.401338, PSNR: 9.07
🔥 Epoch [6/500], Loss: 2.095396, PSNR: 9.67
🔥 Epoch [7/500], Loss: 1.787045, PSNR: 10.37
🔥 Epoch [8/500], Loss: 1.624843, PSNR: 10.79
🔥 Epoch [9/500], Loss: 1.475496, PSNR: 11.18
🔥 Epoch [10/500], Loss: 1.316284, PSNR: 11.71
🔥 Epoch [11/500], Loss: 1.232681, PSNR: 11.98
🔥 Epoch [12/500], Loss: 1.183088, PSNR: 12.16
🔥 Epoch [13/500], Loss: 1.063981, PSNR: 12.62
🔥 Epoch [14/500], Loss: 0.999276, PSNR: 12.89
🔥 Epoch [15/500], Loss: 0.905716, PSNR: 13.32
🔥 Epoch [16/500], Loss: 0.854789, PSNR: 13.55
🔥 Epoch [17/500], Loss: 0.790212, PSNR: 13.91
🔥 Epoch [18/500], Loss: 0.733145, PSNR: 14.22
🔥 Epoch [19/500], Loss: 0.699699, PSNR: 14.43
🔥 Epoch [20/500], Loss: 0.613211, PSNR: 15.00
🔥 Epoch [21/500], Loss: 0.582366, PSNR: 15.