In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
n = 8
m = 128
f = 128

p = torch.randn(n, 2, dtype=torch.float64, device='cpu') # inital mic-position

# enable gradiant and setup optim
p.requires_grad = True
optim = torch.optim.AdamW([p], lr=5, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=2e-4, total_iters=2000)

his = []
for i in range(800):
    optim.zero_grad()

    # create random audio frequency, target direction, interference direction
    freq = 50 + (6000 - 50) * torch.rand(f, device=p.device, dtype=p.dtype)
    u = torch.nn.functional.normalize(torch.randn(f, m, 1, 3, device=p.device, dtype=p.dtype) * torch.tensor([250, 300, 100], device=p.device, dtype=p.dtype) + torch.tensor([200, 0, 25], device=p.device, dtype=p.dtype), p=2, dim=-1)
    v = torch.nn.functional.normalize(torch.randn(f, 1, m, 3, device=p.device, dtype=p.dtype) * torch.tensor([500, 600, 200], device=p.device, dtype=p.dtype) + torch.tensor([100, 0, 0], device=p.device, dtype=p.dtype), p=2, dim=-1)

    # get guide vector
    pos =  torch.concat([p, torch.zeros(n, 1, device=p.device, dtype=p.dtype)], dim=-1) + 2.0 * torch.randn(f, m, m, n, 3, device=p.device, dtype=p.dtype)
    corr = torch.exp(- 2 * torch.pi * freq[:, None, None, None] * (0 + 1j) * torch.einsum('...i, ...ji -> ...j', u - v, pos) / 340000).mean(dim=-1).abs()

    # set weight by angle between target & interference (disable loss when target is too close to interference)
    w = torch.einsum('...i, ...i -> ...', u, v).acos() / (80 * torch.pi / 180)
    w = 2 * (w**2 - torch.nn.functional.relu(w**2 - 1)) - 1
    loss = w * corr
    # loss is giving by the max interference
    loss = (torch.softmax(loss, dim=-1) * loss).sum(dim=-1)
    loss = loss.mean()

    # GD
    loss.backward()
    optim.step()
    scheduler.step()

    # info save & print
    his.append(loss.item())
    if (i + 1) % 200 == 0:
        print(loss.item())
        plt.plot(his)
        plt.show()

In [None]:
p = torch.concat([p, torch.zeros(n, 1, device=p.device, dtype=p.dtype)], dim=-1)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(p[:, 0].numpy(force=True), p[:, 1].numpy(force=True), p[:, 2].numpy(force=True))
print(p)
print(str(p.round(decimals=2).tolist()).replace('[', '{').replace(']', '}'))