In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import yaml
import os
from datetime import datetime
import numpy as np
from utils import get_config, resolve_path
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt


# Custom modules
from dataset import APOGEEDataset
from model2 import Generator
from tqdm import tqdm
from checkpoint import save_checkpoint, load_checkpoint


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
def weighted_mse_loss(input, target, weight):
    assert input.shape == target.shape == weight.shape, f'Shapes of input {input.shape}, target {target.shape}, and weight {weight.shape} must match'
    loss = torch.mean(weight * (input - target) ** 2)
    return loss

In [4]:
# Load configuration
config = get_config()


# Config paths and training params
data_path = resolve_path(config['paths']['hdf5_data'])
checkpoints_path = resolve_path(config['paths']['checkpoints'])
latent_path = resolve_path(config['paths']['latent'])
plots_path = resolve_path(config['paths']['plots'])
tensorboard_path = resolve_path(config['paths']['tensorboard'])

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
num_epochs = config['training']['num_epochs']
learning_rate = config['training']['learning_rate']
latent_learning_rate = config['training']['latent_learning_rate']
latent_dim = config['training']['latent_dim']
checkpoint_interval = config['training']['checkpoint_interval']


In [5]:
dataset = APOGEEDataset(data_path, max_files=config['training']['max_files'])
train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=config['training']['split_ratios'][1])

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False, num_workers=num_workers)



In [6]:
# Initialize generator and latent codes
generator = Generator(latent_dim, config['model']['output_dim'], config['model']['generator_layers'], config['model']['activation_function']).to('cuda')
latent_codes = torch.randn(len(train_loader.dataset), latent_dim, requires_grad=True, device='cuda')

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_l = optim.LBFGS([latent_codes], lr=latent_learning_rate)


In [7]:
# Define the path to the checkpoint files
latest_checkpoint_path = os.path.join(checkpoints_path, 'checkpoint_latest.pth.tar')
best_checkpoint_path = os.path.join(checkpoints_path, 'checkpoint_best.pth.tar')

# Attempt to load the latest checkpoint
latest_checkpoint = load_checkpoint(latest_checkpoint_path)
if latest_checkpoint:
    try:
        generator.load_state_dict(latest_checkpoint['generator_state'])
        latent_codes.data = latest_checkpoint['latent_codes']
        optimizer_g.load_state_dict(latest_checkpoint['optimizer_g_state'])
        optimizer_l.load_state_dict(latest_checkpoint['optimizer_l_state'])
        start_epoch = latest_checkpoint['epoch'] + 1  
        print("Loaded latest checkpoint.")
    except KeyError as e:
        print(f"Error loading state dictionaries from latest checkpoint: {e}")
else:
    # Initialize everything for a fresh start if no latest checkpoint is found
    generator.apply(Generator.init_weights)
    latent_codes = torch.randn(len(train_loader.dataset), latent_dim, device=device, requires_grad=True)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_l = torch.optim.LBFGS([latent_codes], lr=latent_learning_rate)
    start_epoch = 0

# Initialize the learning rate scheduler
scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_g, step_size=config['training']['scheduler_step_size'], gamma=config['training']['scheduler_gamma'])
scheduler_l = torch.optim.lr_scheduler.StepLR(optimizer_l, step_size=config['training']['scheduler_step_size'], gamma=config['training']['scheduler_gamma'])  # Assuming similar parameters

# Attempt to load the best checkpoint
best_checkpoint = load_checkpoint(best_checkpoint_path)
if best_checkpoint:
    try:
        best_val_loss = best_checkpoint['best_val_loss']
        print(f"Best validation loss from checkpoint: {best_val_loss}")
    except KeyError as e:
        print(f"Error retrieving best validation loss from checkpoint: {e}")
        best_val_loss = float('inf')
else:
    best_val_loss = float('inf')


No checkpoint found at '/arc/home/Amirabezine/deepSpectra/checkpoints/checkpoint_latest.pth.tar'
No checkpoint found at '/arc/home/Amirabezine/deepSpectra/checkpoints/checkpoint_best.pth.tar'


In [None]:
# Initialize loss history storage
loss_history = {
    'train': [],
    'val': []
}

# Training loop
for epoch in range(start_epoch, num_epochs):
    generator.train()
    epoch_losses = []
    train_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}")

    for batch in train_bar:
        indices = batch['index'].to(device)
        flux = batch['flux'].to(device)
        mask = batch['flux_mask'].to(device)
        ivar = batch['sigma'].to(device)  # Assuming sigma is the inverse variance

        # Step 1: Optimize generator weights
        optimizer_g.zero_grad()
        generated = generator(latent_codes[indices])
        loss_g = weighted_mse_loss(generated, flux, mask)
        loss_g.backward()
        optimizer_g.step()

        train_bar.set_postfix({"Batch Weight Loss": loss_g.item()})

        # Step 2: Freeze generator weights and optimize latent codes
        for param in generator.parameters():
            param.requires_grad = False

        def closure():
            optimizer_l.zero_grad()
            generated = generator(latent_codes[indices])
            
            # # Apply mask, ASK SEB ABOUT THIS
            # masked_generated = generated * mask
            
            loss_l = weighted_mse_loss(generated, flux, mask)
            loss_l.backward()
            epoch_losses.append(loss_l.item())
            train_bar.set_postfix({"Batch Latent Loss": loss_l.item()})
            return loss_l
        
        optimizer_l.step(closure)

        # Unfreeze generator weights
        for param in generator.parameters():
            param.requires_grad = True

    # Calculate and store average loss for the epoch
    average_train_loss = np.mean(epoch_losses)
    loss_history['train'].append(average_train_loss)
    print(f'Epoch {epoch+1} Average Train Loss: {average_train_loss:.4f}')

    # Validation phase
    generator.eval()
    val_losses = []
    val_bar = tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")
    with torch.no_grad():
        for batch in val_bar:
            indices = batch['index'].to(device)
            flux = batch['flux'].to(device)
            mask = batch['flux_mask'].to(device)
            
            generated = generator(latent_codes[indices])
            val_loss = weighted_mse_loss(generated, flux, mask)
            val_losses.append(val_loss.item())
            val_bar.set_postfix({"Batch Val Loss": val_loss.item()})

    average_val_loss = np.mean(val_losses)
    loss_history['val'].append(average_val_loss)
    print(f'Epoch {epoch+1} Average Validation Loss: {average_val_loss:.4f}')

    # Update learning rate scheduler
    scheduler_g.step()
    scheduler_l.step()

    # Checkpoint handling and latent codes saving
    checkpoint_state = {
        'epoch': epoch + 1,
        'state_dict': generator.state_dict(),
        'latent_codes': latent_codes,
        'optimizer_g_state': optimizer_g.state_dict(),
        'optimizer_l_state': optimizer_l.state_dict(),
        'train_loss': average_train_loss,
        'val_loss': average_val_loss,
        'best_loss': best_val_loss
    }
    save_checkpoint(checkpoint_state, filename=os.path.join(checkpoints_path, f'checkpoint_latest.pth.tar'))
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        save_checkpoint(checkpoint_state, filename=best_checkpoint_path)
    save_checkpoint(checkpoint_state, filename=os.path.join(checkpoints_path, f'checkpoint_epoch_{epoch+1}.pth.tar'))

    # Save all latent codes with their indices after every epoch
    all_latent_data = {'latent_codes': latent_codes.detach().cpu().numpy(), 'indices': torch.arange(latent_codes.size(0)).numpy()}
    np.save(os.path.join(latent_path, f'latent_codes_epoch_{epoch+1}.npy'), all_latent_data)

# After training, save the loss history for later analysis or plotting
np.save(os.path.join(checkpoints_path, 'loss_history.npy'), loss_history)


  self.pid = os.fork()
  self.pid = os.fork()
Training Epoch 1/60: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s, Batch Latent Loss=0.00342]


Epoch 1 Average Train Loss: 0.0609


Validation Epoch 1/60: 100%|██████████| 10/10 [00:01<00:00,  8.17it/s, Batch Val Loss=0.00529]


Epoch 1 Average Validation Loss: 0.0046


Training Epoch 2/60: 100%|██████████| 41/41 [00:10<00:00,  3.98it/s, Batch Latent Loss=0.00386] 


Epoch 2 Average Train Loss: 0.0037


Validation Epoch 2/60: 100%|██████████| 10/10 [00:01<00:00,  8.31it/s, Batch Val Loss=0.00192]


Epoch 2 Average Validation Loss: 0.0025


Training Epoch 3/60: 100%|██████████| 41/41 [00:03<00:00, 11.44it/s, Batch Latent Loss=0.00163] 


Epoch 3 Average Train Loss: 0.0035


Validation Epoch 3/60: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s, Batch Val Loss=0.0019] 


Epoch 3 Average Validation Loss: 0.0026


Training Epoch 4/60: 100%|██████████| 41/41 [00:03<00:00, 11.36it/s, Batch Latent Loss=0.00296] 


Epoch 4 Average Train Loss: 0.0035


Validation Epoch 4/60: 100%|██████████| 10/10 [00:01<00:00,  8.33it/s, Batch Val Loss=0.00237]


Epoch 4 Average Validation Loss: 0.0031


Training Epoch 5/60: 100%|██████████| 41/41 [00:04<00:00,  9.73it/s, Batch Latent Loss=0.00289] 


Epoch 5 Average Train Loss: 0.0039


Validation Epoch 5/60: 100%|██████████| 10/10 [00:01<00:00,  8.68it/s, Batch Val Loss=0.00241]


Epoch 5 Average Validation Loss: 0.0028


Training Epoch 6/60: 100%|██████████| 41/41 [00:03<00:00, 12.10it/s, Batch Latent Loss=0.00182] 


Epoch 6 Average Train Loss: 0.0037


Validation Epoch 6/60: 100%|██████████| 10/10 [00:02<00:00,  4.44it/s, Batch Val Loss=0.00188]


Epoch 6 Average Validation Loss: 0.0025


Training Epoch 7/60: 100%|██████████| 41/41 [00:04<00:00,  9.52it/s, Batch Latent Loss=0.00272]


Epoch 7 Average Train Loss: 0.0045


Validation Epoch 7/60: 100%|██████████| 10/10 [00:01<00:00,  7.95it/s, Batch Val Loss=0.00188]


Epoch 7 Average Validation Loss: 0.0070


Training Epoch 8/60: 100%|██████████| 41/41 [00:03<00:00, 11.84it/s, Batch Latent Loss=0.00678]


Epoch 8 Average Train Loss: 0.0059


Validation Epoch 8/60: 100%|██████████| 10/10 [00:01<00:00,  7.69it/s, Batch Val Loss=0.00799]


Epoch 8 Average Validation Loss: 0.0082


Training Epoch 9/60: 100%|██████████| 41/41 [00:03<00:00, 12.13it/s, Batch Latent Loss=0.00605] 


Epoch 9 Average Train Loss: 0.0062


Validation Epoch 9/60: 100%|██████████| 10/10 [00:01<00:00,  7.36it/s, Batch Val Loss=0.0104]


Epoch 9 Average Validation Loss: 0.0077


Training Epoch 10/60: 100%|██████████| 41/41 [00:03<00:00, 12.81it/s, Batch Latent Loss=0.00228]


Epoch 10 Average Train Loss: 0.0045


Validation Epoch 10/60: 100%|██████████| 10/10 [00:01<00:00,  8.04it/s, Batch Val Loss=0.00296]


Epoch 10 Average Validation Loss: 0.0031


Training Epoch 11/60: 100%|██████████| 41/41 [00:03<00:00, 10.66it/s, Batch Latent Loss=0.00141] 


Epoch 11 Average Train Loss: 0.0032


Validation Epoch 11/60: 100%|██████████| 10/10 [00:01<00:00,  8.33it/s, Batch Val Loss=0.0017]


Epoch 11 Average Validation Loss: 0.0025


Training Epoch 12/60: 100%|██████████| 41/41 [00:03<00:00, 12.48it/s, Batch Latent Loss=0.0014]  


Epoch 12 Average Train Loss: 0.0035


Validation Epoch 12/60: 100%|██████████| 10/10 [00:01<00:00,  8.21it/s, Batch Val Loss=0.00157]


Epoch 12 Average Validation Loss: 0.0023


Training Epoch 13/60: 100%|██████████| 41/41 [00:03<00:00, 11.62it/s, Batch Latent Loss=0.00103] 


Epoch 13 Average Train Loss: 0.0033


Validation Epoch 13/60: 100%|██████████| 10/10 [00:01<00:00,  8.13it/s, Batch Val Loss=0.00168]


Epoch 13 Average Validation Loss: 0.0024


Training Epoch 14/60: 100%|██████████| 41/41 [00:03<00:00, 10.32it/s, Batch Latent Loss=0.00145] 


Epoch 14 Average Train Loss: 0.0038


Validation Epoch 14/60: 100%|██████████| 10/10 [00:01<00:00,  8.32it/s, Batch Val Loss=0.00396]


Epoch 14 Average Validation Loss: 0.0036


Training Epoch 15/60: 100%|██████████| 41/41 [00:03<00:00, 11.71it/s, Batch Latent Loss=0.000698]


Epoch 15 Average Train Loss: 0.0033


Validation Epoch 15/60: 100%|██████████| 10/10 [00:01<00:00,  7.77it/s, Batch Val Loss=0.00158]


Epoch 15 Average Validation Loss: 0.0023


Training Epoch 16/60: 100%|██████████| 41/41 [00:03<00:00, 12.95it/s, Batch Latent Loss=0.00884] 


Epoch 16 Average Train Loss: 0.0041


Validation Epoch 16/60: 100%|██████████| 10/10 [00:01<00:00,  8.21it/s, Batch Val Loss=0.00176]


Epoch 16 Average Validation Loss: 0.0026


Training Epoch 17/60: 100%|██████████| 41/41 [00:03<00:00, 12.97it/s, Batch Latent Loss=0.00263] 


Epoch 17 Average Train Loss: 0.0034


Validation Epoch 17/60: 100%|██████████| 10/10 [00:01<00:00,  8.25it/s, Batch Val Loss=0.00153]


Epoch 17 Average Validation Loss: 0.0023


Training Epoch 18/60: 100%|██████████| 41/41 [00:03<00:00, 12.32it/s, Batch Latent Loss=0.00106]


Epoch 18 Average Train Loss: 0.0032


Validation Epoch 18/60: 100%|██████████| 10/10 [00:01<00:00,  8.57it/s, Batch Val Loss=0.00175]


Epoch 18 Average Validation Loss: 0.0023


Training Epoch 19/60: 100%|██████████| 41/41 [00:03<00:00, 11.54it/s, Batch Latent Loss=0.00225] 


Epoch 19 Average Train Loss: 0.0032


Validation Epoch 19/60: 100%|██████████| 10/10 [00:01<00:00,  8.33it/s, Batch Val Loss=0.00155]


Epoch 19 Average Validation Loss: 0.0023


Training Epoch 20/60: 100%|██████████| 41/41 [00:03<00:00, 12.86it/s, Batch Latent Loss=0.00279] 


Epoch 20 Average Train Loss: 0.0033


Validation Epoch 20/60: 100%|██████████| 10/10 [00:01<00:00,  8.59it/s, Batch Val Loss=0.00166]


Epoch 20 Average Validation Loss: 0.0023


Training Epoch 21/60: 100%|██████████| 41/41 [00:03<00:00, 11.69it/s, Batch Latent Loss=0.000628]


Epoch 21 Average Train Loss: 0.0033


Validation Epoch 21/60: 100%|██████████| 10/10 [00:01<00:00,  7.25it/s, Batch Val Loss=0.00181]


Epoch 21 Average Validation Loss: 0.0026


Training Epoch 22/60: 100%|██████████| 41/41 [00:03<00:00, 12.38it/s, Batch Latent Loss=0.00285] 


Epoch 22 Average Train Loss: 0.0037


Validation Epoch 22/60: 100%|██████████| 10/10 [00:01<00:00,  8.83it/s, Batch Val Loss=0.00261]


Epoch 22 Average Validation Loss: 0.0036


Training Epoch 23/60: 100%|██████████| 41/41 [00:03<00:00, 12.34it/s, Batch Latent Loss=0.00283]


Epoch 23 Average Train Loss: 0.0042


Validation Epoch 23/60: 100%|██████████| 10/10 [00:01<00:00,  7.86it/s, Batch Val Loss=0.00269]


Epoch 23 Average Validation Loss: 0.0035


Training Epoch 24/60: 100%|██████████| 41/41 [00:04<00:00,  9.99it/s, Batch Latent Loss=0.0101]  


Epoch 24 Average Train Loss: 0.0040


Validation Epoch 24/60: 100%|██████████| 10/10 [00:01<00:00,  6.78it/s, Batch Val Loss=0.00168]


Epoch 24 Average Validation Loss: 0.0024


Training Epoch 25/60: 100%|██████████| 41/41 [00:03<00:00, 10.88it/s, Batch Latent Loss=0.0108]  


Epoch 25 Average Train Loss: 0.0040


Validation Epoch 25/60: 100%|██████████| 10/10 [00:01<00:00,  7.55it/s, Batch Val Loss=0.00193]


Epoch 25 Average Validation Loss: 0.0028


Training Epoch 26/60: 100%|██████████| 41/41 [00:04<00:00, 10.24it/s, Batch Latent Loss=0.0014]  


Epoch 26 Average Train Loss: 0.0036


Validation Epoch 26/60: 100%|██████████| 10/10 [00:01<00:00,  8.30it/s, Batch Val Loss=0.00366]


Epoch 26 Average Validation Loss: 0.0032


Training Epoch 27/60: 100%|██████████| 41/41 [00:04<00:00, 10.08it/s, Batch Latent Loss=0.00312]


Epoch 27 Average Train Loss: 0.0040


Validation Epoch 27/60: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s, Batch Val Loss=0.00411]


Epoch 27 Average Validation Loss: 0.0046


Training Epoch 28/60: 100%|██████████| 41/41 [00:03<00:00, 10.96it/s, Batch Latent Loss=0.00102]


Epoch 28 Average Train Loss: 0.0046


Validation Epoch 28/60: 100%|██████████| 10/10 [00:01<00:00,  8.60it/s, Batch Val Loss=0.00244]


Epoch 28 Average Validation Loss: 0.0028


Training Epoch 29/60: 100%|██████████| 41/41 [00:03<00:00, 12.54it/s, Batch Latent Loss=0.00117] 


Epoch 29 Average Train Loss: 0.0034


Validation Epoch 29/60: 100%|██████████| 10/10 [00:01<00:00,  8.28it/s, Batch Val Loss=0.00174]


Epoch 29 Average Validation Loss: 0.0027


Training Epoch 30/60: 100%|██████████| 41/41 [00:03<00:00, 12.17it/s, Batch Latent Loss=0.00124] 


Epoch 30 Average Train Loss: 0.0037


Validation Epoch 30/60: 100%|██████████| 10/10 [00:01<00:00,  8.15it/s, Batch Val Loss=0.0025]


Epoch 30 Average Validation Loss: 0.0030


Training Epoch 31/60: 100%|██████████| 41/41 [00:03<00:00, 11.39it/s, Batch Latent Loss=0.00214]


Epoch 31 Average Train Loss: 0.0041


Validation Epoch 31/60: 100%|██████████| 10/10 [00:01<00:00,  7.06it/s, Batch Val Loss=0.00221]


Epoch 31 Average Validation Loss: 0.0032


Training Epoch 32/60: 100%|██████████| 41/41 [00:03<00:00, 10.25it/s, Batch Latent Loss=0.00402] 


Epoch 32 Average Train Loss: 0.0040


Validation Epoch 32/60: 100%|██████████| 10/10 [00:01<00:00,  8.73it/s, Batch Val Loss=0.00527]


Epoch 32 Average Validation Loss: 0.0113


Training Epoch 33/60: 100%|██████████| 41/41 [00:03<00:00, 12.35it/s, Batch Latent Loss=0.00439]


Epoch 33 Average Train Loss: 0.0070


Validation Epoch 33/60: 100%|██████████| 10/10 [00:01<00:00,  9.12it/s, Batch Val Loss=0.00293]


Epoch 33 Average Validation Loss: 0.0035


Training Epoch 34/60: 100%|██████████| 41/41 [00:03<00:00, 13.47it/s, Batch Latent Loss=0.00301]


Epoch 34 Average Train Loss: 0.0066


Validation Epoch 34/60: 100%|██████████| 10/10 [00:01<00:00,  8.84it/s, Batch Val Loss=0.00426]


Epoch 34 Average Validation Loss: 0.0045


Training Epoch 35/60: 100%|██████████| 41/41 [00:03<00:00, 12.65it/s, Batch Latent Loss=0.0013] 


Epoch 35 Average Train Loss: 0.0049


Validation Epoch 35/60: 100%|██████████| 10/10 [00:01<00:00,  8.49it/s, Batch Val Loss=0.00207]


Epoch 35 Average Validation Loss: 0.0026


Training Epoch 36/60: 100%|██████████| 41/41 [00:03<00:00, 10.90it/s, Batch Latent Loss=0.00103] 


Epoch 36 Average Train Loss: 0.0033


Validation Epoch 36/60: 100%|██████████| 10/10 [00:01<00:00,  8.63it/s, Batch Val Loss=0.00159]


Epoch 36 Average Validation Loss: 0.0023


Training Epoch 37/60: 100%|██████████| 41/41 [00:03<00:00, 13.06it/s, Batch Latent Loss=0.000794]


Epoch 37 Average Train Loss: 0.0032


Validation Epoch 37/60: 100%|██████████| 10/10 [00:01<00:00,  8.54it/s, Batch Val Loss=0.00146]


Epoch 37 Average Validation Loss: 0.0022


Training Epoch 38/60: 100%|██████████| 41/41 [00:03<00:00, 11.77it/s, Batch Latent Loss=0.000964]


Epoch 38 Average Train Loss: 0.0032


Validation Epoch 38/60: 100%|██████████| 10/10 [00:01<00:00,  6.34it/s, Batch Val Loss=0.00155]


Epoch 38 Average Validation Loss: 0.0023


Training Epoch 39/60: 100%|██████████| 41/41 [00:03<00:00, 13.25it/s, Batch Latent Loss=0.00173] 


Epoch 39 Average Train Loss: 0.0033


Validation Epoch 39/60: 100%|██████████| 10/10 [00:01<00:00,  8.05it/s, Batch Val Loss=0.00155]


Epoch 39 Average Validation Loss: 0.0023


Training Epoch 40/60: 100%|██████████| 41/41 [00:03<00:00, 11.37it/s, Batch Latent Loss=0.00197] 


Epoch 40 Average Train Loss: 0.0033


Validation Epoch 40/60: 100%|██████████| 10/10 [00:01<00:00,  8.46it/s, Batch Val Loss=0.00185]


Epoch 40 Average Validation Loss: 0.0024


Training Epoch 41/60: 100%|██████████| 41/41 [00:03<00:00, 10.97it/s, Batch Latent Loss=0.0015]  


Epoch 41 Average Train Loss: 0.0033


Validation Epoch 41/60: 100%|██████████| 10/10 [00:01<00:00,  8.70it/s, Batch Val Loss=0.0016] 


Epoch 41 Average Validation Loss: 0.0023


Training Epoch 42/60: 100%|██████████| 41/41 [00:03<00:00, 11.85it/s, Batch Latent Loss=0.00194] 


Epoch 42 Average Train Loss: 0.0039


Validation Epoch 42/60: 100%|██████████| 10/10 [00:01<00:00,  8.46it/s, Batch Val Loss=0.00186]


Epoch 42 Average Validation Loss: 0.0024


Training Epoch 43/60: 100%|██████████| 41/41 [00:06<00:00,  6.25it/s, Batch Latent Loss=0.0043]  


Epoch 43 Average Train Loss: 0.0032


Validation Epoch 43/60: 100%|██████████| 10/10 [00:01<00:00,  5.08it/s, Batch Val Loss=0.00164]


Epoch 43 Average Validation Loss: 0.0024


Training Epoch 44/60: 100%|██████████| 41/41 [00:03<00:00, 11.98it/s, Batch Latent Loss=0.000488]


Epoch 44 Average Train Loss: 0.0032


Validation Epoch 44/60: 100%|██████████| 10/10 [00:02<00:00,  4.86it/s, Batch Val Loss=0.00157]


Epoch 44 Average Validation Loss: 0.0022


Training Epoch 45/60: 100%|██████████| 41/41 [00:03<00:00, 11.48it/s, Batch Latent Loss=0.00139] 


Epoch 45 Average Train Loss: 0.0032


Validation Epoch 45/60: 100%|██████████| 10/10 [00:01<00:00,  7.19it/s, Batch Val Loss=0.00168]


Epoch 45 Average Validation Loss: 0.0023


Training Epoch 46/60: 100%|██████████| 41/41 [00:03<00:00, 12.66it/s, Batch Latent Loss=0.00171] 


Epoch 46 Average Train Loss: 0.0032


Validation Epoch 46/60: 100%|██████████| 10/10 [00:01<00:00,  8.56it/s, Batch Val Loss=0.00156]


Epoch 46 Average Validation Loss: 0.0023


Training Epoch 47/60: 100%|██████████| 41/41 [00:04<00:00,  9.03it/s, Batch Latent Loss=0.00194] 


Epoch 47 Average Train Loss: 0.0033


Validation Epoch 47/60: 100%|██████████| 10/10 [00:01<00:00,  8.80it/s, Batch Val Loss=0.00198]


Epoch 47 Average Validation Loss: 0.0025


Training Epoch 48/60: 100%|██████████| 41/41 [00:03<00:00, 12.51it/s, Batch Latent Loss=0.000906]


Epoch 48 Average Train Loss: 0.0032


Validation Epoch 48/60: 100%|██████████| 10/10 [00:01<00:00,  8.62it/s, Batch Val Loss=0.00177]


Epoch 48 Average Validation Loss: 0.0023


Training Epoch 49/60: 100%|██████████| 41/41 [00:03<00:00, 12.11it/s, Batch Latent Loss=0.00237] 


Epoch 49 Average Train Loss: 0.0032


Validation Epoch 49/60: 100%|██████████| 10/10 [00:02<00:00,  4.22it/s, Batch Val Loss=0.00155]


Epoch 49 Average Validation Loss: 0.0023


Training Epoch 50/60: 100%|██████████| 41/41 [00:04<00:00,  9.61it/s, Batch Latent Loss=0.00414] 


Epoch 50 Average Train Loss: 0.0033


Validation Epoch 50/60: 100%|██████████| 10/10 [00:01<00:00,  6.30it/s, Batch Val Loss=0.00145]


Epoch 50 Average Validation Loss: 0.0023


Training Epoch 51/60: 100%|██████████| 41/41 [00:03<00:00, 10.77it/s, Batch Latent Loss=0.000747]


Epoch 51 Average Train Loss: 0.0032


Validation Epoch 51/60: 100%|██████████| 10/10 [00:01<00:00,  6.66it/s, Batch Val Loss=0.00161]


Epoch 51 Average Validation Loss: 0.0023


Training Epoch 52/60: 100%|██████████| 41/41 [00:03<00:00, 12.34it/s, Batch Latent Loss=0.00248] 


Epoch 52 Average Train Loss: 0.0032


Validation Epoch 52/60: 100%|██████████| 10/10 [00:01<00:00,  9.31it/s, Batch Val Loss=0.00168]


Epoch 52 Average Validation Loss: 0.0023


Training Epoch 53/60: 100%|██████████| 41/41 [00:03<00:00, 13.57it/s, Batch Latent Loss=0.000927]


Epoch 53 Average Train Loss: 0.0032


Validation Epoch 53/60: 100%|██████████| 10/10 [00:01<00:00,  8.94it/s, Batch Val Loss=0.0015] 


Epoch 53 Average Validation Loss: 0.0022


Training Epoch 54/60: 100%|██████████| 41/41 [00:03<00:00, 12.93it/s, Batch Latent Loss=0.00121] 


Epoch 54 Average Train Loss: 0.0032


Validation Epoch 54/60: 100%|██████████| 10/10 [00:01<00:00,  9.08it/s, Batch Val Loss=0.00164]


Epoch 54 Average Validation Loss: 0.0023


Training Epoch 55/60: 100%|██████████| 41/41 [00:03<00:00, 12.86it/s, Batch Latent Loss=0.0127]  


Epoch 55 Average Train Loss: 0.0037


Validation Epoch 55/60: 100%|██████████| 10/10 [00:01<00:00,  9.17it/s, Batch Val Loss=0.00299]


Epoch 55 Average Validation Loss: 0.0035


Training Epoch 56/60: 100%|██████████| 41/41 [00:03<00:00, 12.32it/s, Batch Latent Loss=0.00175] 


Epoch 56 Average Train Loss: 0.0034


Validation Epoch 56/60: 100%|██████████| 10/10 [00:01<00:00,  7.58it/s, Batch Val Loss=0.00171]


Epoch 56 Average Validation Loss: 0.0023


Training Epoch 57/60:  88%|████████▊ | 36/41 [00:03<00:00, 12.10it/s, Batch Latent Loss=0.00194] 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_losses(loss_history_path):
    # Load the loss history
    loss_history = np.load(loss_history_path, allow_pickle=True).item()
    
    # Extract training and validation losses
    train_losses = loss_history['train']
    val_losses = loss_history['val']
    
    # Create the plot
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_losses(os.path.join(checkpoints_path, 'loss_history.npy'))

In [None]:
def plot_latent_evolution(latent_dir, spectrum_index, num_epochs):
    latent_evolution = []
    
    # Load latent codes for each epoch and append the specific spectrum's latent code to the list
    for epoch in range(1, num_epochs + 1):
        latent_path = os.path.join(latent_dir, f'latent_codes_epoch_{epoch}.npy')
        latent_data = np.load(latent_path, allow_pickle=True).item()
        latent_codes = latent_data['latent_codes']
        latent_evolution.append(latent_codes[spectrum_index])

    latent_evolution = np.array(latent_evolution)

    # Plot each component of the latent code
    plt.figure(figsize=(12, 8))
    for i in range(latent_evolution.shape[1]):  # Assuming latent_codes have the same dimensionality across epochs
        plt.plot(latent_evolution[:, i], label=f'Latent Dimension {i + 1}')
    
    plt.title(f'Evolution of Latent Space for Spectrum Index {spectrum_index}')
    plt.xlabel('Epoch')
    plt.ylabel('Latent Value')
    plt.legend()
    plt.grid(True)
    plt.show()


plot_latent_evolution(latent_path, 0, 15)

In [None]:
def plot_real_vs_generated(checkpoints_path, latent_path, data_loader, generator, device, spectrum_index):
    """
    Plots the real versus generated spectrum for a given spectrum index.
    
    Args:
    - checkpoints_path (str): Path to the directory containing model checkpoints.
    - latent_path (str): Path to the directory containing saved latent codes.
    - data_loader (DataLoader): DataLoader containing the dataset with real spectra.
    - generator (torch.nn.Module): The generator model.
    - device (torch.device): The device on which PyTorch operations should be performed.
    - spectrum_index (int): The index of the spectrum to plot.
    """
    # Load the last latent code
    latent_files = sorted(os.listdir(latent_path))
    last_latent_file = os.path.join(latent_path, latent_files[-1])
    latents = np.load(last_latent_file, allow_pickle=True).item()
    latent_code = torch.tensor(latents['latent_codes'][spectrum_index]).to(device)

    # Generate the spectrum using the latent code
    generator.eval()
    with torch.no_grad():
        generated_spectrum = generator(latent_code.unsqueeze(0)).squeeze(0).cpu().numpy()

    # Load the real flux data
    for batch in data_loader:
        indices = batch['index']
        if spectrum_index in indices:
            real_spectrum = batch['flux'][indices == spectrum_index].squeeze(0).cpu().numpy()
            break

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(real_spectrum, label='Real Spectrum', color='blue')
    plt.plot(generated_spectrum, label='Generated Spectrum', color='red')
    plt.title(f'Comparison of Real and Generated Spectra for Spectrum Index {spectrum_index}')
    plt.xlabel('Wavelength Index')
    plt.ylabel('Flux')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example of usage
# Assuming you have 'generator', 'device' setup, and a 'data_loader' ready
plot_real_vs_generated(checkpoints_path, latent_path, train_loader, generator, device, spectrum_index=0)
