In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, Tensor
import scipy.sparse as sparse
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp


In [2]:
# Global variables
demb = 120
dropout = 0.2
batch_size = 16
device = 'mps' if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

ns = 20
nr = 12
npairs_local = int(nr*(nr - 1)/2)
outlier_percentage = 0.5
noise_std = 0.2

index_to_pair = {}
counter = 0
for i in range(nr):
    for j in range(i+1, nr):
        index_to_pair[counter] = (i, j)
        counter += 1


In [3]:
def generate_problem(ns, nr, outlier_percentage, noise_std):
    sender_position_truth = 5*np.random.rand(ns, 3)
    receiver_position_truth = 5*np.random.rand(nr, 3)
    distance_truth = torch.Tensor(sp.spatial.distance.cdist(
        sender_position_truth, receiver_position_truth))
    npairs = int(nr*(nr-1)/2)
    measurements = torch.zeros(ns, npairs)
#    measurements = distance_truth.copy()
    counter = 0
    for i in range(nr):
        for j in range(i+1, nr):
            measurements[:, counter] = distance_truth[:, i] - \
                distance_truth[:, j]
            counter += 1

    outliers = np.random.rand(ns, npairs) < outlier_percentage
    # Introducing outliers by either doubling or halving the measurement, could be done in better way
    temp = np.ones((ns, npairs))
    temp[(np.random.rand(ns, npairs) < 0.5)] = -0.5
    measurements = measurements*temp*outliers + measurements
    # introducing noise
    measurements = measurements + noise_std * \
        np.median(distance_truth)*np.random.randn(ns, npairs)

    sender_position_truth = torch.tensor(sender_position_truth)
    receiver_position_truth = torch.tensor(receiver_position_truth)
    # translation
    receiver_position_truth = receiver_position_truth - \
        receiver_position_truth[0, :]
    sender_position_truth = sender_position_truth - \
        receiver_position_truth[0, :]
    # rotation
    R = torch.zeros(3, 3, dtype=torch.float64)
    R[0, :] = receiver_position_truth[1, :] / \
        receiver_position_truth[1, :].norm()
    R[1, :] = receiver_position_truth[2, :] - \
        R[0, :]*(receiver_position_truth[2, :]@R[0, :])
    R[1, :] = R[1, :]/R[1, :].norm()
    R[2, :] = torch.cross(R[0, :], R[1, :])
    receiver_position_truth = receiver_position_truth@R.T
    sender_position_truth = sender_position_truth@R.T

    return (measurements, sender_position_truth, receiver_position_truth, 1-outliers)

# prob = generate_problem(20,12,0.3,0.2)
# prob[0]


In [4]:
def getValueEncoding(value_to_encode, d):
    max_v = max(value_to_encode)
    denominator = max_v/np.pi/np.power(1.2, [i for i in range(int(d/2))])
    sin_coeff = np.sin(np.expand_dims(value_to_encode, 1) /
                       np.expand_dims(denominator, 0))
    cos_coeff = np.cos(np.expand_dims(value_to_encode, 1) /
                       np.expand_dims(denominator, 0))
    return np.concatenate([sin_coeff, cos_coeff], 1)


In [5]:

class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(demb, head_size, bias=False)
        self.query = nn.Linear(demb, head_size, bias=False)
        self.value = nn.Linear(demb, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size
#        self.ffwd = nn.Sequential(
#            nn.Linear(head_size,head_size*2),
#            nn.ReLU(),
#            nn.Linear(head_size*2,head_size)
#            )

    def forward(self, x):
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        wei = (k @ q.transpose(2, 1)) * self.head_size**-0.5

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        res = wei @ v
        return res


class MultiAttentionHead(nn.Module):

    def __init__(self, head_size, n_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [SelfAttentionHead(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*head_size, demb)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(x))


class Block(nn.Module):

    def __init__(self, demb, n_heads):
        super().__init__()
        head_size = demb // n_heads
        self.heads = MultiAttentionHead(head_size, n_heads)
        self.ffwd = nn.Sequential(
            nn.Linear(demb, 4*demb),
            nn.ReLU(),
            nn.Linear(4*demb, demb),
        )
        self.ln1 = nn.LayerNorm(demb)
        self.ln2 = nn.LayerNorm(demb)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.heads(self.ln1(x))
        return x + self.dropout(self.ffwd(self.ln2(x)))


class TransformerNetwork(nn.Module):

    def __init__(self, n_layers, n_heads):
        super().__init__()
        # self.pos_encode = nn.Embedding(2*data_size,demb, device=device)
        # self.tok_encode = nn.Embedding(2,demb, device=device)

        self.transformer = nn.Sequential(
            *[Block(demb, n_heads) for _ in range(n_layers)])
        self.condense = nn.Sequential(
            nn.Linear(demb, 2*demb),
            nn.ReLU(),
            nn.Linear(2*demb, 3),
        )
        self.ffwd = nn.Sequential(
            nn.Linear(3960, 400),
            nn.ReLU(),
            nn.Linear(400, (ns+nr)*3),
        )
        self.sigmoid = nn.Sigmoid(),
        self.ln1 = nn.LayerNorm(demb)
        self.ln2 = nn.LayerNorm(3)
        self.apply(self._init_weights)
        # self.pos = torch.Tensor(getPositionEncoding(
        #    seq_len=2*data_size, d=int(demb/2), n=data_size*2)).to(device)
        self.flatten = nn.Flatten()
        self.unflatten = nn.Unflatten(1, (ns+nr, 3))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        # pos_emb = self.pos_encode(torch.arange(T,device=device))
        # x = torch.cat((x.unsqueeze(2).repeat((1, 1, int(demb/2))),
        #              self.pos.unsqueeze(0).repeat((B, 1, 1))), dim=2)

        x = self.transformer(x)

        x = self.ln1(x)
        x = self.condense(x)
        x = self.ln2(x)
        x = self.flatten(x)

        x = self.ffwd(x)
        x = self.unflatten(x)
        return x


In [6]:
model = TransformerNetwork(4, 4).to(device)
loss_fn = torch.nn.HuberLoss()


def train(model, loss_fn, optimizer, generate_problem, batch_size, batches_in_epoch):
    model.train()
    btop = 10
    loss_summer = 0
    for batch in range(batches_in_epoch):
        # Compute prediction error

        X = torch.zeros(batch_size, ns*npairs_local, demb)
        y = torch.zeros(batch_size, ns+nr, 3)

        for i in range(batch_size):
            outlier_percentage_local = np.random.rand(1)*outlier_percentage
            noise_std_local = np.random.rand(1)*noise_std
            prob = generate_problem(
                ns, nr, outlier_percentage_local, noise_std_local)
            x_coord, y_coord = torch.meshgrid(torch.arange(
                ns), torch.arange(npairs_local), indexing='ij')
            X[i, :] = torch.concatenate([prob[0].flatten().unsqueeze(1)*torch.ones(prob[0].numel(), demb//2),
                                        torch.Tensor(getValueEncoding(
                                            x_coord.flatten(), d=demb/6)),
                                        torch.Tensor(getValueEncoding(y_coord.flatten().apply_(
                                            lambda x: index_to_pair[x][0]), d=demb/6)),
                                        torch.Tensor(getValueEncoding(y_coord.flatten().apply_(
                                            lambda x: index_to_pair[x][1]), d=demb/6))
                                         ], 1)

            y[i, :] = torch.cat([prob[1], prob[2]])

        X = X.to(device)
        y = y.to(device)

        pred = model(X)

        # print((X - y).norm(dim=2).shape)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_summer += loss.item()/btop
        if batch % btop == btop-1:
            loss = loss_summer
            loss_summer = 0
            print(f"loss: {loss:>7f}  [{batch:>5d}/{batches_in_epoch}]")


In [7]:
epochs = 5000
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for t in range(epochs):
    # if t % 100 == 0:
    # torch.save(model, "model" + str(t))
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, loss_fn, optimizer, generate_problem, batch_size, 1000)
    # test(test_dataloader, model, loss_fn)
print("Done!")


Epoch 1
-------------------------------
loss: 1.572395  [    9/1000]
loss: 1.533993  [   19/1000]
loss: 1.502200  [   29/1000]
loss: 1.527058  [   39/1000]
loss: 1.500924  [   49/1000]
loss: 1.509917  [   59/1000]
loss: 1.524668  [   69/1000]
loss: 1.504204  [   79/1000]
loss: 1.523578  [   89/1000]
loss: 1.506509  [   99/1000]
loss: 1.511704  [  109/1000]
loss: 1.508141  [  119/1000]
loss: 1.517792  [  129/1000]
loss: 1.503021  [  139/1000]
loss: 1.499395  [  149/1000]
loss: 1.519017  [  159/1000]
loss: 1.514807  [  169/1000]
loss: 1.496973  [  179/1000]
loss: 1.501586  [  189/1000]
loss: 1.506149  [  199/1000]
loss: 1.484012  [  209/1000]
loss: 1.504206  [  219/1000]
loss: 1.501755  [  229/1000]
loss: 1.517225  [  239/1000]
loss: 1.502507  [  249/1000]
loss: 1.498120  [  259/1000]
loss: 1.497862  [  269/1000]
loss: 1.468768  [  279/1000]
loss: 1.490078  [  289/1000]
loss: 1.491009  [  299/1000]
loss: 1.502275  [  309/1000]
loss: 1.490367  [  319/1000]
loss: 1.498806  [  329/1000]
los

KeyboardInterrupt: 