In [None]:
import torch
import h5py
import numpy as np
import matplotlib.pyplot as plt
from model import Generator
from checkpoint import load_checkpoint
from utils import get_config, resolve_path

In [None]:

def generate_and_plot_spectrum():
    config = get_config()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize generator
    latent_dim = config['training']['latent_dim']
    output_dim = config['model']['output_dim']
    generator_layers = config['model']['generator_layers']
    activation_function = getattr(torch.nn, config['model']['activation_function'])
    generator = Generator(latent_dim, output_dim, generator_layers, activation_function).to(device)

    # Load best model checkpoint
    best_checkpoint_path = resolve_path(config['paths']['checkpoints']) + '/checkpoint_best.pth.tar'
    best_checkpoint = load_checkpoint(best_checkpoint_path)
    if best_checkpoint:
        generator.load_state_dict(best_checkpoint['state_dict'])
        print("Best model loaded successfully.")
    else:
        print("Error loading the best model.")
        return

    generator.eval()

    # Generate a random latent code
    latent_code = torch.randn(1, latent_dim, device=device)

    # Generate spectrum
    with torch.no_grad():
        generated_flux = generator(latent_code).squeeze().cpu().numpy()

    # Load an example spectrum from HDF5
    hdf5_file = resolve_path(config['paths']['hdf5_data'])
    with h5py.File(hdf5_file, 'r') as file:
        # Assuming the dataset key names are consistent with those used earlier
        keys = list(file.keys())
        data = file[keys[0]]  # Load the first spectrum
        real_flux = data['flux'][:]
        wavelength = data['wavelength'][:]  # Use the same wavelength for generated spectrum for consistency

    # Plot the generated and real spectra
    plt.figure(figsize=(10, 6))
    plt.plot(wavelength, generated_flux, label='Generated Spectrum', color='blue')
    plt.plot(wavelength, real_flux, label='Real Spectrum', color='red', alpha=0.5, linewidth=0.5)
    plt.xlabel('Wavelength (Angstroms)')
    plt.ylabel('Flux')
    plt.title('Comparison of Generated and Real Stellar Spectra')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    generate_and_plot_spectrum()