# Initialization

## Imports

In [None]:
# Python
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# Torch
import torch

# PhotonTorch
import photontorch as pt

# Progress Bars
from tqdm import tqdm

# Michelson Interferometer Cavity

## Schematic
![michelson interferometer](images/michelson.png)

## Simulation and Design Parameters

In [None]:
num_batches = 404 # number of parallel simulations to perform
neff = np.sqrt(12.1)
wl = 1.55e-6
dt = 0.5e-9
total_time = 2e-6
time = np.arange(0,total_time,dt)

# Set random seed
np.random.seed(0)

# Network

We use the same network as in the [previous notebook](04_train_output.ipynb). However, the network can be put on the `GPU` with the `to("cuda")` method.

In [None]:
# define network in the standard way:
class MichelsonCavity(pt.Network):
    def __init__(self):
        super(MichelsonCavity, self).__init__(copy_components=True)
        self.west = pt.Source()
        self.north = self.east = self.south = pt.Detector()
        self.m_west = self.m_north = self.m_east = self.m_south = pt.Mirror(R=0.9)
        self.wg_west = pt.Waveguide(0.43, neff=neff)
        self.wg_north = pt.Waveguide(0.60, neff=neff)
        self.wg_east = pt.Waveguide(0.95, neff=neff)
        self.wg_south = pt.Waveguide(1.12, neff=neff)
        self.dc = pt.DirectionalCoupler(coupling=0.5)
        self.link('west:0','0:m_west:1', '0:wg_west:1', '0:dc:2', '0:wg_east:1', '0:m_east:1', '0:east')
        self.link('north:0', '0:m_north:1', '0:wg_north:1', '1:dc:3', '0:wg_south:1', '0:m_south:1', '0:south')
    
# create network
nw = MichelsonCavity().to("cuda")

# Simulation

Another power of Photontorch is the massive parallelism one can achieve by doing multiple simulations at once (called batches):

In [None]:
batch_weights = np.random.random(num_batches)
with pt.Environment(wl=wl, t=time, num_batches=num_batches):
    detected = nw(source=batch_weights)[:,0,:,:]   # get all timesteps, the only wavelength, all detectors, all batches
    nw.plot(detected[:,:,[1,3]]); #plot second and fourth batch

# Training

Training Parameters:

In [None]:
num_epochs = 2
learning_rate = 0.2
lossfunc = torch.nn.MSELoss()
optimizer = torch.optim.Adam(nw.parameters(), lr=learning_rate)
env = pt.Environment(wl=wl, t=time, num_batches=num_batches) # training environment

We would like to train the network to arrive in another steady state with the same output everywhere:

In [None]:
total_power_out = detected.data.cpu().numpy()[-1].sum()
target = torch.tensor(torch.cat([detected.data[-1].mean(0, keepdim=True)]*3, dim=0), device=nw.device)
del detected # Free up GPU memory

Train (CUDA is recommended here...):

In [None]:
# Running speed without cuda: 27s/it
# Running speed with cuda: 3.5s/it
# loop over the training cycles:
with pt.Environment(wavelength=wl, t=time, enable_grad=True):
    for epoch in tqdm(range(num_epochs)):
        optimizer.zero_grad()
        detected = nw(source=batch_weights)[-1,0,:,:] # get the last timestep, the only wavelength, all detectors, all batches
        loss = lossfunc(detected, target) # calculate the loss (error) between detected and target
        loss.backward() # calculate the resulting gradients for all the parameters of the network
        optimizer.step() # update the networks parameters with the gradients
        del detected, loss # free up memory (important for GPU)

Do a final simulation:

In [None]:
with pt.Environment(wl=wl, t=time, num_batches=num_batches):
    detected = nw(source=batch_weights)
    nw.plot(detected[:,0,:,[1,3]]); #plot second and fourth batch