#Training ONN with simple MZI model



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# Define a class for an MZI-based linear layer
class MZILayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(MZILayer, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        # Randomly initialize phase shifts for the MZIs
        self.theta = nn.Parameter(torch.rand(input_size, output_size) * 2 * np.pi)  # phase shifts
        self.phi = nn.Parameter(torch.rand(input_size, output_size) * 2 * np.pi)    # phase shifts

    def forward(self, x):
        # Simulate the MZI-based linear transformation
        cos_theta = torch.cos(self.theta)
        sin_theta = torch.sin(self.theta)
        cos_phi = torch.cos(self.phi)
        sin_phi = torch.sin(self.phi)

        # Linear transformation using MZI-based operations
        weight_matrix = cos_theta * cos_phi - sin_theta * sin_phi
        mzi_output = torch.matmul(x, weight_matrix)  # input * weight matrix
        return mzi_output

# Define the optical neural network model using MZI layers
class OpticalNN(nn.Module):
    def __init__(self):
        super(OpticalNN, self).__init__()
        self.fc1 = MZILayer(28*28, 128)
        self.fc2 = MZILayer(128, 64)
        self.fc3 = MZILayer(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)  # Flatten input for MZI layer
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load MNIST data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=64, shuffle=True)

# Instantiate the model, loss function, and optimizer
model = OpticalNN()
criterion = nn.CrossEntropyLoss()

# Define the random noise matrix F for error modulation
def generate_random_matrix(input_size, output_size):
    return torch.randn(output_size, input_size) * 0.01  # F shape: (10, 784)

F = generate_random_matrix(28*28, 10)  # F for input modulation with correct dimensions

# Training loop
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Standard forward pass
        output = model(data)

        # Compute the loss (error) and apply softmax to get probabilities
        loss = criterion(output, target)
        softmax_output = torch.softmax(output, dim=1)

        # Compute the error signal
        e = torch.nn.functional.one_hot(target, num_classes=10).float() - softmax_output

        # Compute delta_x by modulating with the random matrix F
        delta_x = torch.matmul(e, F).view(-1, 1, 28, 28)  # Resulting shape: (batch_size, 784) -> reshaped

        # Modulated input: new input + delta_x
        modulated_input = data + delta_x

        # Forward pass with modulated input
        modulated_output = model(modulated_input)

        # Calculate the loss for modulated output
        modulated_loss = criterion(modulated_output, target)

        # Update the MZI parameters using modulated loss (gradient-based update)
        modulated_loss.backward()

        # Perform manual gradient descent on the phase shifts (theta and phi)
        with torch.no_grad():
            for param in model.parameters():
                param -= 0.01 * param.grad  # simple gradient descent with learning rate 0.01
                param.grad.zero_()  # reset gradients

    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print('Training completed.')


Epoch 1, Loss: 0.6813176870346069
Epoch 2, Loss: 1.559756875038147
Epoch 3, Loss: 1.1564719676971436
Epoch 4, Loss: 0.7610306143760681
Epoch 5, Loss: 0.8801256418228149
Epoch 6, Loss: 0.8372388482093811
Epoch 7, Loss: 0.6285146474838257
Epoch 8, Loss: 0.6099349856376648
Epoch 9, Loss: 0.4673929810523987
Epoch 10, Loss: 0.35434451699256897
Training completed.


# Training Large Scale ONN with simple MZI model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# Define an MZI-based linear layer
class MZILayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(MZILayer, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        # Initialize phase shifts for MZI simulation
        self.theta = nn.Parameter(torch.rand(input_size, output_size) * 2 * np.pi)
        self.phi = nn.Parameter(torch.rand(input_size, output_size) * 2 * np.pi)

    def forward(self, x):
        cos_theta = torch.cos(self.theta)
        sin_theta = torch.sin(self.theta)
        cos_phi = torch.cos(self.phi)
        sin_phi = torch.sin(self.phi)

        weight_matrix = cos_theta * cos_phi - sin_theta * sin_phi
        output = torch.matmul(x, weight_matrix)
        return output

# Define the model with MZI layers
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 28 MZI layers for rows
        self.fc_row_1 = nn.ModuleList([MZILayer(28, 28) for _ in range(28)])
        # 28 MZI layers for rows (second stage)
        self.fc_row_2 = nn.ModuleList([MZILayer(28, 1) for _ in range(28)])
        # Final layer to reduce to 10 classes
        self.linear = MZILayer(28, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28, 28)
        # Pass each row through its corresponding MZI layers
        row = [self.relu(fc(x[:, :, i])) for i, fc in enumerate(self.fc_row_1)]
        row = [self.relu(fc(row[i])) for i, fc in enumerate(self.fc_row_2)]
        row = torch.cat(row, dim=1)
        x = self.linear(row)
        return x

# Load MNIST data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=64, shuffle=True)

# Instantiate the model, loss function, and optimizer
model = Model()
criterion = nn.CrossEntropyLoss()

# Define the random noise matrix F for error modulation
def generate_random_matrix(input_size, output_size):
    return torch.randn(output_size, input_size) * 0.01

F = generate_random_matrix(28*28, 10)  # F for input modulation

# Training loop with PEPITA
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Standard forward pass
        output = model(data)

        # Compute the loss (error) and apply softmax to get probabilities
        loss = criterion(output, target)
        softmax_output = torch.softmax(output, dim=1)

        # Compute the error signal
        e = torch.nn.functional.one_hot(target, num_classes=10).float() - softmax_output

        # Compute delta_x by modulating with the random matrix F
        delta_x = torch.matmul(e, F).view(-1, 1, 28, 28)

        # Modulated input: new input + delta_x
        modulated_input = data + delta_x

        # Forward pass with modulated input
        modulated_output = model(modulated_input)

        # Calculate the loss for modulated output
        modulated_loss = criterion(modulated_output, target)

        # Update the MZI parameters using modulated loss (gradient-based update)
        modulated_loss.backward()

        # Perform manual gradient descent on the phase shifts (theta and phi)
        with torch.no_grad():
            for param in model.parameters():
                param -= 0.01 * param.grad  # simple gradient descent with learning rate 0.01
                param.grad.zero_()  # reset gradients

    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print('Training completed.')


Epoch 1, Loss: 1.5184834003448486
Epoch 2, Loss: 1.4888489246368408
Epoch 3, Loss: 1.5177112817764282
Epoch 4, Loss: 1.304443120956421
