In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from model import Generator
from dataset import APOGEEDataset
from torch.utils.data import DataLoader
from utils import resolve_path, get_config

In [2]:


def load_model(config, checkpoint_path, output_dim):
    latent_dim = config['training']['latent_dim']
    generator_layers = config['model']['generator_layers']
    activation_function = getattr(torch.nn, config['model']['activation_function'])
    
    # Initialize the generator model with the correct output_dim
    generator = Generator(latent_dim, output_dim, generator_layers, activation_function)
    
    # Load the saved model weights
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    generator.load_state_dict(checkpoint['state_dict'], strict=False)
    
    return generator



In [3]:
def generate_spectrum(generator, latent_dim):
    # Generate a random latent vector
    latent_vector = torch.randn(1, latent_dim)
    
    # Use the generator to produce a new spectrum
    with torch.no_grad():
        generated_spectrum = generator(latent_vector).cpu().numpy().flatten()
    
    return generated_spectrum

In [4]:
def plot_spectrum(wavelength, generated_flux, original_flux=None):
    import matplotlib.pyplot as plt
    import numpy as np

    # Ensure wavelength and generated_flux are numpy arrays
    wavelength = np.array(wavelength).flatten()
    generated_flux = np.array(generated_flux).flatten()
    
    # Print the shapes for debugging
    print("wavelength shape:", wavelength.shape)
    print("generated_flux shape:", generated_flux.shape)
    if original_flux is not None:
        original_flux = np.array(original_flux).flatten()
        print("original_flux shape:", original_flux.shape)
    
    # Sanity check on the sizes
    if wavelength.shape[0] != generated_flux.shape[0]:
        print("Mismatch in wavelength and generated_flux array sizes.")
        return

    if original_flux is not None and original_flux.shape[0] != wavelength.shape[0]:
        print("Mismatch in wavelength and original_flux array sizes.")
        return

    plt.figure(figsize=(10, 6))
    plt.plot(wavelength, generated_flux, label='Generated Flux', color='blue')
    if original_flux is not None:
        plt.plot(wavelength, original_flux, label='Original Flux', color='red', alpha=0.5)
    plt.xlabel('Wavelength')
    plt.ylabel('Flux')
    plt.title('Flux vs. Wavelength')
    plt.legend()
    plt.show()


In [6]:
if __name__ == "__main__":
    config = get_config()
    
    hdf5_path = resolve_path(config['paths']['hdf5_data'])
    dataset = APOGEEDataset("../data/hdf5/spectra3.hdf5", max_files=1)  # Load a single file to get the output_dim
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    sample = next(iter(dataloader))
    flux_example = sample['flux']
    wavelength_example = sample['wavelength']
    output_dim = flux_example.size(0)  # Determine the output_dim based on the sample
    
    best_model_path = resolve_path(config['paths']['checkpoints']) + '/checkpoint_best.pth.tar'
    generator = load_model(config, best_model_path, output_dim)
    
    generated_spectrum = generate_spectrum(generator, config['training']['latent_dim'])
    plot_spectrum(wavelength_example.numpy(), generated_spectrum, original_flux=flux_example.numpy())


wavelength shape: (8575,)
generated_flux shape: (1,)
original_flux shape: (8575,)
Mismatch in wavelength and generated_flux array sizes.
