In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# =============================================================================
# Dummy Dataset
# =============================================================================
# Our simulated dataset creates random "galaxy images" and corresponding spectra.
# In real-life, your data might be as elusive as a galaxy in a dark sky,
# but here we generate random tensors for demonstration.
class GalaxyDataset(Dataset):
    def __init__(self, num_samples=1000, image_size=(3, 64, 64), low_spec_length=100, high_spec_length=200):
        self.num_samples = num_samples
        self.image_size = image_size
        self.low_spec_length = low_spec_length
        self.high_spec_length = high_spec_length

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Simulate a galaxy image (3-channel photometric image)
        image = torch.randn(self.image_size)
        # Simulate low-resolution (noisy) space-based spectrum (e.g., 3DHST GRISM)
        low_res_spec = torch.randn(self.low_spec_length)
        # Simulate high-resolution spectrum (e.g., Keck MOSDEF)
        high_res_spec = torch.randn(self.high_spec_length)
        return image, low_res_spec, high_res_spec

# =============================================================================
# Model Components
# =============================================================================

# Image Encoder: a simple CNN to extract features from galaxy images.
class ImageEncoder(nn.Module):
    def __init__(self, embedding_dim=128):
        super(ImageEncoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # (3,64,64) -> (16,32,32)
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (16,32,32) -> (32,16,16)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (32,16,16) -> (64,8,8)
            nn.ReLU(),
        )
        self.fc = nn.Linear(64 * 8 * 8, embedding_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        embedding = self.fc(x)
        return embedding

# Spectrum Encoder: a simple MLP to extract features from 1D spectral data.
class SpectrumEncoder(nn.Module):
    def __init__(self, input_length, embedding_dim=128):
        super(SpectrumEncoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_length, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x):
        # x shape: (batch, input_length)
        embedding = self.net(x)
        return embedding

# Conditional Diffusion Model: a simplified network that predicts the noise component
# added to the high-resolution spectrum, conditioned on the image embedding.
class ConditionalDiffusionModel(nn.Module):
    def __init__(self, spectrum_length, embedding_dim=128):
        super(ConditionalDiffusionModel, self).__init__()
        # Use a linear layer to project the image embedding to the spectrum dimension.
        self.fc_cond = nn.Linear(embedding_dim, spectrum_length)
        # A simple feedforward network to predict the noise.
        self.net = nn.Sequential(
            nn.Linear(spectrum_length, 512),
            nn.ReLU(),
            nn.Linear(512, spectrum_length)
        )

    def forward(self, x, cond):
        # x: noisy high-resolution spectrum, shape (batch, spectrum_length)
        # cond: image embedding, shape (batch, embedding_dim)
        cond_proj = self.fc_cond(cond)  # Shape: (batch, spectrum_length)
        # Add conditioning information in a simple additive way.
        x_cond = x + cond_proj
        predicted_noise = self.net(x_cond)
        return predicted_noise

# Contrastive Loss (InfoNCE) to align image and spectrum embeddings.
# We want the matching image-spectrum pair to be more similar than non-matching pairs.
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, image_emb, spec_emb):
        batch_size = image_emb.size(0)
        # Normalize embeddings to unit length (for cosine similarity)
        image_emb_norm = image_emb / image_emb.norm(dim=1, keepdim=True)
        spec_emb_norm = spec_emb / spec_emb.norm(dim=1, keepdim=True)
        # Compute the similarity matrix (batch_size x batch_size)
        logits = torch.matmul(image_emb_norm, spec_emb_norm.T) / self.temperature
        labels = torch.arange(batch_size, device=logits.device)
        loss_i = self.cross_entropy(logits, labels)
        loss_s = self.cross_entropy(logits.T, labels)
        loss = (loss_i + loss_s) / 2
        return loss

# =============================================================================
# Diffusion Noise Scheduler (Simplified)
# =============================================================================
# In full-fledged diffusion models, the noise schedule is more intricate.
# Here, we assume a simple linear schedule to add noise to the high-res spectrum.
def add_noise(x, noise, t):
    # x_noisy = sqrt(1 - t) * x + sqrt(t) * noise
    alpha = (1 - t).sqrt()
    beta = t.sqrt()
    return alpha * x + beta * noise

# =============================================================================
# Training Loop
# =============================================================================
def train(model_components, dataloader, device, num_epochs=10, diffusion_loss_weight=1.0):
    image_encoder = model_components['image_encoder']
    spectrum_encoder = model_components['spectrum_encoder']
    diffusion_model = model_components['diffusion_model']
    contrastive_loss_fn = model_components['contrastive_loss_fn']
    
    # Combine all model parameters in one optimizer
    optimizer = optim.Adam(
        list(image_encoder.parameters()) + 
        list(spectrum_encoder.parameters()) + 
        list(diffusion_model.parameters()), 
        lr=1e-3
    )

    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch_idx, (image, low_res_spec, high_res_spec) in enumerate(dataloader):
            image = image.to(device)
            # Reshape spectra: (batch, spectrum_length)
            low_res_spec = low_res_spec.to(device)
            high_res_spec = high_res_spec.to(device)
            
            # Compute embeddings
            img_emb = image_encoder(image)              # Image embedding: (batch, embedding_dim)
            spec_emb = spectrum_encoder(high_res_spec)    # Spectrum embedding: (batch, embedding_dim)
            
            # Contrastive loss: Align the image and high-res spectrum embeddings.
            contrast_loss = contrastive_loss_fn(img_emb, spec_emb)
            
            # Diffusion loss: Train the diffusion model to predict noise in high-res spectra.
            batch_size = high_res_spec.size(0)
            # Sample a random time t (between 0 and 1) for each sample in the batch.
            t = torch.rand(batch_size, 1, device=device)
            noise = torch.randn_like(high_res_spec)
            # Generate a noisy high-resolution spectrum.
            high_res_spec_noisy = add_noise(high_res_spec, noise, t)
            # Predict the noise using the diffusion model conditioned on the image embedding.
            noise_pred = diffusion_model(high_res_spec_noisy, img_emb)
            diffusion_loss = mse_loss(noise_pred, noise)
            
            # Total loss: weighted sum of contrastive and diffusion losses.
            total_loss = contrast_loss + diffusion_loss_weight * diffusion_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {total_loss.item():.4f}")
        
        print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {epoch_loss/len(dataloader):.4f}")
    print("Training complete! May your galaxies shine bright and your spectra be ever detailed.")

# =============================================================================
# Main Function
# =============================================================================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Hyperparameters
    embedding_dim = 128
    image_size = (3, 64, 64)
    low_spec_length = 100
    high_spec_length = 200  # High-resolution spectra are assumed to have twice the length.

    # Instantiate model components.
    image_encoder = ImageEncoder(embedding_dim=embedding_dim).to(device)
    spectrum_encoder = SpectrumEncoder(input_length=high_spec_length, embedding_dim=embedding_dim).to(device)
    diffusion_model = ConditionalDiffusionModel(spectrum_length=high_spec_length, embedding_dim=embedding_dim).to(device)
    contrastive_loss_fn = ContrastiveLoss(temperature=0.1)
    
    model_components = {
        'image_encoder': image_encoder,
        'spectrum_encoder': spectrum_encoder,
        'diffusion_model': diffusion_model,
        'contrastive_loss_fn': contrastive_loss_fn
    }
    
    # Create the dataset and dataloader.
    dataset = GalaxyDataset(
        num_samples=1000, 
        image_size=image_size, 
        low_spec_length=low_spec_length, 
        high_spec_length=high_spec_length
    )
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Start training the model.
    train(model_components, dataloader, device, num_epochs=10, diffusion_loss_weight=1.0)

if __name__ == "__main__":
    main()

Epoch [1/10], Batch [0/32], Loss: 4.5226
Epoch [1/10], Batch [10/32], Loss: 4.5312
Epoch [1/10], Batch [20/32], Loss: 4.4677
Epoch [1/10], Batch [30/32], Loss: 4.4240
Epoch [1/10] Average Loss: 4.4883
Epoch [2/10], Batch [0/32], Loss: 4.4069
Epoch [2/10], Batch [10/32], Loss: 4.4066
Epoch [2/10], Batch [20/32], Loss: 4.3745
Epoch [2/10], Batch [30/32], Loss: 4.2908
Epoch [2/10] Average Loss: 4.3264
Epoch [3/10], Batch [0/32], Loss: 4.3100
Epoch [3/10], Batch [10/32], Loss: 4.2808
Epoch [3/10], Batch [20/32], Loss: 4.2502
Epoch [3/10], Batch [30/32], Loss: 4.2515
Epoch [3/10] Average Loss: 4.2357
Epoch [4/10], Batch [0/32], Loss: 4.2232
Epoch [4/10], Batch [10/32], Loss: 4.1993
Epoch [4/10], Batch [20/32], Loss: 4.1984
Epoch [4/10], Batch [30/32], Loss: 4.2005
Epoch [4/10] Average Loss: 4.1609
Epoch [5/10], Batch [0/32], Loss: 4.1441
Epoch [5/10], Batch [10/32], Loss: 4.1063
Epoch [5/10], Batch [20/32], Loss: 4.1141
Epoch [5/10], Batch [30/32], Loss: 4.1622
Epoch [5/10] Average Loss: 4.