In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from multiprocessing import set_start_method, Manager, Pool
from module import APOGEEDataset
from tqdm import tqdm

# To avoid issues with non-fork-safe libraries
try:
    set_start_method('spawn')
except RuntimeError:
    pass

In [2]:
directory = '../../../projects/k-pop/spectra/apogee/dr17'
dataset = APOGEEDataset(directory, max_files=500)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=8)


# Model Definition

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.LeakyReLU(0.01), nn.BatchNorm1d(128),
            nn.Linear(128, 256), nn.LeakyReLU(0.01), nn.BatchNorm1d(256),
            nn.Linear(256, 512), nn.LeakyReLU(0.01), nn.BatchNorm1d(512),
            nn.Linear(512, 512), nn.LeakyReLU(0.01), nn.BatchNorm1d(512),
            nn.Linear(512, 256), nn.LeakyReLU(0.01), nn.BatchNorm1d(256),
            nn.Linear(256, 128), nn.LeakyReLU(0.01), nn.BatchNorm1d(128),
            nn.Linear(128, output_dim)
        )

    def forward(self, z):
        return self.model(z)


# Custom Loss Function

In [4]:
def weighted_min_loss(generator, z, batch, loss_func):
    real_data, sigma, mask = batch['flux'], batch['sigma'], batch['flux_mask']
    sigma_safe = sigma + 1e-8
    weights = mask / sigma_safe
    generated_data = generator(z)
    losses = loss_func(generated_data, real_data, reduction='none')
    weighted_losses = (weights * losses).sum(dim=1)
    min_loss = weighted_losses.min()
    return min_loss


# Hyperparameters and Optimizer Setup

In [5]:
latent_dim = 50
output_dim = 8575
learning_rate_gen = 0.1
learning_rate_latent = 0.01
weight_decay = 1e-4
loss_func = torch.nn.functional.mse_loss

generator = Generator(latent_dim, output_dim)
latent_vectors = torch.randn((len(dataset), latent_dim), requires_grad=True)
optimizer = torch.optim.Adam([
    {'params': generator.parameters(), 'lr': learning_rate_gen, 'weight_decay': weight_decay},
    {'params': latent_vectors, 'lr': learning_rate_latent}
])


# Multiprocessing Setup

In [6]:
def train_batch(batch, generator_state_dict, latent_vectors, optimizer_state_dict):
    """
    Train a single batch using the provided state dictionaries and return the loss.
    """
    local_generator = Generator(latent_dim, output_dim)
    local_generator.load_state_dict(generator_state_dict)
    local_generator.train()

    # Create a local optimizer for this process
    local_optimizer = torch.optim.Adam([
        {'params': local_generator.parameters(), 'lr': learning_rate_gen, 'weight_decay': weight_decay},
        {'params': latent_vectors, 'lr': learning_rate_latent}
    ])
    local_optimizer.load_state_dict(optimizer_state_dict)

    # Generate data
    z = latent_vectors[batch['index']]
    generated_data = local_generator(z)

    # Calculate the loss
    loss = weighted_min_loss(local_generator, z, batch, loss_func)
    loss.backward()
    local_optimizer.step()

    return local_generator.state_dict(), local_optimizer.state_dict(), loss.item()


In [7]:
def init_processes(dataset, num_processes, num_epochs):
    with get_context("spawn").Pool(num_processes) as pool:
        manager = Manager()
        losses_per_epoch = manager.list()

        # Initialize the generator and optimizer in the main process
        generator = Generator(latent_dim, output_dim)
        optimizer = torch.optim.Adam([
            {'params': generator.parameters(), 'lr': learning_rate_gen, 'weight_decay': weight_decay},
            {'params': latent_vectors, 'lr': learning_rate_latent}
        ])

        for epoch in range(num_epochs):
            epoch_losses = manager.list()

            # Process each batch in parallel
            batch_results = [pool.apply_async(train_batch, args=(batch, generator.state_dict(), latent_vectors, optimizer.state_dict())) for batch in dataloader]

            # Update generator and optimizer with the new states and collect losses
            for result in batch_results:
                generator_state_dict, optimizer_state_dict, batch_loss = result.get()
                generator.load_state_dict(generator_state_dict)
                optimizer.load_state_dict(optimizer_state_dict)
                epoch_losses.append(batch_loss)

            # Aggregate losses for the epoch
            avg_loss = sum(epoch_losses) / len(epoch_losses)
            losses_per_epoch.append(avg_loss)

            # Visualization and logging
            if epoch % 10 == 0:
                with torch.no_grad():
                    generator.eval()
                    sample_latent = torch.randn(1, latent_dim)
                    generated_spectra = generator(sample_latent).squeeze().numpy()
                    real_spectra = next(iter(dataloader))['flux'][0].numpy()  # Extracting real spectra for comparison
                    plot_spectra(real_spectra, generated_spectra, epoch, 0)
                    print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')
                generator.train()

    return list(losses_per_epoch)

# Visualization and Logging


In [8]:

if __name__ == '__main__':
    latent_vectors = torch.randn((len(dataset), latent_dim), requires_grad=True)
    loss_values = init_processes(dataset, 16, 200)  # Start training with 16 processes

    # Plotting the training losses after all epochs
    plt.figure(figsize=(10, 5))
    plt.plot(loss_values, label='Training Loss')
    plt.title('Loss During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

NameError: name 'get_context' is not defined