In [1]:
from pe import PositionalEncoding
from model import Generator, DownsamplingLayer, FullNetwork

import torch
import torch.nn as nn
import numpy as np


In [2]:
def generate_wavelength_grid():
    grid = [
        (15050, 15850, 0.2),
        (15870, 16440, 0.2),
        (16475, 17005, 0.2),
        (4700, 4930, 0.05),
        (5650, 5880, 0.05),
        (6420, 6800, 0.05),
        (7500, 7920, 0.05)
    ]

    wavelength_grid = []
    for start, end, step in grid:
        wavelength_grid.extend(np.arange(start, end + step, step))
    
    return np.array(wavelength_grid)

def normalize_wavelengths(wavelengths, max_wavelength):
    return wavelengths / max_wavelength

In [3]:
# Generate wavelength grid (high resolution)
wavelength_grid = torch.tensor(generate_wavelength_grid(), dtype=torch.float16)
wavelength_grid = normalize_wavelengths(wavelength_grid, 17100)  # Normalize the wavelength grid


In [4]:
# Example configuration
latent_dim = 5  # Example latent dimension
output_dim = 25300  # Example output dimension (max length of wavelengths)
layers = [512, 512]  # Example hidden layers
activation_function = 'LeakyReLU'  # Example activation function

In [5]:
# Define Positional Encoding
pe_args = (
    1,  # Dimension of the input (wavelengths)
    10,  # pe_dim example value
    1.0,  # omega example value
    1.0,  # sigma example value
    True,  # pe_bias example value
    42  # seed example value
)

In [6]:
positional_encoding = PositionalEncoding(pe_args)

In [7]:
# Define Generator
input_dim = latent_dim + pe_args[1]  # latent_dim + pe_dim
generator = Generator(input_dim, output_dim, layers, activation_function)

# Define downsampling layer
downsampling_layer = DownsamplingLayer()

# Create the full network
full_network = FullNetwork(generator, downsampling_layer, positional_encoding)



In [8]:
# Example input tensors
batch_size = 5
positional_dim = 25300  # Adjust according to your needs
latent_z = torch.randn((batch_size, latent_dim))  # Example latent space vector


In [9]:
# Generate lower resolution real_wavelengths
real_wavelengths = torch.linspace(wavelength_grid.min(), wavelength_grid.max(), 8000)
real_wavelengths = real_wavelengths.unsqueeze(0).expand(batch_size, -1)  # Shape: [batch_size, 8000]


In [10]:

# Use positional encoding by reference
positional_encoding_ref = positional_encoding(wavelength_grid)

# Forward pass
generated_flux = full_network(latent_z, wavelength_grid, real_wavelengths)

# Print the shape of the generated flux
print(f"Generated flux shape: {generated_flux.shape}")

latent_z shape: torch.Size([5, 5])
wavelength_grid shape: torch.Size([34714])
real_wavelengths shape: torch.Size([5, 8000])
positional_encoding shape: torch.Size([5, 34714, 10])
latent_z_expanded shape: torch.Size([5, 34714, 5])
input_to_generator shape: torch.Size([5, 34714, 15])
high_res_flux shape: torch.Size([5, 878264200])
high_res_flux shape: torch.Size([5, 878264200])
high_res_wavelength shape: torch.Size([34714])
observed_wavelength shape: torch.Size([5, 8000])
Reshaped high_res_flux shape: torch.Size([5, 1, 878264200, 1])
Grid shape: torch.Size([5, 1, 8000, 1])


RuntimeError: grid_sampler(): expected grid to have size 2 in last dimension, but got grid with sizes [5, 1, 8000, 1]

In [None]:
# Define the weighted MSE loss function
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2) 
    return loss

In [None]:
# Example target and weight tensors
real_flux = torch.randn((batch_size, positional_dim))  # Example real flux from dataset
weights = torch.ones((batch_size, positional_dim))  # Example weights

# Compute the weighted MSE loss
loss = weighted_mse_loss(generated_flux, real_flux, weights)

# Backpropagation
optimizer = torch.optim.Adam(full_network.parameters(), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print total loss
print(f"Total loss: {loss.item()}")