In [1]:
from pe import PositionalEncoding
from model import Generator, DownsamplingLayer, FullNetwork

import torch
import torch.nn as nn
import numpy as np


In [2]:
# Initialize components with example configuration
latent_dim = 3
output_dim = 34714  # The correct output dimension matching the wavelength grid
layers = [512, 512]
activation_function = 'LeakyReLU'
batch_size = 4


In [3]:
# Generate wavelength grid (high resolution)
def generate_wavelength_grid(batch_size):
    grid = [
        (15050, 15850, 0.2),
        (15870, 16440, 0.2),
        (16475, 17005, 0.2),
        (4700, 4930, 0.05),
        (5650, 5880, 0.05),
        (6420, 6800, 0.05),
        (7500, 7920, 0.05)
    ]
    wavelength_grid = []
    for start, end, step in grid:
        wavelength_grid.extend(np.arange(start, end + step, step))
    wavelength_grid = np.array(wavelength_grid)
    return wavelength_grid


def normalize_wavelengths(wavelengths, max_wavelength):
    return wavelengths / max_wavelength

In [4]:
wavelength_grid = torch.tensor(generate_wavelength_grid(batch_size), dtype=torch.float16)
wavelength_grid = normalize_wavelengths(wavelength_grid, 17100).unsqueeze(0).expand(batch_size, -1)  # Normalize the wavelength grid and expand to batch size


In [5]:

# Define Positional Encoding
pe_args = (1, 10, 1.0, 1.0, True, 42)
positional_encoding = PositionalEncoding(pe_args)


In [6]:

# Define Generator
input_dim = latent_dim + pe_args[1]  # latent_dim + pe_dim
generator = Generator(input_dim, output_dim, layers, activation_function)


In [7]:

# Define downsampling layer
downsampling_layer = DownsamplingLayer()

# Create the full network
full_network = FullNetwork(generator, downsampling_layer, positional_encoding)

# Example input tensors
latent_z = torch.randn((batch_size, latent_dim))  # Example latent space vector

# Generate lower resolution real_wavelengths
real_wavelengths = torch.linspace(0, 1, 8000).unsqueeze(0).expand(batch_size, -1)  # Shape: [batch_size, 8000]


In [8]:

# Forward pass
generated_flux = full_network(latent_z, wavelength_grid, real_wavelengths)

# Print the shape of the generated flux
print(f"Generated flux shape: {generated_flux.shape}")


latent_z shape: torch.Size([4, 3])
wavelength_grid shape: torch.Size([34714])
real_wavelengths shape: torch.Size([4, 8000])
positional_encoding shape: torch.Size([10])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)