In [1]:
import sys

sys.path.append('../GSL')

import pickle
import os
from glob import glob
import yaml
from easydict import EasyDict as edict

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from math import sqrt

from models.graph_learning_Attention.probsparseattention import ProbAttention
from models.message_passing.MPNN import InterCorrealtionStack
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from models.ic_pn_beats_model import IC_PN_BEATS_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_file = glob('../GSL/config/ic_pnbeats_general.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [23]:
config.train.batch_size = 1
model = IC_PN_BEATS_model(config)

In [24]:
dataset = pickle.load(open('../GSL/data/ECL/temporal_signal_12_12_128.pickle', 'rb'))
train_dataset = dataset['test']
data_batch = next(iter(train_dataset))

In [25]:
data_batch

DataBatch(x=[321, 5, 12], edge_index=[2, 103041], y=[321, 12], batch=[321])

In [26]:
backcast, forecast, outputs = model(data_batch.x, True)

In [27]:
backcast.shape

torch.Size([321, 12])

In [28]:
forecast.shape

torch.Size([321, 12])

In [29]:
outputs.keys()

dict_keys(['per_trend_backcast', 'per_trend_forecast', 'per_seasonality_backcast', 'per_seasonality_forecast', 'singual_backcast', 'singual_forecast', 'attention_matrix'])

In [30]:
outputs['per_trend_backcast'].shape

(3, 321, 12)

In [31]:
outputs['per_trend_forecast'].shape

(3, 321, 12)

In [32]:
outputs['per_seasonality_backcast'].shape

(3, 321, 12)

In [33]:
outputs['per_seasonality_forecast'].shape

(3, 321, 12)

In [34]:
outputs['singual_backcast'].shape

(1, 321, 12)

In [35]:
outputs['singual_forecast'].shape

(1, 321, 12)

In [36]:
outputs['attention_matrix'].shape

(4, 1, 4, 321, 321)

In [38]:
np.log(321)*2

11.542882246260032

In [None]:
temp = data_batch.x.view(128, 321, 5, 12)

In [None]:
temp.permute(0, 1, 3, 2).shape

In [None]:
data_batch.x.permute(0,2,1).shape

In [None]:
layer = nn.Linear(5,1)

In [None]:
out = layer(temp.permute(0, 1, 3, 2))

In [None]:
out.shape

In [None]:
out = out.view(128, 321, 12)

In [None]:
out.shape

In [None]:
gl_out = gl(out, out)

In [None]:
gl_out.shape

In [None]:
attention_matrix = []
for stack in range(16):
    attention_matrix.append(gl_out)

In [None]:
attention_matrix = torch.stack(attention_matrix, dim=0)

In [None]:
attention_matrix.shape

In [None]:
gl_out = gl_out.permute(1,0,2,3)

In [None]:
temp = []
for head in range(gl_out.shape[0]):
    temp.append(dense_to_sparse(gl_out[head])[0])

In [None]:
temp = torch.stack(temp, axis=0)

In [None]:
temp.shape

In [None]:
temp

In [None]:
mpglu = InterCorrealtionStack(
                    hidden_dim=12,
                    message_norm=True,
                    GLU=True)

In [None]:
mpglu

In [None]:
inpts = layer(data_batch.x.permute(0, 2, 1))

In [None]:
inpts = inpts.squeeze()

In [None]:
x = mpglu(inpts, temp[0])

In [None]:
x.shape

In [None]:
_multi_head = []

for head in range(temp.shape[0]):
    _multi_head.append(mpglu(inpts, temp[head]))

In [None]:
_multi_head = torch.stack(_multi_head,axis=0)

In [None]:
_multi_head.shape

In [None]:
_multi_head.device

In [None]:
weight = nn.Linear(4,1, bias=False)

In [None]:
weight.weight = __parameter

In [None]:
weight.weight

In [None]:
__parameter = nn.Parameter(torch.randn(4))

In [None]:
__parameter

## Outer Attention Layer

In [None]:
# Attention Layer Preprocess input X -> Q, K, V

queries = torch.Tensor(batch, L_Q, d_model)
keys = torch.Tensor(batch, L_K, d_model)
values = torch.Tensor(batch, L_K, d_model)

In [None]:
# Make Head

B, L, _ = queries.shape
_, S, _ = keys.shape
H = n_heads

q = queries.view(B, L, H, -1)
k = keys.view(B, S, H, -1)
v = values.view(B, S, H, -1)

In [None]:
q.shape

### Inner Attention

In [None]:
B, L_Q, H, D = q.shape
B, L_K, _, _ = k.shape

In [None]:
queries = q.transpose(2,1)
keys = k.transpose(2,1)
values = v.transpose(2,1)

In [None]:
# Sampling num

U_part = factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
u = factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 

In [None]:
U_part = U_part if U_part<L_K else L_K
u = u if u<L_Q else L_Q

In [None]:
U_part

In [None]:
u

### Get Sparsity Measurement with random Key sample and Query sample

In [None]:
Q = queries
K = keys

K.shape

In [None]:
B, H, L_K, E = keys.shape
B, _, L_Q, _ = queries.shape

In [None]:
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
K_expand.shape

In [None]:
index_sample = torch.randint(L_K, (L_Q, U_part)) # real U = U_part(factor*ln(L_k))*L_q
index_sample.shape

In [None]:
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample,:]
K_sample.shape

In [None]:
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
Q_K_sample.shape

In [None]:
# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(u, sorted=False)[1]
M_top.shape

In [None]:
Q_reduce = Q[torch.arange(B)[:, None, None],
             torch.arange(H)[None, :, None],
             M_top, :] # factor*ln(L_q)

Q_reduce.shape

In [None]:
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
Q_K.shape

In [None]:
scores_top = Q_K
index = M_top

### Get initial context

In [None]:
v.shape

In [None]:
B, H, L_V, D = v.shape

In [None]:
v_sum = v.mean(dim=-2)

In [None]:
v_sum.shape

In [None]:
v.shape

In [None]:
contex = v_sum.unsqueeze(-2).expand(B, H, L_Q, v_sum.shape[-1]).clone()

In [None]:
contex.shape

### Update context

In [None]:
attn = torch.softmax(scores_top, dim=-1)

In [None]:
attns = (torch.ones([3,2,28,28])/28).type_as(attn)

In [None]:
attns.shape

In [None]:
attns[torch.arange(B)[:, None, None], torch.arange(2)[None, :, None], M_top, :] = attn

In [None]:
attns.shape