In [1]:
from module import filteredAPOGEEDataset
import torch
from torch.utils.data import Dataset, DataLoader
from astropy.io import fits
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import torch
import torch.nn as nn


In [2]:
directory = '../../../projects/k-pop/spectra/apogee/dr17'
dataset = filteredAPOGEEDataset(directory, max_files=500, lower_bound=15250, upper_bound=15750)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)


In [3]:
start_wavelength = 15250  
end_wavelength = 15750   
delta_wavelength = 0.02 

wavelengths_high_res = np.arange(start_wavelength, end_wavelength, delta_wavelength)

In [8]:
def interpolate_spectra(wavelengths, flux, new_wavelengths):
    """
    Interpolates given flux data to a new set of wavelength points.

    Parameters:
    - wavelengths (np.array): Original wavelength points
    - flux (np.array): Flux values corresponding to the original wavelengths
    - new_wavelengths (np.array): New wavelength points for interpolation

    Returns:
    - np.array: Interpolated flux values at new wavelength points
    """
    interpolated_flux = np.interp(new_wavelengths, wavelengths, flux)
    return interpolated_flux

In [10]:
class HighResGenerator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(HighResGenerator, self).__init__()
        # Define the network layers
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(512),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(512),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.01),
            nn.BatchNorm1d(128),
            nn.Linear(128, output_dim)  # No activation in the output layer
        )

    def forward(self, z):
        return self.model(z)


In [11]:
mse_loss = nn.MSELoss()

In [12]:
# Hyperparameters
latent_dim = 50
output_dim = len(wavelengths_high_res)  
learning_rate_gen = 0.1
learning_rate_latent = 0.01
weight_decay = 1e-4

# Initialize the generator
generator = HighResGenerator(latent_dim, output_dim)

# Latent vectors
latent_vectors = torch.randn((len(dataset), latent_dim), requires_grad=True)

# Optimizer setup with separate learning rates
optimizer = torch.optim.Adam([
    {'params': generator.parameters(), 'lr': learning_rate_gen, 'weight_decay': weight_decay},
    {'params': latent_vectors, 'lr': learning_rate_latent}
])


In [35]:
interpolated_fluxes = []
for idx, data in dataset:
    
    interpolated_flux = interpolate_spectra(data['wavelength'], data['flux'], wavelengths_high_res)
    interpolated_fluxes.append(interpolated_flux)
interpolated_fluxes = np.array(interpolated_fluxes)


In [37]:
import matplotlib.pyplot as plt

def plot_spectra(real, generated, epoch, index):
    plt.figure(figsize=(10, 5))
    plt.plot(real, label='Real Spectra', color='blue')
    plt.plot(generated, label='Generated Spectra', color='red', linestyle='--')
    plt.title(f'Comparison of Real and Generated Spectra at Epoch {epoch}')
    plt.xlabel('Wavelength')
    plt.ylabel('Flux')
    plt.legend()
    plt.show()

In [None]:
num_epochs = 200  # Total number of epochs
batch_size = 10  # Batch size for training
loss_values = []

# # DataLoader setup
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)




for epoch in range(num_epochs):
    for flux in interpolated_fluxes:
        optimizer.zero_grad()
        flux = torch.tensor(flux, dtype=torch.float32)
        output = model(flux)
        loss = loss_func(output, flux)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


for epoch in range(num_epochs):
    for batch_indices, batch in dataloader:
        optimizer.zero_grad()

       
        real_spectra = batch['flux']  
        mask = batch['flux_mask']
        sigma = batch['variation']  
        batch_indices = batch_indices  # Custom indexing support might be needed

        generated_spectra = generator(latent_vectors[batch_indices])
        
        # Calculate the custom weighted loss
        # loss = weighted_mse_loss(generated_spectra, real_spectra, mask, sigma)
        
        loss = mse_loss(generated_spectra, real_spectra)

        loss.backward()
        optimizer.step()
        
    # Store the loss value
    loss_values.append(loss.item())
    if epoch % 10 == 0: 
            plot_spectra(real_spectra[0].detach().numpy(), generated_spectra[0].detach().numpy(), epoch, 0)

    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}/{num_epochs}, Loss: {loss.item()}')