In [1]:
from pe import PositionalEncoding
import torch
import torch.nn as nn
import numpy as np

In [2]:
# Example config
config = {
    'model': {
        'pe_dim': 10,  # Example value for pe_dim, replace with your actual value
        'omega': 1.0,  # Example omega value, replace with your actual value
        'sigma': 1.0,  # Example sigma value, replace with your actual value
        'pe_bias': True,  # Example bias setting, replace with your actual value
        'seed': 42  # Example seed value, replace with your actual value
    }
}


In [3]:

# Extract positional encoding arguments from config
pe_args = (
    1,  # Dimension of the input (wavelengths)
    config['model']['pe_dim'],
    config['model']['omega'],
    config['model']['sigma'],
    config['model']['pe_bias'],
    config['model']['seed']
)


In [4]:
def generate_wavelength_grid():
    grid = [
        (15050, 15850, 0.2),
        (15870, 16440, 0.2),
        (16475, 17005, 0.2),
        (4700, 4930, 0.05),
        (5650, 5880, 0.05),
        (6420, 6800, 0.05),
        (7500, 7920, 0.05)
    ]

    wavelength_grid = []
    for start, end, step in grid:
        wavelength_grid.extend(np.arange(start, end + step, step))
    
    return np.array(wavelength_grid)

def normalize_wavelengths(wavelengths, max_wavelength):
    return wavelengths / max_wavelength

In [5]:
# Generate wavelength grid
wavelength_grid = generate_wavelength_grid()

# Normalize the wavelength grid
max_wavelength = 17100
normalized_wavelength_grid = normalize_wavelengths(wavelength_grid, max_wavelength)

# Initialize Positional Encoding with the arguments from config
positional_encoding = PositionalEncoding(pe_args)

# Generate positional encoding tensor
wavelength_grid_tensor = torch.tensor(normalized_wavelength_grid, dtype=torch.float16)
pe_tensor = positional_encoding(wavelength_grid_tensor)

print(pe_tensor.shape)  # Verify the shape

torch.Size([34714, 10])


In [6]:
print(pe_tensor)

tensor([[-0.4727,  0.9746, -0.3076,  ...,  0.9517,  0.8340, -0.6118],
        [-0.4736,  0.9746, -0.3083,  ...,  0.9512,  0.8340, -0.6143],
        [-0.4736,  0.9746, -0.3083,  ...,  0.9512,  0.8340, -0.6143],
        ...,
        [ 0.3794,  0.9937,  0.2966,  ...,  0.9551,  1.0000,  0.7573],
        [ 0.3794,  0.9937,  0.2966,  ...,  0.9551,  1.0000,  0.7573],
        [ 0.3794,  0.9937,  0.2966,  ...,  0.9551,  1.0000,  0.7573]],
       dtype=torch.float16)


In [7]:
print (wavelength_grid)

[15050.   15050.2  15050.4  ...  7919.95  7920.    7920.05]


In [8]:
# Check for NaNs in the pe_tensor
contains_nan = torch.isnan(pe_tensor).any()
print(f"Contains NaNs: {contains_nan}")


Contains NaNs: False


In [9]:
from dataset import IterableSpectraDataset, collate_fn
from torch.utils.data import DataLoader

In [10]:
hdf5_dir = '../data/healpixfiles_inter'
dataset = IterableSpectraDataset(hdf5_dir, n_samples_per_spectrum=1000, n_subspectra=5, yield_full_spectrum=True)
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)

In [11]:
# Get the first batch
first_batch = next(iter(dataloader))

# Print shapes of wavelength and flux tensors, and the spectrum IDs
print("Wavelength shape:", first_batch['wavelength'].shape)
print("Flux shape:", first_batch['flux'].shape)
print("Spectrum IDs:", first_batch['spectrum_id'])
print("Lengths:", first_batch['length'])

Wavelength shape: torch.Size([32, 25300])
Flux shape: torch.Size([32, 25300])
Spectrum IDs: ('63_apogee_121', '73_apogee_105', '73_apogee_270', '73_apogee_295', '20_apogee_104', '20_apogee_273', '20_apogee_274', '32_apogee_129', '32_apogee_214', '32_apogee_245', '185_galah_332', '185_galah_336', '185_galah_355', '185_galah_391', '182_apogee_102', '182_apogee_115', '182_apogee_145', '182_apogee_193', '182_apogee_197', '182_apogee_199', '182_apogee_202', '182_apogee_234', '182_apogee_239', '182_apogee_243', '182_apogee_250', '182_apogee_299', '182_galah_315', '182_galah_343', '182_galah_392', '59_apogee_114', '59_galah_324', '149_apogee_172')
Lengths: tensor([ 9495,  9495,  9495,  9495,  9495,  9495,  9495,  9495,  9495,  9495,
        20761, 20775, 20766, 20760,  9495,  9495,  9495,  9495,  9495,  9495,
         9495,  9495,  9495,  9495,  9495,  9495, 20758, 20776, 20767,  9495,
        20775,  9495])


In [12]:
from model import Generator, DownsamplingLayer, FullNetwork
from pe import PositionalEncoding

In [13]:
# Example configuration
latent_dim = 5  # Example latent dimension
output_dim = 25300  # Example output dimension (max length of wavelengths)
layers = [512, 512]  # Example hidden layers
activation_function = 'LeakyReLU'  # Example activation function

# Define Generator
generator = Generator(latent_dim, output_dim, layers, activation_function)

# Define downsampling layer
downsampling_layer = DownsamplingLayer()

# Create the full network
full_network = FullNetwork(generator, downsampling_layer)

# Define Positional Encoding
pe_args = (
    1,  # Dimension of the input (wavelengths)
    128,  # pe_dim example value
    1.0,  # omega example value
    1.0,  # sigma example value
    True,  # pe_bias example value
    42  # seed example value
)

positional_dim = 10

TypeError: FullNetwork.__init__() missing 1 required positional argument: 'positional_encoding'

In [None]:
positional_encoding = PositionalEncoding(pe_args)
print(positional_encoding)

In [None]:

# Generate positional encoding tensor
wavelength_grid = torch.tensor(generate_wavelength_grid(), dtype=torch.float32)
wavelength_grid = normalize_wavelengths(wavelength_grid, 17100)  # Normalize the wavelength grid

# Example input tensors
batch_size = 5
latent_z = first_batch['latent_code']  # Example latent space vector
observed_wavelength = torch.randn((batch_size, positional_dim))  # Example observed wavelengths

# Use positional encoding by reference
positional_encoding_ref = positional_encoding(wavelength_grid)

# Forward pass
generated_flux = full_network(latent_z, positional_encoding_ref, observed_wavelength)

# Print the shape of the generated flux
print(f"Generated flux shape: {generated_flux.shape}")

# Define the weighted MSE loss function
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2) 
    return loss

# Example target and weight tensors
sampled_flux = first_batch['flux']  # Example sampled flux (real flux)
weights = first_batch['weight']  # Example weights

# Compute the weighted MSE loss
loss = weighted_mse_loss(generated_flux, sampled_flux, weights)

# Backpropagation
optimizer = torch.optim.Adam(full_network.parameters(), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print total loss
print(f"Total loss: {loss.item()}")