In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from torchinterp1d import interp1d

# Simplified Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Linear(latent_dim, output_dim)
        self.apply(self.init_weights)  # Initialize weights

    def forward(self, z):
        return self.fc(z)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

# Simplified Downsampling Layer
class DownsamplingLayer(nn.Module):
    def __init__(self, padding=5):
        super(DownsamplingLayer, self).__init__()
        self.padding = padding
        self.channels = [
            (4711, 4906, 0.05),
            (5647, 5875, 0.05),
            (6475, 6737, 0.05),
            (7583, 7885, 0.05),
            (15100, 17000, 0.2)
        ]
        self.extended_wavelength = self.create_extended_wavelength_grid()

    def create_extended_wavelength_grid(self):
        extended_wavelengths = []
        for start, end, step in self.channels:
            pre_pad = np.arange(start - self.padding * step, start, step)
            post_pad = np.arange(end + step, end + (self.padding + 1) * step, step)
            channel = np.arange(start, end + step, step)
            extended_wavelengths.extend(pre_pad)
            extended_wavelengths.extend(channel)
            extended_wavelengths.extend(post_pad)
        return torch.FloatTensor(np.unique(extended_wavelengths))

    def forward(self, high_res_flux, high_res_wavelength, observed_wavelengths, device):
        high_res_flux = high_res_flux.to(device)
        high_res_wavelength = high_res_wavelength.to(device)
        observed_wavelengths = observed_wavelengths.to(device)
        extended_wavelength = self.extended_wavelength.to(device)
        
        # Simplified interpolation using linear approximation for testing
        extended_flux = interp1d(extended_wavelength, high_res_wavelength, high_res_flux)
        observed_flux = interp1d(observed_wavelengths, extended_wavelength, extended_flux)
        return observed_flux

# Simplified FullNetwork
class FullNetwork(nn.Module):
    def __init__(self, generator, high_res_wavelength, device):
        super(FullNetwork, self).__init__()
        self.generator = generator
        self.downsampling_layer = DownsamplingLayer()
        self.high_res_wavelength = torch.tensor(high_res_wavelength, dtype=torch.float32)  # Convert to tensor
        self.device = device
    
    def forward(self, z, observed_wavelengths):
        high_res_flux = self.generator(z)  # Generate high-res flux from the generator
        downsampled_flux = self.downsampling_layer(high_res_flux, self.high_res_wavelength, observed_wavelengths, self.device)
        return downsampled_flux


Starting training step...
Error during training: x and y must have the same number of columns, and either the same number of row or one of them having only one row.
Allocated: 21.03 MB
Cached: 44.00 MB
Allocated: 20.79 MB
Cached: 44.00 MB
