In [1]:
from pe import PositionalEncoding
from model import Generator, DownsamplingLayer, FullNetwork

import torch
import torch.nn as nn
import numpy as np
from dataset import IterableSpectraDataset, collate_fn
from torch.utils.data import DataLoader

In [2]:
hdf5_dir = '../data/healpixfiles_inter'
dataset = IterableSpectraDataset(hdf5_dir, n_samples_per_spectrum=4000, n_subspectra=5, yield_full_spectrum=True)
dataloader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)

# Get the first batch
first_batch = next(iter(dataloader))

# Print shapes of wavelength and flux tensors, and the spectrum IDs
print("Wavelength shape:", first_batch['wavelength'].shape)
print("Flux shape:", first_batch['flux'].shape)
print("Spectrum IDs:", first_batch['spectrum_id'])
print("Lengths:", first_batch['length'])

Exception: The expanded size of the tensor (20762) must match the existing size (20765) at non-singleton dimension 0.  Target sizes: [20762].  Tensor sizes: [20765] in group 184_galah_317
Exception: The expanded size of the tensor (20762) must match the existing size (20778) at non-singleton dimension 0.  Target sizes: [20762].  Tensor sizes: [20778] in group 184_galah_322
Exception: The expanded size of the tensor (20762) must match the existing size (20767) at non-singleton dimension 0.  Target sizes: [20762].  Tensor sizes: [20767] in group 191_galah_328
Exception: The expanded size of the tensor (20762) must match the existing size (20764) at non-singleton dimension 0.  Target sizes: [20762].  Tensor sizes: [20764] in group 161_galah_344
Exception: The expanded size of the tensor (20762) must match the existing size (20769) at non-singleton dimension 0.  Target sizes: [20762].  Tensor sizes: [20769] in group 161_galah_351
Exception: The expanded size of the tensor (20762) must matc

In [3]:
def generate_wavelength_grid():
    grid = [
        (15050, 15850, 0.2),
        (15870, 16440, 0.2),
        (16475, 17005, 0.2),
        (4700, 4930, 0.05),
        (5650, 5880, 0.05),
        (6420, 6800, 0.05),
        (7500, 7920, 0.05)
    ]

    wavelength_grid = []
    for start, end, step in grid:
        wavelength_grid.extend(np.arange(start, end + step, step))
    
    return np.array(wavelength_grid)

def normalize_wavelengths(wavelengths, max_wavelength):
    return wavelengths / max_wavelength

In [4]:
# Generate wavelength grid (high resolution)
wavelength_grid = torch.tensor(generate_wavelength_grid(), dtype=torch.float16)
wavelength_grid = normalize_wavelengths(wavelength_grid, 17100)  # Normalize the wavelength grid


In [5]:
# Example configuration
latent_dim = 20  # Example latent dimension
output_dim = 20762  # Example output dimension (max length of wavelengths)
layers = [512, 512]  # Example hidden layers
activation_function = 'LeakyReLU'  # Example activation function

In [6]:
# Define Positional Encoding
pe_args = (
    1,  # Dimension of the input (wavelengths)
    10,  # pe_dim example value
    1.0,  # omega example value
    1.0,  # sigma example value
    True,  # pe_bias example value
    42  # seed example value
)

In [7]:
positional_encoding = PositionalEncoding(pe_args)

In [8]:
# Define Generator
generator_output_dim= 1
input_dim = latent_dim + pe_args[1]  # latent_dim + pe_dim
generator = Generator(input_dim, generator_output_dim, layers, activation_function)

# Define downsampling layer
downsampling_layer = DownsamplingLayer()

# Create the full network
full_network = FullNetwork(generator, downsampling_layer, positional_encoding)



In [9]:
# Example input tensors
batch_size = 10
positional_dim = 25300  # Adjust according to your needs
latent_z = first_batch['latent_code']  # Example latent space vector


In [10]:
# Generate lower resolution real_wavelengths
real_wavelengths = first_batch['wavelength']
# real_wavelengths = real_wavelengths.unsqueeze(0).expand(batch_size, -1)  # Shape: [batch_size, 8000]


In [11]:

# Use positional encoding by reference
positional_encoding_ref = positional_encoding(wavelength_grid)

# Forward pass
generated_flux = full_network(latent_z, wavelength_grid, real_wavelengths)

# Print the shape of the generated flux
print(f"Generated flux shape: {generated_flux.shape}")

latent_z shape: torch.Size([10, 20])
wavelength_grid shape: torch.Size([34714])
real_wavelengths shape: torch.Size([10, 20762])
positional_encoding shape: torch.Size([10, 34714, 10])
latent_z_expanded shape: torch.Size([10, 34714, 20])
input_to_generator shape: torch.Size([10, 34714, 30])
generator_output shape: torch.Size([10, 34714, 1])
generator_output squeezed shape:  torch.Size([10, 34714])
high_res_flux shape: torch.Size([10, 34714])
high_res_flux shape: torch.Size([10, 34714])
high_res_wavelength shape: torch.Size([34714])
observed_wavelength shape: torch.Size([10, 20762])
Reshaped high_res_flux shape: torch.Size([10, 1, 34714, 1])
Grid shape: torch.Size([10, 1, 20762, 2])
Sampled flux shape: torch.Size([10, 1, 1, 20762])
generated_flux shape: torch.Size([10, 1, 20762])
Generated flux shape: torch.Size([10, 1, 20762])


In [12]:
# Define the weighted MSE loss function
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2) 
    return loss

In [13]:
# Example target and weight tensors
real_flux = first_batch['flux']  # Example real flux from dataset
weights = torch.ones((batch_size, positional_dim))  # Example weights

# Compute the weighted MSE loss
generated_flux = generated_flux.squeeze(1)
loss = weighted_mse_loss(generated_flux, real_flux, weights)

# Backpropagation
optimizer = torch.optim.Adam(full_network.parameters(), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print total loss
print(f"Total loss: {loss.item()}")

AssertionError: Shapes of input torch.Size([10, 20762]), target torch.Size([10, 20762]), and weight torch.Size([4, 20762]) must match