In [1]:
import torch

In [3]:
import torch.nn.functional as F

class SmallCNN(torch.nn.Module):

    def __init__(self):
        super(SmallCNN, self).__init__()
        # 1 input image channel (black & white), 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        return x

smallCNN = SmallCNN()

In [4]:
class SmallLinear(torch.nn.Module):

    def __init__(self):
        super(SmallLinear, self).__init__()
        # 1 input image channel (black & white), 6 output channels, 5x5 square convolution
        # kernel
        self.lin1 = torch.nn.Linear(5, 5)
        self.lin2 = torch.nn.Linear(5, 4)
        self.lin3 = torch.nn.Linear(4, 3)
        self.activation = torch.nn.Tanh()

    def forward(self, x):
        x = self.lin1(x)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.activation(x)
        x = self.lin3(x)
        return x

smallLinear = SmallLinear()

In [5]:
loss_fct = torch.nn.MSELoss()

In [94]:
batch_size = 20
reshape_size = 8
input_uvp = torch.rand((batch_size, 3, reshape_size, reshape_size))

In [95]:
latent_vectors = smallCNN(input_uvp)

In [96]:
# loss = loss_fct(input_uvp, latent_vectors)

In [97]:
# I want to add the boundary conditions as input to the linear network as well. This is to hope that the 
# boundary information will be conveyed, hence the model will be able to solve the same equation for various
# boundary conditions.
# Geometric boundary condition under the form of a signed distance field.
# Velocity boundary as a loss function on the velocity at the inlet, outlet, and walls.

In [99]:
# imagine I have a list of boundary points, per image, probably have them as 1d, from a segmentaiton map where
# 0 is background, 1 is boundary, 2 is interior point
map_square = torch.randint(0, 3, (batch_size, reshape_size, reshape_size)) # segmentation map for the image pixels

In [100]:
latent_vectors_perm = latent_vectors.permute(0, 2, 3, 1)
interior_points = latent_vectors_perm[map_square==2] # all interior points from the batch
boundary_points = latent_vectors_perm[map_square==1] # all boundary points from the batch

In [101]:
xs_for_grid = torch.linspace(-1, 1, reshape_size)
ys_for_grid = torch.linspace(-1, 1, reshape_size)

In [102]:
x_physical, y_physical = torch.meshgrid([xs_for_grid, ys_for_grid])

In [103]:
# xs = [torch.cat([outs[i, :, i, i], x[:, 0], y[:, 0]]) for i in range(5)]

In [104]:
x_physical_batch  = x_physical.repeat(batch_size, 1, 1)
y_physical_batch  = y_physical.repeat(batch_size, 1, 1)

In [105]:
x_physical_batch.shape

torch.Size([20, 8, 8])

In [106]:
x_interior_points = x_physical_batch[map_square==2].view(-1, 1)
x_boundary_points = x_physical_batch[map_square==1].view(-1, 1)
y_interior_points = y_physical_batch[map_square==2].view(-1, 1)
y_boundary_points = y_physical_batch[map_square==1].view(-1, 1)

In [107]:
i = 0 
torch.cat([interior_points[i], x_interior_points[i], y_interior_points[i]])

tensor([-0.1533,  0.2114,  0.3398, -1.0000, -0.7143], grad_fn=<CatBackward0>)

In [108]:
input_features = torch.cat([x_interior_points, y_interior_points, interior_points], dim=1)


In [109]:
input_features.shape

torch.Size([446, 5])

In [110]:
inputs = [input_features[..., i:i+1] for i in range(input_features.shape[-1])]
# for xx in inputs:
#     if not xx.requires_grad:
#         xx.requires_grad = True
x_ = torch.cat(inputs, axis=-1)

In [111]:
outputs = smallLinear(input_features)

In [112]:
# inputs[0].shape

In [113]:
# outputs[:,0].sum()

In [114]:
# x_.shape

In [115]:
# inputs = input_features
# x_ = input_features
outputs = x_[:,[0]]**2

In [116]:
outputs.shape

torch.Size([446, 1])

In [117]:
dudx, dudy = torch.autograd.grad(outputs[:,0].sum(), [inputs[0], inputs[1]], retain_graph=True, create_graph=True)
# dvdx, dvdy = torch.autograd.grad(outputs[:,1].sum(), [inputs[0], inputs[1]], retain_graph=True, create_graph=True)
d2udx2 = torch.autograd.grad(dudx.sum(), inputs[0], retain_graph=True, create_graph=True)
d2udy2 = torch.autograd.grad(dudy.sum(), inputs[1], retain_graph=True, create_graph=True)

In [124]:
loss = torch.mean((dudx + dudy)**2)

In [125]:
loss.backward()

In [120]:
# dudxv2, dudyv2 = torch.autograd.grad(outputs[:,0], [inputs[0], inputs[1]], grad_outputs=torch.ones_like(inputs[0].squeeze(-1)), retain_graph=True, create_graph=True)


In [121]:
dudxv2, dudyv2 = torch.autograd.grad(outputs[:, 0], [inputs[0], inputs[1]], grad_outputs=torch.ones_like(outputs[:, 0]).view(-1), retain_graph=True, create_graph=True)


In [122]:
from functorch import grad

In [126]:
# grad(smallLinear)(input_features)

In [None]:
def take_grad(outs, ins):
    print(outs[0].requires_grad, ins[0].requires_grad, ins[1].requires_grad)
    outs[0].requires_grad = True
    ins[0].requires_grad = True
    ins[1].requires_grad = True
    print(outs[0].requires_grad, ins[0].requires_grad, ins[1].requires_grad)
    dudx, dudy = torch.autograd.grad(outs[0], [ins[0], ins[1]], grad_outputs=torch.ones_like(outs[0]), retain_graph=True, create_graph=True)
    d2udx2 = torch.autograd.grad(dudx, ins[0], grad_outputs=torch.ones_like(dudx), retain_graph=True, create_graph=True)[0]
    d2udy2 = torch.autograd.grad(dudy, ins[1], grad_outputs=torch.ones_like(dudy), retain_graph=True, create_graph=True)[0]
    return dudx, dudy, d2udx2, d2udy2


In [None]:
torch.vmap(take_grad)(outputs, x_)

In [None]:
dudx, dudy = torch.autograd.grad(outputs[0], [inputs[0], inputs[1]], grad_outputs=torch.ones_like(outputs[0]), retain_graph=True, create_graph=True)


In [None]:
dudx, dudy = torch.autograd.grad(outputs[0], [inputs[0], inputs[1]], grad_outputs=torch.ones_like(outputs[0]), retain_graph=True, create_graph=True)
d2udx2 = torch.autograd.grad(dudx, inputs[0], grad_outputs=torch.ones_like(dudx), retain_graph=True, create_graph=True)

In [None]:
d2udx2[0].shape

In [None]:
x = torch.randperm(5, dtype=torch.float32, requires_grad=True).view(-1, 5)
y = torch.randperm(5, dtype=torch.float32, requires_grad=True).view(-1, 5)

In [None]:
input_features = [torch.cat([outs[i, :, i, i], x[:, 0], y[:, 0]]) for i in range(5)]

In [None]:
input_features = torch.vstack(input_features)

In [None]:
input_features.shape

In [None]:
# lin_out = smallLinear(input_features[0])

In [None]:
# lin_out.shape

In [None]:
# lin_out, lin_in

In [None]:
lin_in = input_features[0]
inputs = [x.view(-1, 1) for x in lin_in]
print(inputs)
x_ = torch.cat(inputs, axis=-1)
lin_out = smallLinear(x_)
# lin_out = x_**2 + torch.sum(x_)**2
print(lin_out.shape)
outputs = [lin_out[..., i:i+1] for i in range(lin_out.shape[-1])]
# lin_out = torch.sum(lin_in)
dudx, dudy = torch.autograd.grad(outputs[0], [inputs[0], inputs[1]], grad_outputs=torch.ones_like(outputs[0]), retain_graph=True, create_graph=True)
d2udx2 = torch.autograd.grad(dudx, inputs[0], grad_outputs=torch.ones_like(dudx), retain_graph=True, create_graph=True)
# dvd = torch.autograd.grad(lin_out[1], lin_in, grad_outputs=torch.ones_like(lin_out[1]), retain_graph=True, create_graph=True)

# d2ud2 = torch.autograd.grad(dud[0][0], lin_in, grad_outputs=torch.ones_like(dud[0][0]), retain_graph=True, create_graph=True)


In [None]:
d2udx2

In [None]:
inputs = [input_features[..., i:i+1] for i in range(input_features.shape[0])]
outputs = [lin_out[..., i:i+1] for i in range(lin_out.shape[0])]

In [None]:
# torch_diff = lambda y, x: grad(y, x, grad_outputs=torch.ones_like(y), create_graph=True,
#                                allow_unused=True)[0]

In [None]:
dudx, dudy = torch.autograd.grad(outputs[0], inputs[0], grad_outputs=torch.ones_like(outputs[0]), retain_graph=True, create_graph=True)



In [None]:
from torch import vmap

In [None]:
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

In [None]:
from functorch import jacrev, vmap

In [None]:
x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)

In [None]:
# x = torch.rand(4, 3, 2, 2)
x = torch.arange(4*1*2*2, dtype=torch.float32).reshape((4, 1, 2, 2))
x_fl = x.view(4, -1)

In [None]:
x_fl.shape

In [None]:
torch.square(x_fl).shape

In [None]:
# jacobian_fl = vmap(jacrev(torch.sin))(x_fl)
jacobian_fl = vmap(jacrev(torch.square))(x_fl)

In [None]:
jacobian_fl.shape

In [None]:
jacobian = jacobian_fl.reshape(4, 1, 4, 4)

In [None]:
jacobian_fl.shape

In [None]:
# torch.diag(torch.cos(x[0, 0]))
# torch.cos(x[0, 0])

In [None]:
vals = vmap(lambda x: torch.reshape(torch.diag(x.squeeze()), (2,2)))(jacobian)
vals          

In [None]:
torch.diag(jacobian[0].squeeze()).reshape(2, 2)

In [None]:
torch.diag(jacobian[1].squeeze()).reshape(2, 2)

In [None]:
x.shape

In [None]:
xtst = x[0, 0]

In [None]:
# xtst = torch.rand(32, 32)

In [None]:
import functorch

In [None]:
import torch.nn.functional as F

In [None]:
dims = F.one_hot(torch.arange(2*2))

In [None]:
dims = dims.view((-1, 2, 2))

In [None]:
dims.shape

In [None]:
# x.requires_grad = True
# xtst = torch.rand((1, 1, 32, 32), requires_grad=True)
# ytst = torch.rand((1, 1, 32, 32), requires_grad=True)

xtst = torch.arange(2*2, dtype=torch.float32, requires_grad=True).view(1, 1, 2, 2)
ytst = torch.arange(2*2, dtype=torch.float32, requires_grad=True).view(1, 1, 2, 2)

y = xtst**2 + xtst[:, :, 0, 0] + ytst #+ xtst[:, :, 0, 0] + xtst[:, :, 0, 1] + xtst[:, :, 1, 0] + xtst[:, :, 1, 1] + ytst**3

In [None]:
# y = y.repeat(1, 3, 1, 1)
# y.shape

In [None]:
# functorch.vmap(torch.autograd.grad)(
#     y.repeat(2, 1, 1), 
#     [[xtst, ytst],[xtst, ytst]], 
#     grad_outputs=dims.repeat(2, 1, 1), 
#     is_grads_batched=True, 
#     retain_graph=True, 
#     create_graph=True
# )


In [None]:
# torch.sum((dudx.squeeze())*dims, dim=0)

In [None]:
# %%timeit
dudx, dudy = torch.autograd.grad(y[0, 0], [xtst, ytst], grad_outputs=dims, is_grads_batched=True, retain_graph=True, create_graph=True)
# dudx = torch.sum((dudx*dims), dim=0)
# dudy = torch.sum((dudy*dims), dim=0)

In [None]:
torch.sum((dudx.view(-1, 2, 2)*dims), dim=0)

In [None]:
xtsts.shape

In [None]:
import numpy as np

In [None]:
# ys = y.repeat(2, 1, 1)
# xtsts = xtst.repeat(2, 1, 1)
# ytsts = ytst.repeat(2, 1, 1)

xtsts = torch.rand((2, 32, 32), requires_grad=True)
ytsts = torch.rand((2, 32, 32), requires_grad=True)

xs = [i.unsqueeze(dim=0) for i in xtsts]
ys = [i.unsqueeze(dim=0) for i in ytsts]

xss = torch.cat(xs, dim=0)
yss = torch.cat(ys, dim=0)

print(xtsts.requires_grad)


output = xss**2 + yss**3

In [None]:
%%timeit
for i in range(len(ys)):
    yi = output[i]
    dudx, dudy = torch.autograd.grad(yi, [xs[i], ys[i]], grad_outputs=dims, is_grads_batched=True, retain_graph=True, create_graph=True)
    dudx = torch.sum((dudx*dims), dim=0)
    dudy = torch.sum((dudy*dims), dim=0)
    print(dudx.requires_grad, dudy.requires_grad)

In [None]:
d2udx2 = torch.autograd.grad(dudx, xtst, grad_outputs=dims, is_grads_batched=True, retain_graph=True, create_graph=True)
# d2udx2[0].requires_grad = True
print(d2udx2[0].requires_grad)
d2udx2 = torch.sum((d2udx2[0]*dims), dim=0)
d2udx2

In [None]:
d2udx2.requires_grad

In [None]:
loss = torch.sum(d2udx2)

In [None]:
loss.backward()

In [None]:
d2udy2 = torch.autograd.grad(dudy, ytst, grad_outputs=dims, is_grads_batched=True, retain_graph=True)
g = (d2udy2[0]*dims)
h = torch.sum(g, dim=0)
h

In [None]:
# d = torch.autograd.grad(c, xtst, grad_outputs=dims, is_grads_batched=True, retain_graph=True)
# e = (d[0]*dims)
# f = torch.sum(e, dim=0)
# f

In [None]:
dim1 = torch.Tensor([[1, 0], [0, 0]])
dim2 = torch.Tensor([[0, 1], [0, 0]])
dim3 = torch.Tensor([[0, 0], [1, 0]])
dim4 = torch.Tensor([[0, 0], [0, 1]])

In [None]:
xtst

In [None]:
dim_grads = torch.cat(
    [
        dim1.view(1, 2, 2), 
        dim2.view(1, 2, 2),
        dim3.view(1, 2, 2),
        dim4.view(1, 2, 2),
    ],
    dim=0
)

In [None]:
# x.requires_grad = True
y = xtst**2 + xtst[0, 0] + xtst[0, 1] + xtst[1, 0] + 9*xtst[1, 1]
a = torch.autograd.grad(y, xtst, grad_outputs=dim_grads, is_grads_batched=True, retain_graph=True, create_graph=True)
a

In [None]:
# a[0].requires_grad = True
b = (a[0]*dim_grads)
b

In [None]:
c = torch.sum(b, dim=0)
c

In [None]:
d = torch.autograd.grad(c, xtst, grad_outputs=dim_grads, is_grads_batched=True, retain_graph=True)
e = (d[0]*dim_grads)
f = torch.sum(e, dim=0)
f

In [None]:
a[0][0], dim1

In [None]:
a[0][1], dim2

In [None]:
a[0][2], dim3

In [None]:
a[0][3], dim4

In [None]:
xtst

In [None]:
y

In [None]:
x.requires_grad = True
# y.requires_grad = True

In [None]:
y = x**2

In [None]:
torch.autograd.grad(y[0, 0, 0, 0], x[0, 0, 0, 0])#, grad_outputs=torch.ones(x.view(-1).shape))

In [None]:
y

In [None]:
y = x**2
torch.autograd.grad(y, x, grad_outputs=torch.ones(x.shape))

In [None]:
x.shape

In [None]:
x.view(-1)

In [None]:
x_flat = x.view(-1, 2)

In [None]:
x_flat.shape

In [None]:
# list_tens = [i.unsqueeze(dim=0) for i in x_flat]
list_tens = [i.unsqueeze(dim=0) for i in x_flat]

In [None]:
list_tens

In [None]:
torch.cat(list_tens, dim=0).shape

In [None]:
list_tens[0].requires_grad

In [None]:
# def forward(inp):
#     print(inp[0].requires_grad)
#     nt = torch.cat(inp, dim=0)
# #     no = nt[:, 0]**2 + 3*nt[:, 1]**2 #+ (nt[:, 0]**2 + nt[:, 1]**3).sum()
#     no = nt[0]**2 + 3*nt[1]**2
# #     oup = [i.unsqueeze(dim=0) for i in no] #
#     oup = no
    
#     dudx = torch.autograd.grad(oup, inp, grad_outputs=torch.ones(oup.shape), create_graph=True)
#     d2udx2 = torch.autograd.grad(dudx, inp, grad_outputs=torch.ones(dudx[0].shape), create_graph=True)
#     return list_out

In [None]:
def forward(list_tens):
    nt = torch.cat(list_tens, dim=0)
    no = nt[:, 0]**2 + 3*nt[:, 1]**2 #+ (nt[:, 0]**2 + nt[:, 1]**3).sum()
    list_out = [i.unsqueeze(dim=0) for i in no] #
    return list_out

In [None]:
idx = 3
out = forward(list_tens)
dudx = torch.autograd.grad(out[idx], list_tens[idx], grad_outputs=torch.ones(out[idx].shape), create_graph=True)
d2udx2 = torch.autograd.grad(dudx, list_tens[idx], grad_outputs=torch.ones(dudx[0].shape), create_graph=True)
d3udx3 = torch.autograd.grad(d2udx2, list_tens[idx], grad_outputs=torch.ones(d2udx2[0].shape), retain_graph=True)

In [None]:
dudx, d2udx2, d3udx3

In [None]:
loss = (dudx[0] + d2udx2[0] + d3udx3[0]).sum()

In [None]:
loss.backward()

In [None]:
def compute_derivatives(inp, oup):
    dudx = torch.autograd.grad(oup, inp, grad_outputs=torch.ones(oup.shape), create_graph=True)
    d2udx2 = torch.autograd.grad(dudx, inp, grad_outputs=torch.ones(dudx[0].shape), create_graph=True)
#     print(dudx[0], d2udx2[0])
    return dudx[0], d2udx2[0] #torch.cat([dudx[0], d2udx2[0]], dim=0).view(2, 2, 1)

In [None]:
# list(map(compute_derivatives, list_tens, out))

In [None]:
a = list(map(compute_derivatives, list_tens, out))

In [None]:
a[0][0]

In [None]:
torch.cat(a[0], dim=0).view(-1)

In [None]:
x.shape

In [None]:
torch.cat(a, dim=0).view(8, 1, 2, 2)[:, :, :]

In [None]:
dudxs, d2udx2s = torch.cat(a, dim=2)[0].T, torch.cat(a, dim=2)[1].T#.reshape(8, 4)

In [None]:
# torch.cat([dudxs, d2udx2s], dim=0).view(8, 2, 2)

In [None]:
dudxs.shape

In [None]:
torch.cat([dudxs.T, d2udx2s.T], dim=0)

In [None]:
def compute_derivatives(lists):
#     print(list_i, list_o)
#     print(len(list_o))
    print(lists.shape)
    list_i = lists[:2]
    list_o = lists[2]
    print(list_i.shape, list_o.requires_grad)
    try:
        list_i.requires_grad = True
        list_o.requires_grad = True
    except:
        pass
#         raise RuntimeError('Already requires_grad')
#     print(list_i.shape, list_o.requires_grad)
    print(list_i.shape, list_o.shape)
    dudx = torch.autograd.grad(list_o, list_i, grad_outputs=torch.ones(list_o.shape), create_graph=True)
    d2udx2 = torch.autograd.grad(dudx, list_i, grad_outputs=torch.ones(dudx[0].shape), create_graph=True)
#     d3udx3 = torch.autograd.grad(d2udx2, list_i, grad_outputs=torch.ones(d2udx2[0].shape), retain_graph=True)
    return dudx, d2udx2

In [None]:
# compute_derivatives(list_tens[0], out[0])

In [None]:
a, b = torch.cat(list_tens, dim=0), torch.cat(out, dim=0)
print(a.shape, b.shape)
c = torch.cat([a, b.view(-1, 1)], dim=-1)
print(c.shape)
gradients = vmap(compute_derivatives)(c)


In [None]:
torch.cat(list_tens, dim=0).shape

In [None]:
list_tens