In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear
from complexPyTorch.complexFunctions import complex_relu, complex_max_pool2d
from utils import *

In [9]:
def random_complex_number(r, c):
    a, b = np.random.rand(r, c), np.random.rand(r, c)
    return a + b*complex(0, 1)

In [10]:
spatial = random_complex_number(100, 1)
time = random_complex_number(100, 1)
y = spatial**2 + time
spatial = torch.tensor(spatial, dtype=torch.complex64).requires_grad_(True)
time = torch.tensor(time, dtype=torch.complex64).requires_grad_(True)
y = torch.tensor(y, dtype=torch.complex64).reshape(-1, 1)

In [11]:
def cat(v1, v2): return torch.cat([v1, v2], dim=-1)

class ComplexNet(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(ComplexNet, self).__init__()
        self.linear1 = ComplexLinear(in_features, hidden_features)
        self.linear2 = ComplexLinear(hidden_features, out_features)
             
    def forward(self, x):
        x = self.linear1(x)
        x = torch.tanh(x)
        x = self.linear2(x)
        return x

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ComplexNet(2, 50, 1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 15000
target = torch.zeros(y.shape).to(device)
model.train()
for i in range(epochs):
    optimizer.zero_grad()
    loss = F.mse_loss((model(torch.cat([spatial, time], dim=-1))-y).abs(), target, reduction='mean')
    loss.backward()
    if i % 5000==0:
        print(loss.item())
    optimizer.step()

3.2557153701782227
4.236429958837107e-06
2.091291435135645e-06


In [18]:
model.eval()
predictions = model(cat(spatial, time))
# chk derivatives wrt time -> the real part should be close to torch.ones
((torch.real(diff(predictions, time)[0]) - torch.ones(size=predictions.shape))**2).mean()

tensor(3.4451e-05, grad_fn=<MeanBackward0>)