In [1]:
from pe import PositionalEncoding
from model import Generator, DownsamplingLayer, FullNetwork

import torch
import torch.nn as nn
import numpy as np

import torch.optim as optim
from torch.utils.data import DataLoader
import yaml
import os
from datetime import datetime
import numpy as np
from utils import get_config2, resolve_path
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import h5py
import csv
import time
import glob
from dataset import IterableSpectraDataset, collate_fn
from tqdm import tqdm


In [2]:
def initialize_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device

In [6]:
device = initialize_device()

Using device: cpu


In [3]:
# Define the weighted MSE loss function
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2) 
    return loss

In [19]:
def load_configurations():
    config = get_config2()
    dataset_name = config['dataset_name']
    dataset_config = config['datasets'][dataset_name]
    data_path = resolve_path(dataset_config['path'])
   
    return (config, data_path)

In [20]:
(config, data_path) = load_configurations()

In [21]:
def prepare_datasets(config, data_path):
    n_samples_per_spectrum =  config['training']['n_samples_per_spectrum']
    n_subspectra = config['training']['n_subspectra']
    train_dataset = IterableSpectraDataset(data_path, is_validation=False,n_samples_per_spectrum=n_samples_per_spectrum, n_subspectra=n_subspectra)
    val_dataset = IterableSpectraDataset(data_path, is_validation=True, n_samples_per_spectrum=n_samples_per_spectrum, n_subspectra=n_subspectra)
    train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], collate_fn=collate_fn, num_workers=config['training']['num_workers'], pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], collate_fn=collate_fn, num_workers=config['training']['num_workers'], pin_memory=True)
    return train_loader, val_loader

In [22]:
train_loader, val_loader = prepare_datasets(config, data_path)

In [23]:
# Get the first batch
first_batch = next(iter(train_loader))

# Print shapes of wavelength and flux tensors, and the spectrum IDs
print("Wavelength shape:", first_batch['wavelength'].shape)
print("Flux shape:", first_batch['flux'].shape)
print("Spectrum IDs:", first_batch['spectrum_id'])
print("Lengths:", first_batch['length'])

Wavelength shape: torch.Size([5, 25300])
Flux shape: torch.Size([5, 25300])
Spectrum IDs: ('43_apogee_215', '43_apogee_215', '43_apogee_253', '43_apogee_253', '43_apogee_296')
Lengths: tensor([7000, 7000, 7000, 7000, 7000])


In [26]:
def initialize_optimizers(config, generator, latent_codes):
    optimizer_g = optim.Adam(generator.parameters(), lr=config['training']['learning_rate'])
    optimizer_l = optim.Adam([latent_codes], lr=config['training']['latent_learning_rate'])
    return optimizer_g, optimizer_l

In [24]:
def generate_wavelength_grid():
    grid = [
        (15050, 15850, 0.2),
        (15870, 16440, 0.2),
        (16475, 17005, 0.2),
        (4700, 4930, 0.05),
        (5650, 5880, 0.05),
        (6420, 6800, 0.05),
        (7500, 7920, 0.05)
    ]

    wavelength_grid = []
    for start, end, step in grid:
        wavelength_grid.extend(np.arange(start, end + step, step))
    
    return np.array(wavelength_grid)

def normalize_wavelengths(wavelengths, max_wavelength):
    return wavelengths / max_wavelength

In [25]:
# Generate wavelength grid (high resolution)
wavelength_grid = torch.tensor(generate_wavelength_grid(), dtype=torch.float16)
wavelength_grid = normalize_wavelengths(wavelength_grid, 17100)  # Normalize the wavelength grid


In [28]:
# Example configuration
latent_dim = config['training']['latent_dim']  # Example latent dimension
output_dim = config['training']['n_samples_per_spectrum']  # Example output dimension (max length of wavelengths)
layers = [512, 512]  # Example hidden layers
activation_function = 'LeakyReLU'  # Example activation function

In [29]:
# Define Positional Encoding
pe_args = (
    1,  # Dimension of the input (wavelengths)
    10,  # pe_dim example value
    1.0,  # omega example value
    1.0,  # sigma example value
    True,  # pe_bias example value
    42  # seed example value
)

In [30]:
positional_encoding = PositionalEncoding(pe_args)

In [31]:
# Define Generator
generator_output_dim= 1
input_dim = latent_dim + pe_args[1]  # latent_dim + pe_dim
generator = Generator(input_dim, generator_output_dim, layers, activation_function)

# Define downsampling layer
downsampling_layer = DownsamplingLayer()

# Create the full network
full_network = FullNetwork(generator, downsampling_layer, positional_encoding)



In [32]:
# Example input tensors
batch_size = config['training']['batch_size']
positional_dim = config['training']['n_samples_per_spectrum']  # Adjust according to your needs
# latent_z = torch.randn((batch_size, latent_dim), device=device)  # Example latent space vector


latent_codes = torch.randn(1, config['training']['latent_dim'], device=device)  # Initialize with a single random vector
dict_latent_codes = {}


## latent z is the batch['latent_code']

In [33]:
# Generate lower resolution real_wavelengths
real_wavelengths = torch.linspace(wavelength_grid.min(), wavelength_grid.max(), 8000)
real_wavelengths = real_wavelengths.unsqueeze(0).expand(batch_size, -1)  # Shape: [batch_size, 8000]


real_wavelengths= first_batch['wavelength']
print (real_wavelengths.shape)

torch.Size([5, 25300])


In [10]:

# Use positional encoding by reference
positional_encoding_ref = positional_encoding(wavelength_grid)

# Forward pass
generated_flux = full_network(latent_z, wavelength_grid, real_wavelengths)

# Print the shape of the generated flux
print(f"Generated flux shape: {generated_flux.shape}")

latent_z shape: torch.Size([4, 5])
wavelength_grid shape: torch.Size([34714])
real_wavelengths shape: torch.Size([4, 8000])
positional_encoding shape: torch.Size([4, 34714, 10])
latent_z_expanded shape: torch.Size([4, 34714, 5])
input_to_generator shape: torch.Size([4, 34714, 15])
generator_output shape: torch.Size([4, 34714, 1])
generator_output squeezed shape:  torch.Size([4, 34714])
high_res_flux shape: torch.Size([4, 34714])
high_res_flux shape: torch.Size([4, 34714])
high_res_wavelength shape: torch.Size([34714])
observed_wavelength shape: torch.Size([4, 8000])
Reshaped high_res_flux shape: torch.Size([4, 1, 34714, 1])
Grid shape: torch.Size([4, 1, 8000, 2])
Sampled flux shape: torch.Size([4, 1, 1, 8000])
generated_flux shape: torch.Size([4, 1, 8000])
Generated flux shape: torch.Size([4, 1, 8000])


In [13]:
# Example target and weight tensors
real_flux = torch.randn((batch_size, positional_dim))  # Example real flux from dataset
weights = torch.ones((batch_size, positional_dim))  # Example weights

# Compute the weighted MSE loss
generated_flux = generated_flux.squeeze(1)
loss = weighted_mse_loss(generated_flux, real_flux, weights)

# Backpropagation
optimizer = torch.optim.Adam(full_network.parameters(), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print total loss
print(f"Total loss: {loss.item()}")

Total loss: 1.0041241645812988
