In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

class MatrixParameterNetwork(nn.Module):
    def __init__(self):
        super(MatrixParameterNetwork, self).__init__()
        self.fc1 = nn.Linear(16, 128)  # Input size is 16 (8 from x and 8 from b)
        self.fc2 = nn.Linear(128, 64)  # Hidden layer
        self.fc3 = nn.Linear(64, 48)   # Output layer (48 thetas)

    def forward(self, x, b):
        # Concatenate x and b as input to the network
        x = x.view(1, -1)  # Flatten to (1, 8) for batch size 1
        b = b.view(1, -1)  # Flatten to (1, 8)
        input_vector = torch.cat((x, b), dim=-1)
        
        x = torch.relu(self.fc1(input_vector))
        x = torch.relu(self.fc2(x))
        thetas = self.fc3(x)
        
        return thetas

# Function to create matrix A from manually initialized N, M, and P matrices
def create_A(theta):
    thetas = theta[27:]
    A = torch.eye(8, 8)  # Initialize matrix A
    I = torch.tensor(1j)  # Imaginary unit in PyTorch

    # Manually define the N, M, and P matrices as full 8x8 matrices
    for i in range(3):
        M_matrix = torch.tensor([
                                    [0, 0, 0, 0, 0, 0, 0, 0],
                                    [0, I * torch.exp(I * theta[7*i] / 2) * torch.sin(theta[7*i] / 2), I * torch.exp(I * theta[7*i] / 2) * torch.cos(theta[7*i] / 2), 0, 0, 0, 0, 0],
                                    [0, I * torch.exp(I * theta[7*i] / 2) * torch.cos(theta[7*i] / 2), -I * torch.exp(I * theta[7*i] / 2) * torch.sin(theta[7*i] / 2), 0, 0, 0, 0, 0],
                                    [0, 0, 0, I * torch.exp(I * theta[7*i+1] / 2) * torch.sin(theta[7*i+1] / 2), I * torch.exp(I * theta[7*i+1] / 2) * torch.cos(theta[7*i+1] / 2), 0, 0, 0],
                                    [0, 0, 0, I * torch.exp(I * theta[7*i+1] / 2) * torch.cos(theta[7*i+1] / 2), -I * torch.exp(I * theta[7*i+1] / 2) * torch.sin(theta[7*i+1] / 2), 0, 0, 0],
                                    [0, 0, 0, 0, 0, I * torch.exp(I * theta[7*i+2] / 2) * torch.sin(theta[7*i+2] / 2), I * torch.exp(I * theta[7*i+2] / 2) * torch.cos(theta[7*i+2] / 2), 0],
                                    [0, 0, 0, 0, 0, I * torch.exp(I * theta[7*i+2] / 2) * torch.cos(theta[7*i+2] / 2), -I * torch.exp(I * theta[7*i+2] / 2) * torch.sin(theta[7*i+2] / 2), 0],
                                    [0, 0, 0, 0, 0, 0, 0, 0],
                                ], dtype=torch.complex64)
        N_matrix = torch.tensor([
                                    [I * torch.exp(I * theta[7*i+3] / 2) * torch.sin(theta[7*i+3] / 2), I * torch.exp(I * theta[7*i+3] / 2) * torch.cos(theta[7*i+3] / 2), 0, 0, 0, 0, 0 ,0],
                                    [I * torch.exp(I * theta[7*i+3] / 2) * torch.cos(theta[7*i+3] / 2), -I * torch.exp(I * theta[7*i+3] / 2) * torch.sin(theta[7*i+3] / 2), 0, 0, 0, 0, 0, 0],
                                    [0, 0, I * torch.exp(I * theta[7*i+4] / 2) * torch.sin(theta[7*i+4] / 2), I * torch.exp(I * theta[7*i+4] / 2) * torch.cos(theta[7*i+4] / 2), 0, 0, 0, 0],
                                    [0, 0, I * torch.exp(I * theta[7*i+4] / 2) * torch.cos(theta[7*i+4] / 2), -I * torch.exp(I * theta[7*i+4] / 2) * torch.sin(theta[7*i+4] / 2), 0, 0, 0, 0],
                                    [0, 0, 0, 0, I * torch.exp(I * theta[7*i+5] / 2) * torch.sin(theta[7*i+5] / 2), I * torch.exp(I * theta[7*i+5] / 2) * torch.cos(theta[7*i+5] / 2), 0, 0],
                                    [0, 0, 0, 0, I * torch.exp(I * theta[7*i+5] / 2) * torch.cos(theta[7*i+5] / 2), -I * torch.exp(I * theta[7*i+5] / 2) * torch.sin(theta[7*i+5] / 2), 0, 0],
                                    [0, 0, 0, 0, 0, 0, I * torch.exp(I * theta[7*i+6] / 2) * torch.sin(theta[7*i+6] / 2), I * torch.exp(I * theta[7*i+6] / 2) * torch.cos(theta[7*i+6] / 2)],
                                    [0, 0, 0, 0, 0, 0, I * torch.exp(I * theta[7*i+6] / 2) * torch.cos(theta[7*i+6] / 2), -I * torch.exp(I * theta[7*i+6] / 2) * torch.sin(theta[7*i+6] / 2)],
                                ], dtype=torch.complex64)
    
        A = torch.matmul(A, torch.matmul(N_matrix, M_matrix))
    
    # Manually define the P matrix as a full 8x8 matrix
    P_matrix = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],  # Example: full P matrix
                             [0, I * torch.exp(I * theta[21] / 2) * torch.cos(theta[21] / 2), 0, 0, 0, 0, 0, 0],
                             [0, 0, I * torch.exp(I * theta[22] / 2) * torch.cos(theta[22] / 2), 0, 0, 0, 0, 0],
                             [0, 0, 0, I * torch.exp(I * theta[23] / 2) * torch.cos(theta[23] / 2), 0, 0, 0, 0],
                             [0, 0, 0, 0, I * torch.exp(I * theta[24] / 2) * torch.cos(theta[24] / 2), 0, 0, 0],
                             [0, 0, 0, 0, 0, I * torch.exp(I * theta[25] / 2) * torch.cos(theta[25] / 2), 0, 0],
                             [0, 0, 0, 0, 0, 0, I * torch.exp(I * theta[26] / 2) * torch.cos(theta[26] / 2), 0],
                             [0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.complex64)
    A = torch.matmul(A, P_matrix)
    
    for i in range(3):
        M_matrix = torch.tensor([
                                    [0, 0, 0, 0, 0, 0, 0, 0],
                                    [0, I * torch.exp(I * thetas[7*i] / 2) * torch.sin(thetas[7*i] / 2), I * torch.exp(I * thetas[7*i] / 2) * torch.cos(thetas[7*i] / 2), 0, 0, 0, 0, 0],
                                    [0, I * torch.exp(I * thetas[7*i] / 2) * torch.cos(thetas[7*i] / 2), -I * torch.exp(I * thetas[7*i] / 2) * torch.sin(thetas[7*i] / 2), 0, 0, 0, 0, 0],
                                    [0, 0, 0, I * torch.exp(I * thetas[7*i+1] / 2) * torch.sin(thetas[7*i+1] / 2), I * torch.exp(I * thetas[7*i+1] / 2) * torch.cos(thetas[7*i+1] / 2), 0, 0, 0],
                                    [0, 0, 0, I * torch.exp(I * thetas[7*i+1] / 2) * torch.cos(thetas[7*i+1] / 2), -I * torch.exp(I * thetas[7*i+1] / 2) * torch.sin(thetas[7*i+1] / 2), 0, 0, 0],
                                    [0, 0, 0, 0, 0, I * torch.exp(I * thetas[7*i+2] / 2) * torch.sin(thetas[7*i+2] / 2), I * torch.exp(I * thetas[7*i+2] / 2) * torch.cos(thetas[7*i+2] / 2), 0],
                                    [0, 0, 0, 0, 0, I * torch.exp(I * thetas[7*i+2] / 2) * torch.cos(thetas[7*i+2] / 2), -I * torch.exp(I * thetas[7*i+2] / 2) * torch.sin(thetas[7*i+2] / 2), 0],
                                    [0, 0, 0, 0, 0, 0, 0, 0],
                                ], dtype=torch.complex64)
        N_matrix = torch.tensor([
                                    [I * torch.exp(I * thetas[7*i+3] / 2) * torch.sin(thetas[7*i+3] / 2), I * torch.exp(I * thetas[7*i+3] / 2) * torch.cos(thetas[7*i+3] / 2), 0, 0, 0, 0, 0 ,0],
                                    [I * torch.exp(I * thetas[7*i+3] / 2) * torch.cos(thetas[7*i+3] / 2), -I * torch.exp(I * thetas[7*i+3] / 2) * torch.sin(thetas[7*i+3] / 2), 0, 0, 0, 0, 0, 0],
                                    [0, 0, I * torch.exp(I * thetas[7*i+4] / 2) * torch.sin(thetas[7*i+4] / 2), I * torch.exp(I * thetas[7*i+4] / 2) * torch.cos(thetas[7*i+4] / 2), 0, 0, 0, 0],
                                    [0, 0, I * torch.exp(I * thetas[7*i+4] / 2) * torch.cos(thetas[7*i+4] / 2), -I * torch.exp(I * thetas[7*i+4] / 2) * torch.sin(thetas[7*i+4] / 2), 0, 0, 0, 0],
                                    [0, 0, 0, 0, I * torch.exp(I * thetas[7*i+5] / 2) * torch.sin(thetas[7*i+5] / 2), I * torch.exp(I * thetas[7*i+5] / 2) * torch.cos(thetas[7*i+5] / 2), 0, 0],
                                    [0, 0, 0, 0, I * torch.exp(I * thetas[7*i+5] / 2) * torch.cos(thetas[7*i+5] / 2), -I * torch.exp(I * thetas[7*i+5] / 2) * torch.sin(thetas[7*i+5] / 2), 0, 0],
                                    [0, 0, 0, 0, 0, 0, I * torch.exp(I * thetas[7*i+6] / 2) * torch.sin(thetas[7*i+6] / 2), I * torch.exp(I * thetas[7*i+6] / 2) * torch.cos(thetas[7*i+6] / 2)],
                                    [0, 0, 0, 0, 0, 0, I * torch.exp(I * thetas[7*i+6] / 2) * torch.cos(thetas[7*i+6] / 2), -I * torch.exp(I * thetas[7*i+6] / 2) * torch.sin(thetas[7*i+6] / 2)],
                                ], dtype=torch.complex64)
    
        A = torch.matmul(A, torch.matmul(N_matrix, M_matrix))
    
    return A

# Example of training the model
def train_model():
    # Assume x and b are given as 8x1 matrices
    x = torch.tensor([0,1,1,1,1,1,1,0], dtype=torch.complex64)
    b = torch.tensor([0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], dtype=torch.complex64)  # Output vector b (8x1)
    
    # Initialize the model
    model = MatrixParameterNetwork()
    criterion = nn.MSELoss()  # MSE Loss function
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    num_epochs = 1000
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        # Forward pass
        thetas = model(x, b)
        
        # Create matrix A from thetas
        A = create_A(thetas)
        
        # Calculate Ax
        Ax = torch.matmul(A, x)
        
        # Calculate loss
        loss = criterion(Ax, b)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    return model

# Train the model
model = train_model()


RuntimeError: mat1 and mat2 must have the same dtype, but got ComplexFloat and Float