In [None]:
!uv add torch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

In [None]:
sdf = torch.tensor(np.load("/kaggle/input/datasets/dagadam/pinn-test/distance_field.npy"), dtype=torch.float).to(device)
uv = torch.tensor(np.load("/kaggle/input/datasets/dagadam/pinn-test/uv.npy"))
vv = torch.tensor(np.load("/kaggle/input/datasets/dagadam/pinn-test/vv.npy"))
#plt.imshow(sdf)

# Here define the model of our neural network

In [None]:
start_pos = torch.tensor([5,5]).to(device)#np.random.rand(2) * 40, dtype=torch.float).to(device)
end_pos = torch.tensor([35,35]).to(device)#np.random.rand(2) * 40, dtype=torch.float).to(device)

class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.dense1 = nn.Linear(1, 12)
        self.relu = nn.ReLU()
        # change activ fun. more smooth
        self.dense2 = nn.Linear(12, 4)
        self.dense3 = nn.Linear(4, 2)


    def forward(self, t):
        x = torch.tensor(t * 1e-2, dtype=torch.float32, device=device)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        x = self.relu(x)
        x = self.dense3(x)
        return (1-t)*start_pos + t*end_pos + t * (1-t)*x

model = PINN().to(device)
print(model)


# Next we define the loss function

In [None]:
class PathLoss(nn.Module):
    def __init__(self):
        super(PathLoss, self).__init__()

    def forward(self, path, sdf):
        # path: (100, 2), sdf: (H, W) tensor
        H, W = sdf.shape

        # normalize path coords to [-1, 1] as required by grid_sample
        grid = path.clone()
        grid[:, 0] = (grid[:, 0] / (W - 1)) * 2 - 1  # x
        grid[:, 1] = (grid[:, 1] / (H - 1)) * 2 - 1  # y

        # grid_sample expects (N, C, H, W) and grid (N, H, W, 2)
        sdf_input = sdf.unsqueeze(0).unsqueeze(0)          # (1, 1, H, W)
        grid_input = grid.unsqueeze(0).unsqueeze(0)        # (1, 1, 100, 2)

        sdf_vals = F.grid_sample(sdf_input.cpu(), grid_input.cpu(),
                                  align_corners=True,
                                  padding_mode='border').to(path.device)  # (1, 1, 1, 100)
        sdf_vals = sdf_vals.squeeze()                      # (100,)
        loss = (1 / (sdf_vals**2 + 1e-4)).sum()
        return loss
loss = PathLoss().to(device)

# Training Function

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
def train(model, optimizer, device, sdf, loss_fn):
    training_steps = 1000
    model.train()
    for i in range(training_steps):
        optimizer.zero_grad()
        #start_pos = torch.tensor(np.random.rand(2) * 40, dtype=torch.float).to(device)
        #end_end = torch.tensor(np.random.rand(2) * 40, dtype=torch.float).to(device)
        path_list = []

        for t in range(0, 100):
            t_tensor = torch.tensor(t * 1e-2, dtype=torch.float32, device=device)
            path_list.append(model(t_tensor.unsqueeze(0)))

        path = torch.stack(path_list)
        loss = loss_fn.forward(path, sdf)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")
            # plt.plot(path_list[0].cpu().detach().numpy(), path_list[1].cpu().detach().numpy(), label="path")


# Run Training

In [None]:
epochs = 3
for epoch in range(epochs):
    train(model, optimizer, device, sdf, loss)

# Visualize path

In [None]:
#print(start_pos)
#print(end_pos)
path_x = []
path_y = []
for t in range(0, 100):
            t_tensor = torch.tensor(t * 1e-2, dtype=torch.float32, device=device)
            #path_list.append(model(t_tensor.unsqueeze(0)))
            point = model.forward(t_tensor.unsqueeze(0))
            #print(model.forward(t_tensor.unsqueeze(0)))
            path_x.append(point[0].detach().numpy())
            path_y.append(point[1].detach().numpy())
#print(path_x)

In [None]:
grad_y, grad_x = np.gradient(sdf)
plt.quiver(grad_x, grad_y, scale=30)
plt.plot(path_x, path_y)
#def constraint(t):
#    return (1-t)*start_pos.detach().numpy() + t*end_pos.detach().numpy()
#plt.plot(np.linspace(start_pos[0], end_pos[0]), constraint(np.linspace(start_pos[0], end_pos[0])))
plt.scatter(*start_pos)
plt.scatter(*end_pos)