In [1]:
import os

import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim)
        )

    def forward(self, x):
        return x + self.block(x)

In [3]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)

In [4]:
class ResNetGenerator(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_blocks, output_dim=50, num_choices=50):
        super(ResNetGenerator, self).__init__()
        self.initial_mlp = MLP(input_dim, hidden_dim, hidden_dim)
        self.resnet_blocks = nn.Sequential(*[ResidualBlock(hidden_dim) for n in range(num_blocks)])
        self.final_mlp = MLP(hidden_dim, hidden_dim, output_dim * num_choices)

        self.coordinates_range = torch.linspace(0.1, 0.6, output_dim)

        self.num_choices = num_choices

        # Hyperparamters
        self.alpha = 0.001
        self.alpha_sup = 0.01

    def update_alpha(self, normIter):
        self.alpha = (normIter/0.05) * self.alpha_sup + 1
        return self.alpha
    
    def forward(self, z, alpha):
        x = self.initial_mlp(z)
        x = self.resnet_blocks(x)
        x = self.final_mlp(x)
        
        # Reshape the output into [batch_size, 50, num_choices]
        coordinates = x.view(-1, 50, self.num_choices)  # [batch_size, 50, 50]

        # Apply softmax to generate a distribution over the choices
        coordinates = F.softmax(coordinates * alpha, dim=-1)

        # Perform the weighted sum over the predefined range
        predicted_coords = torch.sum(coordinates * self.coordinates_range, dim=-1)  # [batch_size, 50]
        
        return predicted_coords

In [5]:
# Define the save directory for the coordinates
save_dir = 'fsp_files'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Load the pre-trained model
model_path = r'C:\Users\Massee\Desktop\QCGCnew\Lumerical Adjoint_397\model_checkpoints\resnet_generator.pth'

model = ResNetGenerator(input_dim=10, hidden_dim=128, num_blocks=4, output_dim=50, num_choices=50)
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode

batch_size = 1000
input_dim = 10
alpha = 1e12  # as per the image

z = torch.randn(batch_size, input_dim)  # Generating random inputs
predicted_coords = model(z, alpha)  # Get the predicted coordinates from the model

# Save the predicted coordinates as .fsp files in numerical order
for i in range(predicted_coords.size(0)):
    coords = predicted_coords[i].detach().numpy()  # Convert to numpy array
    
    # Save each set of coordinates to a unique file
    save_path = os.path.join(save_dir, f'coords_{i+1}.fsp')
    np.savetxt(save_path, coords, delimiter=',')
    print(f'Saved coordinates to {save_path}')

print(f'{batch_size} coordinate files saved in {save_dir}')

  model.load_state_dict(torch.load(model_path))


Saved coordinates to fsp_files\coords_1.fsp
Saved coordinates to fsp_files\coords_2.fsp
Saved coordinates to fsp_files\coords_3.fsp
Saved coordinates to fsp_files\coords_4.fsp
Saved coordinates to fsp_files\coords_5.fsp
Saved coordinates to fsp_files\coords_6.fsp
Saved coordinates to fsp_files\coords_7.fsp
Saved coordinates to fsp_files\coords_8.fsp
Saved coordinates to fsp_files\coords_9.fsp
Saved coordinates to fsp_files\coords_10.fsp
Saved coordinates to fsp_files\coords_11.fsp
Saved coordinates to fsp_files\coords_12.fsp
Saved coordinates to fsp_files\coords_13.fsp
Saved coordinates to fsp_files\coords_14.fsp
Saved coordinates to fsp_files\coords_15.fsp
Saved coordinates to fsp_files\coords_16.fsp
Saved coordinates to fsp_files\coords_17.fsp
Saved coordinates to fsp_files\coords_18.fsp
Saved coordinates to fsp_files\coords_19.fsp
Saved coordinates to fsp_files\coords_20.fsp
Saved coordinates to fsp_files\coords_21.fsp
Saved coordinates to fsp_files\coords_22.fsp
Saved coordinates t