In [3]:
import torch
import torch.nn as nn
import photontorch as pt
from torchvision import datasets
import torchvision.transforms as transforms
from deep_learning_utils import DLModule

In [4]:
import torch.nn as nn
import torch
from ring_net import RingNet

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer_flat = torch.nn.Flatten()# output: (batch, 1 * 28 * 28)
        self.layer_el = torch.nn.Linear(1 * 28 * 28, 8)# output: (batch, 64)
        self.layer_relu = torch.nn.ReLU()
        self.layer_ol = RingNet(10, 8, mode=0, wavelength_list=[1.3e-6, 1.35e-6, 1.4e-6, 1.45e-6, 1.5e-6, 1.55e-6, 1.6e-6, 1.65e-6])
    
    def forward(self, x):
        """
        将8个数据用8个波长承载，复用之后输入从同一个光源端口输入
        """
        x = torch.chunk(x, 8, dim = -1)
        x = torch.stack(x, dim = 0)
        x = torch.squeeze(x, dim = -1)
        x = torch.stack([x] * 10, dim = 0).rename('s', 'w', 'b')
        x = self.layer_ol(source = x)[-1, :, :, :]
        x = torch.sum(x, dim = 0) # 在波长维度求和
        x = torch.transpose(x, 0, 1)
        return x

In [7]:
wavelength_list = [1.3e-6, 1.35e-6, 1.4e-6, 1.45e-6, 1.5e-6, 1.55e-6, 1.6e-6, 1.65e-6]
env = pt.Environment(t_start = 0, t_end = 1e-11, dt = 1e-12, wl = wavelength_list, grad=True, freqdomain=True)
pt.set_environment(env)

In [20]:
# 搭建网络
model = Model()
# 超参数
train_batch = 16
test_batch = 5
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.005)
# 数据集
trainset = datasets.MNIST(root="./", train=True, transform=transforms.ToTensor(), download=True)
testset = datasets.MNIST(root="./", train=False, transform=transforms.ToTensor(), download=True)
subtrainset = torch.utils.data.Subset(trainset, range(160))
subtestset = torch.utils.data.Subset(testset, range(10))

# 训练测试模块
dl = DLModule(model=model, loss_fn=loss_fn, optim=optim, train_set=trainset, 
              test_set=testset, train_batch=train_batch, test_batch=test_batch)
best_accuracy = 0

In [8]:
input = torch.tensor(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8)
label = torch.tensor([5])

In [9]:
model = Model()
model.train()
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.005)


In [10]:
total_train_loss = 0
total_train_step = 0
for i in range(20):
    predict = model(input)
    loss = loss_fn(predict, label)
    total_train_loss += loss.item()
    optim.zero_grad()
    loss.backward()
    optim.step()
    total_train_step += 1
    print(f"Train steps : {total_train_step}",  f"Loss : {loss.item()}")

print(model(input))

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:760.)
  ret = func(*args, **kwargs)


In [18]:
out = model(input)
print(out)
print(out.argmax(1))

tensor([[-115.3332, -115.3332, -115.3332, -115.3332, -115.3331,  -83.4024,
         -115.3332, -115.3332, -115.3332, -115.3332]], grad_fn=<SubBackward0>)
tensor([5])
