In [1]:
from dataset import SFMNNDataset
from network import SFMNNEncoder
from simulate import SFMNNSimulation
from loss import SFMNNLoss

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import json

from dotenv import load_dotenv



In [2]:
load_dotenv()
data_folder = os.getenv("DATA_FOLDER")

with open(data_folder + "simulation_sim_0_amb_0.json") as f:
    data = json.load(f)
with open(data_folder + "simulation_lookuptable.json") as f:
    lookup = json.load(f)

# Load the dataset
dataset = SFMNNDataset(data_folder + 'simulation_lookuptable.json', 'output/', patch_size=5)

# Load the dataloader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, drop_last=False)


Loading JSON files:   0%|          | 0/1350 [00:00<?, ?it/s]

Loading JSON files: 100%|██████████| 1350/1350 [00:15<00:00, 86.50it/s]


Total number of elements loaded: 1350
Total number of patches loaded: 60


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
# Number of spectral bands
n_spectral_bands = len(dataset.get_wl())
n_spectral_bands

3620

In [5]:
H, W, C = dataset[0].shape
H, W, C

(5, 5, 3623)

In [6]:
# Initialize the network
encoder = SFMNNEncoder(input_channels=C, latent_dim=C).to(device)

print(encoder)

SFMNNEncoder(
  (input_norm): BatchNorm1d(3623, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_layers): Sequential(
    (0): Linear(in_features=3623, out_features=4096, bias=True)
    (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=4096, out_features=8192, bias=True)
    (5): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Linear(in_features=8192, out_features=16384, bias=True)
    (9): BatchNorm1d(16384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Linear(in_features=16384, out_features=16384, bias=True)
    (12): BatchNorm1d(16384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
  )
  (latent_proj): Linear(in_features=16384, out_features=32607, bias=True)
)


In [7]:
optimizer = optim.Adam(encoder.parameters(), lr=0.001, weight_decay=0.0001)
criterion = SFMNNLoss(torch.tensor(n_spectral_bands))

In [8]:
# test the network
for i, data in enumerate(dataloader):
    data = data.to(device)
    print('input:', data.shape)
    out = encoder(data)
    print('output:', out.shape)

input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
output: torch.Size([3, 5, 5, 9, 3623])
input: torch.Size([3, 5, 5, 3623])
out

In [9]:
torch.tensor(dataset.get_wl()).shape

torch.Size([3620])

In [15]:
simulation = SFMNNSimulation(torch.tensor(dataset.get_wl()).to(device))

encoder.train()
for data in tqdm(dataloader):
    # MEASUREMENT STEP
    data = data.to(device)

    # ENCODING STEP
    output = encoder(data)
    print("Encoder output shape:", output.shape)

    # SIMULATION STEP
    R           = output[..., 0, :]  # Reflectance parameters
    F           = output[..., 1, :]  # Fluorescence parameters
    t_1         = output[..., 2, :] 
    t_2         = output[..., 3, :]
    t_3         = output[..., 4, :]
    t_4         = output[..., 5, :]
    t_5         = output[..., 6, :]
    t_6         = output[..., 7, :]
    #d_lambda    = output[..., 8, :]  # Expected shape: [B, N] or [B, H, W]
    d_lambda    = output[..., 8, 0]  # Expected shape: [B, N] or [B, H, W]
    #d_lambda_avg = d_lambda.mean(dim=(1,2,3)).view(d_lambda.shape[0], 1)
    print("R shape:", R.shape)
    print("F shape:", F.shape)
    print("t_1 shape:", t_1.shape)
    print("t_2 shape:", t_2.shape)
    print("t_3 shape:", t_3.shape)
    print("t_4 shape:", t_4.shape)
    print("t_5 shape:", t_5.shape)
    print("t_6 shape:", t_6.shape)
    print("d_lambda shape:", d_lambda.shape)
    #print("d_lambda_avg shape:", d_lambda_avg.shape)
    

    
    sim_output = simulation(
        t_1, t_2, t_3, t_4, t_5, t_6,
        R,
        F,
        d_lambda,
        torch.tensor(10, device=device),
        torch.tensor(1, device=device),
        torch.tensor(0, device=device)
    )

    print("Simulation output type:", type(simulation))
    
    # LOSS COMPUTATION
    loss = criterion(sim_output, data)
    print("Loss:", loss.item())

    # BACKPROPAGATION
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Encoder output shape: torch.Size([3, 5, 5, 9, 3623])
R shape: torch.Size([3, 5, 5, 3623])
F shape: torch.Size([3, 5, 5, 3623])
t_1 shape: torch.Size([3, 5, 5, 3623])
t_2 shape: torch.Size([3, 5, 5, 3623])
t_3 shape: torch.Size([3, 5, 5, 3623])
t_4 shape: torch.Size([3, 5, 5, 3623])
t_5 shape: torch.Size([3, 5, 5, 3623])
t_6 shape: torch.Size([3, 5, 5, 3623])
d_lambda shape: torch.Size([3, 5, 5])
t1: torch.Size([3, 5, 5, 3623])
t2: torch.Size([3, 5, 5, 3623])
t3: torch.Size([3, 5, 5, 3623])
t4: torch.Size([3, 5, 5, 3623])
t5: torch.Size([3, 5, 5, 3623])
t6: torch.Size([3, 5, 5, 3623])
R: torch.Size([3, 5, 5, 3623])
F: torch.Size([3, 5, 5, 3623])
L_hr shape: torch.Size([3, 5, 5, 3623])





AttributeError: 'Tensor' object has no attribute 'conv1d'

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F_func  # renamed to avoid shadowing by tensor 'F'
from tqdm import tqdm

# --- FourStreamSimulator ---
class FourStreamSimulator(nn.Module):
    def __init__(self, spectral_window=(750, 770), high_res=0.0055):
        super().__init__()
        self.register_buffer('lambda_hr', torch.arange(*spectral_window, high_res))
        self.mu_f = 737.0  # Fluorescence peak wavelength

    def forward(self, t1, t2, t3, t4, t5, t6, R, F, E_s, cos_theta_s):
        print(f"t1: {t1.shape}")
        print(f"t2: {t2.shape}")
        print(f"t3: {t3.shape}")
        print(f"t4: {t4.shape}")
        print(f"t5: {t5.shape}")
        print(f"t6: {t6.shape}")
        print(f"R: {R.shape}")
        print(f"F: {F.shape}")

        # Compute product terms (following Table 3 from the paper)
        t7 = t3 * t4
        t8 = t3 * t6
        t9 = t4 * t5
        t10 = t4 * t2
        t11 = t3 * t2

        LTOA = t1 * t2 + (t1 * t8 * R + t9 * R + t10 * R + t11 * R + t6 * F + t7 * F) / (1 - t3 * R)
        return LTOA

# --- HyPlantSensorSimulator ---
class HyPlantSensorSimulator(nn.Module):
    def __init__(self, sensor_wavelengths, high_res=0.0055):
        super().__init__()
        # sensor_wavelengths: expected shape [n_sensor]
        self.register_buffer('sensor_wavelengths', sensor_wavelengths)
        self.high_res = high_res
        self.register_buffer('wl_range', torch.tensor([sensor_wavelengths.min(), sensor_wavelengths.max()]))

    def forward(self, L_hr, delta_lambda, delta_sigma):
        """
        L_hr: [B, C_hr, H, W] where C_hr is the high-resolution spectral dimension.
        delta_lambda: [B, 1] (one wavelength shift per image)
        delta_sigma: scalar tensor
        """
        B, C_hr, H, W = L_hr.shape

        # 1. Create Gaussian SRF kernel
        sigma = (0.27 + delta_sigma) * 2.3548  # convert FWHM to sigma
        kernel_size = int(6 * sigma / self.high_res)
        x = torch.linspace(-3 * sigma, 3 * sigma, kernel_size, device=L_hr.device)
        kernel = torch.exp(-0.5 * (x / sigma) ** 2)
        kernel /= kernel.sum()

        # 2. Spectral convolution along the spectral dimension.
        # Reshape L_hr so that the spectral dimension becomes the width of a 1-pixel–high image.
        # L_hr: [B, C_hr, H, W] -> permute to [B, H, W, C_hr] then reshape to [B*H*W, 1, C_hr]
        L_hr_reshaped = L_hr.permute(0, 2, 3, 1).reshape(B * H * W, 1, C_hr)
        L_blur = F_func.conv1d(L_hr_reshaped, kernel.view(1, 1, -1), padding=kernel_size // 2)
        # Compute the new spectral dimension from conv1d output:
        new_C_hr = L_blur.shape[-1]
        # Reshape back to [B, H, W, new_C_hr]
        L_blur = L_blur.view(B, H, W, new_C_hr)

        # 3. Wavelength shift and interpolation.
        # sensor_wavelengths: [n_sensor]. After adding delta_lambda ([B,1]) and normalizing,
        # we obtain coordinates for each image.
        n_sensor = self.sensor_wavelengths.shape[0]
        shifted_wl = self.sensor_wavelengths.unsqueeze(0) + delta_lambda  # [B, n_sensor]
        normalized_wl = 2 * (shifted_wl - self.wl_range[0]) / (self.wl_range[1] - self.wl_range[0]) - 1  # [B, n_sensor]

        # Expand these coordinates to every spatial location.
        # normalized_wl_expanded: [B, H, W, n_sensor]
        normalized_wl_expanded = normalized_wl.unsqueeze(1).unsqueeze(1).expand(B, H, W, n_sensor)
        # Flatten spatial dimensions: [B*H*W, n_sensor]
        normalized_wl_flat = normalized_wl_expanded.reshape(B * H * W, n_sensor)

        # grid_sample expects a grid with last dimension 2 (for (y, x)).
        # Here, we create dummy y coordinates (all zeros) because our "image" height is 1.
        zero_y = torch.zeros_like(normalized_wl_flat)  # [B*H*W, n_sensor]
        # Form grid: [B*H*W, n_sensor, 2] with (y, x) coordinates.
        grid = torch.stack((zero_y, normalized_wl_flat), dim=-1)
        # Reshape grid to [B*H*W, 1, n_sensor, 2] (grid_sample expects shape [N, H_out, W_out, 2])
        grid = grid.view(B * H * W, 1, n_sensor, 2)

        # Reshape L_blur to [B*H*W, 1, 1, new_C_hr] so that the "image" has height 1 and width = new_C_hr.
        L_blur_reshaped = L_blur.view(B * H * W, 1, 1, new_C_hr)
        # Sample using grid_sample.
        L_sampled = F_func.grid_sample(L_blur_reshaped, grid, mode='bilinear', align_corners=False)
        # L_sampled: [B*H*W, 1, 1, n_sensor]. Reshape to [B, H, W, n_sensor]
        L_sampled = L_sampled.view(B, H, W, n_sensor)
        # Permute to [B, n_sensor, H, W] (the desired output shape).
        L_hyp = L_sampled.permute(0, 3, 1, 2)
        return L_hyp

# --- SFMNNSimulation ---
class SFMNNSimulation(nn.Module):
    def __init__(self, sensor_wavelengths):
        super().__init__()
        self.four_stream = FourStreamSimulator()
        self.sensor_sim = HyPlantSensorSimulator(sensor_wavelengths)

    def forward(self, t1, t2, t3, t4, t5, t6, R, F, delta_lambda, delta_sigma, E_s, cos_theta_s):
        # 1. Compute high-resolution radiance.
        L_hr = self.four_stream(t1, t2, t3, t4, t5, t6, R, F, E_s, cos_theta_s)
        print("L_hr shape:", L_hr.shape)
        # 2. Apply sensor simulation.
        L_hyp = self.sensor_sim(L_hr, delta_lambda, delta_sigma)
        return L_hyp

# --- Training Loop ---
# Assumes that the following objects are defined and on the correct device:
# - dataset: an instance of your dataset (with method get_wl())
# - dataloader: a PyTorch DataLoader for your dataset
# - encoder: your encoder network
# - criterion: your loss function
# - optimizer: the optimizer for your encoder parameters
# - device: e.g., torch.device('cuda') or torch.device('cpu')

# Initialize simulation with sensor wavelengths from the dataset.
sensor_wl = torch.tensor(dataset.get_wl()).to(device)
simulation = SFMNNSimulation(sensor_wl)

encoder.train()
for data in tqdm(dataloader):
    # MEASUREMENT STEP
    data = data.to(device)

    # ENCODING STEP
    output = encoder(data)
    print("Encoder output shape:", output.shape)

    # Unpack encoder output.
    # Assumes output has shape [B, 5, 5, N_params] with N_params >= 9.
    R_param       = output[..., 0, :]  # Reflectance parameters
    F_param       = output[..., 1, :]  # Fluorescence parameters
    t_1           = output[..., 2, :] 
    t_2           = output[..., 3, :]
    t_3           = output[..., 4, :]
    t_4           = output[..., 5, :]
    t_5           = output[..., 6, :]
    t_6           = output[..., 7, :]
    d_lambda      = output[..., 8, :]  # Expected shape: [B, 5, 5, 326070] (for example)
    print("Original d_lambda shape:", d_lambda.shape)

    # Average d_lambda over dimensions 1, 2, and 3 to obtain shape [B, 1]
    d_lambda_avg = d_lambda.mean(dim=(1, 2, 3)).view(d_lambda.shape[0], 1)
    print("Averaged d_lambda shape:", d_lambda_avg.shape)

    # SIMULATION STEP
    sim_output = simulation(
        t_1, t_2, t_3, t_4, t_5, t_6,
        R_param,
        F_param,
        d_lambda_avg,  # shape [B, 1]
        torch.tensor(10, device=device),
        torch.tensor(1, device=device),
        torch.tensor(0, device=device)
    )
    print("Simulation output type:", type(simulation))

    # LOSS COMPUTATION
    loss = criterion(sim_output, data)
    print("Loss:", loss.item())

    # BACKPROPAGATION
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Encoder output shape: torch.Size([3, 5, 5, 9, 3623])
Original d_lambda shape: torch.Size([3, 5, 5, 3623])
Averaged d_lambda shape: torch.Size([3, 1])
t1: torch.Size([3, 5, 5, 3623])
t2: torch.Size([3, 5, 5, 3623])
t3: torch.Size([3, 5, 5, 3623])
t4: torch.Size([3, 5, 5, 3623])
t5: torch.Size([3, 5, 5, 3623])
t6: torch.Size([3, 5, 5, 3623])
R: torch.Size([3, 5, 5, 3623])
F: torch.Size([3, 5, 5, 3623])
L_hr shape: torch.Size([3, 5, 5, 3623])
Simulation output type: <class '__main__.SFMNNSimulation'>





TypeError: SFMNNLoss.forward() missing 3 required positional arguments: 'outputs', 'E_s', and 'cos_theta_s'