In [1]:
# -*- coding: utf-8 -*-
"""how-tsp-should-be.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1InE1iW8ARzndPpvqH_9y22s81sOiHxPs
"""

from tqdm import tqdm
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

from math import sqrt
from collections import deque
import os
import random
import pickle
import ipdb

#  torch.manual_seed(30)
#  random.seed(30)
torch.manual_seed(33)
random.seed(33)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# assert device.type == "cuda", "CUDA is not available. Please check your GPU setup."

NVTXS = 6
MAXDIST = NVTXS+1
AVGDEG = 2
SEQLEN = NVTXS + 1
HIDDENDIM = 4*NVTXS+2

# 0: ANSFLAG
# 1:NVTXS+1 NBRS
# NVTXS+1: 2*NVTXS+1 REACH
# 2*NVTXS+1: 3*NVTXS+1 SELF
# -1 NOTANSFLAG

START_REACH = NVTXS+1
START_OUT = 2*NVTXS+1
START_SELF = 3*NVTXS+1
SRC_FLAG_IDX = START_SELF
SOURCE = 1
TARGET = 2
ANS_FLAG_IDX = 0
NOTANS_FLAG_IDX = -1

def print_everything(data):
    print("NBRS")
    print(data[0, 1:, 1:1+NVTXS])
    print("REACH")
    print(data[0, 1:, START_REACH:START_REACH+NVTXS])
    print("ANSFLAG")
    print(data[0, :, 0])
    print("MORE FLAGS")
    print(data[0, :, -1])
    print("SELF")
    print(data[0, 1:, START_SELF:START_SELF+NVTXS])
    print("OUT")
    print(data[0, 0, START_OUT:START_OUT+NVTXS])


def random_graph():
    data = torch.zeros((SEQLEN, HIDDENDIM))

    for i in range(1,NVTXS+1):
        data[i, START_SELF-1+i] = 1

    adj_list = [set() for _ in range(SEQLEN)]
    indices = [random.randint(1, NVTXS) for _ in range(AVGDEG * NVTXS)]
    for i in range(0, len(indices), 2):
        u = indices[i]
        v = indices[i + 1]
        if u != v:
            data[v,u] = 1
            data[u,v] = 1
            data[v,NVTXS+u] = 1
            data[u,NVTXS+v] = 1
            adj_list[u].add(v)
            adj_list[v].add(u)

    data[0, ANS_FLAG_IDX] = 1
    data[1:, NOTANS_FLAG_IDX] = 1

    # TODO: this is kind of a hack
    data[0, START_REACH:START_REACH+NVTXS] = 1
    return data, adj_list

"""
input: G, represented as an adjacency list
output: distance from SOURCE to TARGET
"""
def SSSP(G):
    dist = [MAXDIST for _ in G]
    dist[SOURCE] = 0
    frontier = deque()
    frontier.append(SOURCE)
    while len(frontier) > 0:
        vtx = frontier.popleft()
        for x in G[vtx]:
            if dist[x] == MAXDIST:
                dist[x] = 1 + dist[vtx]
                frontier.append(x)
                if x == TARGET:
                    return dist[TARGET]
    return MAXDIST

def mkbatch(size):
    graphs1 = []
    distance1 = []

    for i in range(size):
        data, adj_list = random_graph()
        dist = SSSP(adj_list)
        graphs1.append(data)
        distance1.append(dist)

        print(adj_list)

    data = torch.stack(graphs1)
    labels = torch.tensor(distance1, dtype=torch.float16)
    return data, labels

"""
TODO: WRAP EVERYTHING in nn.Parameter(torch.zeros((1, HIDDENDIM)))
and then do my perturbing parameters experiment

TODO:
    USE activation magic to bring everything back to the 0/1 realm instead of possibly being 0/2 valued
"""

class SillyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.most_KQVs = []
        for head in range(1,NVTXS+1):
          Q = torch.zeros((2, HIDDENDIM))
          Q[0, START_REACH-1+head] = 1000
          Q[1, NOTANS_FLAG_IDX] = 1

          K = torch.zeros((2, HIDDENDIM))
          K[0, head] = 1
          K[1, ANS_FLAG_IDX] = 200

          V = torch.zeros((NVTXS,HIDDENDIM))
          for i in range(NVTXS):
              V[i, START_SELF+i] = 1

          self.most_KQVs.append((K, Q, V))

        self.weird_KQVs = []
        for layer in range(NVTXS):
            K = torch.zeros((3, HIDDENDIM))
            K[0, NOTANS_FLAG_IDX] = -1000
            K[0, SRC_FLAG_IDX] = +1100
            K[1, NOTANS_FLAG_IDX] = -1000
            K[1, NVTXS+TARGET] = +1100
            K[1, ANS_FLAG_IDX] = -1100
            K[2, ANS_FLAG_IDX] = 10

            Q = torch.zeros((3, HIDDENDIM))
            Q[:, ANS_FLAG_IDX] = 1

            V = torch.zeros((NVTXS, HIDDENDIM))
            V[layer, SRC_FLAG_IDX] = 1

            self.weird_KQVs.append((K, Q, V))

    def forward(self, src):
      for layer in range(NVTXS):
        allKQVs = [self.weird_KQVs[layer]] + self.most_KQVs
        head_outputs = []
        for (K, Q, V) in allKQVs:
            ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))
            qsrc = torch.matmul(src, Q.unsqueeze(0).transpose(-2, -1))
            vsrc = torch.matmul(src, V.unsqueeze(0).transpose(-2, -1))

            scores = torch.matmul(qsrc, ksrc.transpose(-2, -1))
            attention_weights = torch.softmax(scores, dim=-1)
            head_output = torch.matmul(attention_weights, vsrc)
            head_outputs.append(head_output)

        new_reaches = sum(head_outputs[1:])
        BSZ = new_reaches.shape[0]

        nodelta_nbrs = torch.zeros((BSZ, SEQLEN, NVTXS+1))
        morepadlol = torch.zeros((BSZ, SEQLEN, 1+NVTXS))

        DIFF = torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)
        src += torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)
        src[:, :, START_REACH:START_REACH+NVTXS] = 2*torch.sigmoid(src[:,:, START_REACH:START_REACH+NVTXS]*1000)-1

        #  print("SRC")
        #  print_everything(src)

      canreach = src[:,0,START_OUT:START_OUT+NVTXS]
      #  __import__('ipdb').set_trace()
      final_output = 1+torch.sum(1-canreach,dim=1)
      return final_output

model = SillyTransformer()
model.to(device)

data, labels = mkbatch(10)
assert torch.all(model(data) == labels)



[set(), set(), {5, 6}, {4}, {3}, {2, 6}, {2, 5}]
[set(), {6}, set(), {4, 5, 6}, {3}, {3, 6}, {1, 3, 5}]
[set(), {4}, set(), {4, 5}, {1, 3}, {3, 6}, {5}]
[set(), {2, 6}, {1, 6}, {6}, set(), set(), {1, 2, 3}]
[set(), {3}, {3}, {1, 2, 5, 6}, {5}, {3, 4}, {3}]
[set(), {3, 6}, {4}, {1}, {2}, {6}, {1, 5}]
[set(), {2, 3}, {1, 3, 6}, {1, 2, 4}, {3}, set(), {2}]
[set(), {4}, set(), {4}, {1, 3, 5, 6}, {4}, {4}]
[set(), {3, 4, 5}, {6}, {1}, {1, 6}, {1}, {2, 4}]
[set(), {5, 6}, {6}, {6}, {5, 6}, {1, 4}, {1, 2, 3, 4}]
