In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import yaml
import os
from datetime import datetime
import numpy as np
from utils import get_config, resolve_path
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt


# Custom modules
from dataset import APOGEEDataset
from model2 import Generator
from tqdm import tqdm
from checkpoint import save_checkpoint, load_checkpoint


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2)
    return loss

In [4]:
# Load configuration
config = get_config()


# Config paths and training params
data_path = resolve_path(config['paths']['hdf5_data'])
checkpoints_path = resolve_path(config['paths']['checkpoints'])
latent_path = resolve_path(config['paths']['latent'])
plots_path = resolve_path(config['paths']['plots'])
tensorboard_path = resolve_path(config['paths']['tensorboard'])

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
num_epochs = config['training']['num_epochs']
learning_rate = config['training']['learning_rate']
latent_learning_rate = config['training']['latent_learning_rate']
latent_dim = config['training']['latent_dim']
checkpoint_interval = config['training']['checkpoint_interval']


In [5]:
dataset = APOGEEDataset(data_path, max_files=config['training']['max_files'])
train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=config['training']['split_ratios'][1])

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False, num_workers=num_workers)



In [6]:
# Initialize generator and latent codes
generator = Generator(latent_dim, config['model']['output_dim'], config['model']['generator_layers'], config['model']['activation_function']).to('cuda')
latent_codes = torch.randn(len(train_loader.dataset), latent_dim, requires_grad=True, device='cuda')

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_l = optim.LBFGS([latent_codes], lr=latent_learning_rate)


In [7]:
# Define the path to the checkpoint files
latest_checkpoint_path = os.path.join(checkpoints_path, 'checkpoint_latest.pth.tar')
best_checkpoint_path = os.path.join(checkpoints_path, 'checkpoint_best.pth.tar')

# Attempt to load the latest checkpoint
latest_checkpoint = load_checkpoint(latest_checkpoint_path)
if latest_checkpoint:
    try:
        generator.load_state_dict(latest_checkpoint['generator_state'])
        latent_codes.data = latest_checkpoint['latent_codes']
        optimizer_g.load_state_dict(latest_checkpoint['optimizer_g_state'])
        optimizer_l.load_state_dict(latest_checkpoint['optimizer_l_state'])
        start_epoch = latest_checkpoint['epoch'] + 1  
        print("Loaded latest checkpoint.")
    except KeyError as e:
        print(f"Error loading state dictionaries from latest checkpoint: {e}")
else:
    # Initialize everything for a fresh start if no latest checkpoint is found
    generator.apply(Generator.init_weights)
    latent_codes = torch.randn(len(train_loader.dataset), latent_dim, device=device, requires_grad=True)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_l = torch.optim.LBFGS([latent_codes], lr=latent_learning_rate)
    start_epoch = 0

# Initialize the learning rate scheduler
scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_g, step_size=config['training']['scheduler_step_size'], gamma=config['training']['scheduler_gamma'])
scheduler_l = torch.optim.lr_scheduler.StepLR(optimizer_l, step_size=config['training']['scheduler_step_size'], gamma=config['training']['scheduler_gamma'])  # Assuming similar parameters

# Attempt to load the best checkpoint
best_checkpoint = load_checkpoint(best_checkpoint_path)
if best_checkpoint:
    try:
        best_val_loss = best_checkpoint['best_val_loss']
        print(f"Best validation loss from checkpoint: {best_val_loss}")
    except KeyError as e:
        print(f"Error retrieving best validation loss from checkpoint: {e}")
        best_val_loss = float('inf')
else:
    best_val_loss = float('inf')


No checkpoint found at '/arc/home/Amirabezine/deepSpectra/checkpoints/checkpoint_latest.pth.tar'
No checkpoint found at '/arc/home/Amirabezine/deepSpectra/checkpoints/checkpoint_best.pth.tar'


In [None]:
# Initialize loss history storage
loss_history = {
    'train': [],
    'val': []
}

# Training loop
for epoch in range(start_epoch, num_epochs):
    generator.train()
    epoch_losses = []
    train_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}")

    for batch in train_bar:
        indices = batch['index'].to(device)
        flux = batch['flux'].to(device)
        mask = batch['flux_mask'].to(device)
        ivar = batch['sigma'].to(device)  # Assuming sigma is the inverse variance

        # Step 1: Optimize generator weights
        optimizer_g.zero_grad()
        generated = generator(latent_codes[indices])
        loss_g = weighted_mse_loss(generated, flux, mask)
        loss_g.backward()
        optimizer_g.step()

        train_bar.set_postfix({"Batch Weight Loss": loss_g.item()})

        # Step 2: Freeze generator weights and optimize latent codes
        for param in generator.parameters():
            param.requires_grad = False

        def closure():
            optimizer_l.zero_grad()
            generated = generator(latent_codes[indices])
            
            # # Apply mask, ASK SEB ABOUT THIS
            # masked_generated = generated * mask
            
            loss_l = weighted_mse_loss(generated, flux, mask)
            loss_l.backward()
            epoch_losses.append(loss_l.item())
            train_bar.set_postfix({"Batch Latent Loss": loss_l.item()})
            return loss_l
        
        optimizer_l.step(closure)

        # Unfreeze generator weights
        for param in generator.parameters():
            param.requires_grad = True

    # Calculate and store average loss for the epoch
    average_train_loss = np.mean(epoch_losses)
    loss_history['train'].append(average_train_loss)
    print(f'Epoch {epoch+1} Average Train Loss: {average_train_loss:.4f}')

    # Validation phase
    generator.eval()
    val_losses = []
    val_bar = tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")
    with torch.no_grad():
        for batch in val_bar:
            indices = batch['index'].to(device)
            flux = batch['flux'].to(device)
            mask = batch['flux_mask'].to(device)
            
            generated = generator(latent_codes[indices])
            val_loss = weighted_mse_loss(generated, flux, mask)
            val_losses.append(val_loss.item())
            val_bar.set_postfix({"Batch Val Loss": val_loss.item()})

    average_val_loss = np.mean(val_losses)
    loss_history['val'].append(average_val_loss)
    print(f'Epoch {epoch+1} Average Validation Loss: {average_val_loss:.4f}')

    # Update learning rate scheduler
    scheduler_g.step()
    scheduler_l.step()

    # Checkpoint handling and latent codes saving
    checkpoint_state = {
        'epoch': epoch + 1,
        'state_dict': generator.state_dict(),
        'latent_codes': latent_codes,
        'optimizer_g_state': optimizer_g.state_dict(),
        'optimizer_l_state': optimizer_l.state_dict(),
        'train_loss': average_train_loss,
        'val_loss': average_val_loss,
        'best_loss': best_val_loss
    }
    save_checkpoint(checkpoint_state, filename=os.path.join(checkpoints_path, f'checkpoint_latest.pth.tar'))
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        save_checkpoint(checkpoint_state, filename=best_checkpoint_path)
    save_checkpoint(checkpoint_state, filename=os.path.join(checkpoints_path, f'checkpoint_epoch_{epoch+1}.pth.tar'))

    # Save all latent codes with their indices after every epoch
    all_latent_data = {'latent_codes': latent_codes.detach().cpu().numpy(), 'indices': torch.arange(latent_codes.size(0)).numpy()}
    np.save(os.path.join(latent_path, f'latent_codes_epoch_{epoch+1}.npy'), all_latent_data)

# After training, save the loss history for later analysis or plotting
np.save(os.path.join(checkpoints_path, 'loss_history.npy'), loss_history)


  self.pid = os.fork()
  self.pid = os.fork()
Training Epoch 1/60: 100%|██████████| 41/41 [00:04<00:00,  8.36it/s, Batch Latent Loss=0.00564]


Epoch 1 Average Train Loss: 0.2759


Validation Epoch 1/60: 100%|██████████| 10/10 [00:01<00:00,  6.54it/s, Batch Val Loss=0.00924]


Epoch 1 Average Validation Loss: 0.0092


Training Epoch 2/60: 100%|██████████| 41/41 [00:04<00:00,  8.53it/s, Batch Latent Loss=0.00102]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_losses(loss_history_path):
    # Load the loss history
    loss_history = np.load(loss_history_path, allow_pickle=True).item()
    
    # Extract training and validation losses
    train_losses = loss_history['train']
    val_losses = loss_history['val']
    
    # Create the plot
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_losses(os.path.join(checkpoints_path, 'loss_history.npy'))

In [None]:
def plot_latent_evolution(latent_dir, spectrum_index, num_epochs):
    latent_evolution = []
    
    # Load latent codes for each epoch and append the specific spectrum's latent code to the list
    for epoch in range(1, num_epochs + 1):
        latent_path = os.path.join(latent_dir, f'latent_codes_epoch_{epoch}.npy')
        latent_data = np.load(latent_path, allow_pickle=True).item()
        latent_codes = latent_data['latent_codes']
        latent_evolution.append(latent_codes[spectrum_index])

    latent_evolution = np.array(latent_evolution)

    # Plot each component of the latent code
    plt.figure(figsize=(12, 8))
    for i in range(latent_evolution.shape[1]):  # Assuming latent_codes have the same dimensionality across epochs
        plt.plot(latent_evolution[:, i], label=f'Latent Dimension {i + 1}')
    
    plt.title(f'Evolution of Latent Space for Spectrum Index {spectrum_index}')
    plt.xlabel('Epoch')
    plt.ylabel('Latent Value')
    plt.legend()
    plt.grid(True)
    plt.show()


plot_latent_evolution(latent_path, 0, 15)

In [None]:
def plot_real_vs_generated(checkpoints_path, latent_path, data_loader, generator, device, spectrum_index):
    """
    Plots the real versus generated spectrum for a given spectrum index.
    
    Args:
    - checkpoints_path (str): Path to the directory containing model checkpoints.
    - latent_path (str): Path to the directory containing saved latent codes.
    - data_loader (DataLoader): DataLoader containing the dataset with real spectra.
    - generator (torch.nn.Module): The generator model.
    - device (torch.device): The device on which PyTorch operations should be performed.
    - spectrum_index (int): The index of the spectrum to plot.
    """
    # Load the last latent code
    latent_files = sorted(os.listdir(latent_path))
    last_latent_file = os.path.join(latent_path, latent_files[-1])
    latents = np.load(last_latent_file, allow_pickle=True).item()
    latent_code = torch.tensor(latents['latent_codes'][spectrum_index]).to(device)

    # Generate the spectrum using the latent code
    generator.eval()
    with torch.no_grad():
        generated_spectrum = generator(latent_code.unsqueeze(0)).squeeze(0).cpu().numpy()

    # Load the real flux data
    for batch in data_loader:
        indices = batch['index']
        if spectrum_index in indices:
            real_spectrum = batch['flux'][indices == spectrum_index].squeeze(0).cpu().numpy()
            break

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(real_spectrum, label='Real Spectrum', color='blue')
    plt.plot(generated_spectrum, label='Generated Spectrum', color='red')
    plt.title(f'Comparison of Real and Generated Spectra for Spectrum Index {spectrum_index}')
    plt.xlabel('Wavelength Index')
    plt.ylabel('Flux')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example of usage
# Assuming you have 'generator', 'device' setup, and a 'data_loader' ready
plot_real_vs_generated(checkpoints_path, latent_path, train_loader, generator, device, spectrum_index=0)
