In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pyarrow.parquet as pq
import sys
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import math

sys.path.append('../../../')
from configs.data_configs.rosbank import data_configs
from configs.model_configs.mTAN.rosbank import model_configs
from src.data_load.dataloader import create_data_loaders
from src.models.mTAND.model import MegaEncoder

In [2]:
conf = data_configs()
model_conf = model_configs()

In [3]:
df = pd.read_parquet(conf.train_path)
df.head()

Unnamed: 0,cl_id,amount,event_time,mcc,channel_type,currency,trx_category,trx_count,target_target_flag,target_target_sum
0,10018,"[10.609081944147828, 10.596659732783579, 10.81...","[17120.38773148148, 17133.667800925927, 17134....","[13, 2, 13, 2, 1, 18, 13, 2, 13, 2, 5, 13, 9, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[5, 3, 5, 3, 1, 1, 5, 3, 5, 3, 1, 5, 5, 5, 5]",15,0,0.0
1,10030,"[4.61512051684126, 6.90875477931522, 10.598857...","[17141.0, 17141.0, 17145.0, 17147.0, 17147.0, ...","[9, 9, 21, 1, 25, 6, 14, 14, 3, 3, 3, 13, 1, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 3, ...",42,1,59.51
2,10038,"[7.4127640174265625, 7.370230641807081, 7.8180...","[17301.0, 17301.0, 17301.0, 17301.774780092594...","[1, 1, 1, 2, 2, 4, 2, 8, 1, 22, 8, 1, 8, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 1, 1, 1, 2, ...",111,0,0.0
3,10057,"[7.494708263135679, 7.736394428979239, 10.7789...","[17151.0, 17151.0, 17153.0, 17154.0, 17155.0, ...","[6, 21, 2, 6, 2, 4, 2, 22, 15, 2, 1, 35, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 4, 1, 4, 1, 3, 1, 1, 3, 1, 1, 1, 4, 1, ...",61,1,62961.31
4,10062,"[8.31898612539206, 8.824824939175638, 6.509067...","[17143.0, 17143.0, 17143.0, 17144.0, 17144.0, ...","[80, 15, 37, 38, 11, 11, 2, 24, 7, 5, 5, 11, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, ...",82,1,107126.35


In [4]:
train_loader, valid_loader = create_data_loaders(conf)

2086it [00:00, 20848.94it/s]

9717it [00:00, 17738.15it/s]


In [5]:
for batch in train_loader:
    break

In [6]:
encoder = MegaEncoder(model_conf=model_conf, data_conf=conf)

In [7]:
out = encoder(batch[0])

In [8]:
out.size()

torch.Size([20, 64, 4])

In [9]:
encoder

MegaEncoder(
  (preprocessor): FeatureProcessor(
    (embed_layers): ModuleDict(
      (channel_type): Embedding(400, 16)
      (currency): Embedding(400, 16)
      (mcc): Embedding(400, 16)
      (trx_category): Embedding(400, 16)
    )
  )
  (encoder): enc_mtan_rnn(
    (att): multiTimeAttention(
      (linears): ModuleList(
        (0-1): 2 x Linear(in_features=16, out_features=16, bias=True)
        (2): Linear(in_features=130, out_features=16, bias=True)
      )
    )
    (gru_rnn): GRU(16, 16, batch_first=True, bidirectional=True)
    (hiddens_to_z0): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=4, bias=True)
    )
    (periodic): Linear(in_features=1, out_features=15, bias=True)
    (linear): Linear(in_features=1, out_features=1, bias=True)
  )
)

In [10]:
features = batch[0]

In [11]:
features.payload

{'amount': tensor([[ 8.6127,  5.8522,  7.8228,  ...,  8.1312,  6.4938,  6.7951],
         [ 7.8317,  8.6127,  5.8522,  ...,  0.0000,  0.0000,  0.0000],
         [ 7.9664,  9.9295,  6.9088,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [12.2061, 11.6784, 11.5129,  ...,  0.0000,  0.0000,  0.0000],
         [12.2061, 11.6784, 11.9184,  ...,  0.0000,  0.0000,  0.0000],
         [11.6784, 11.9184, 10.8198,  ...,  0.0000,  0.0000,  0.0000]],
        dtype=torch.float64),
 'mcc': tensor([[16,  4,  6,  ..., 78,  6,  4],
         [16, 16,  4,  ...,  0,  0,  0],
         [ 6, 99,  3,  ...,  0,  0,  0],
         ...,
         [ 3,  3,  3,  ...,  0,  0,  0],
         [ 3,  3,  3,  ...,  0,  0,  0],
         [ 3,  3,  3,  ...,  0,  0,  0]], dtype=torch.int32),
 'channel_type': tensor([[3, 3, 3,  ..., 3, 3, 3],
         [3, 3, 3,  ..., 0, 0, 0],
         [3, 3, 3,  ..., 0, 0, 0],
         ...,
         [2, 2, 2,  ..., 0, 0, 0],
         [2, 2, 2,  ..., 0, 0, 0],
         [2, 2, 2,  ..., 0

In [12]:
class multiTimeAttention(nn.Module):
    
    def __init__(self, input_dim, nhidden=16, 
                 embed_time=16, num_heads=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = nhidden
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(input_dim*num_heads, nhidden)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        dim = value.size(-1)
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(d_k)
        scores = scores.unsqueeze(-1).repeat_interleave(dim, dim=-1)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -2)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.unsqueeze(-3), -2), p_attn
    
    
    def forward(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        batch, seq_len, dim = value.size()
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        
        query, key = [l(x).view(x.size(0), -1, self.h, self.embed_time_k).transpose(1, 2)
                      for l, x in tqdm(zip(self.linears, (query, key)))]
        x, _ = self.attention(query, key, value, mask, dropout)
        x = x.transpose(1, 2).contiguous() \
             .view(batch, -1, self.h * dim)

        print(x.size())
        print(self.linears)
        print(self.linears[-1].weight.size())
        return self.linears[-1](x)

class enc_mtan_rnn(nn.Module):
    def __init__(self, input_dim, query, latent_dim=2, nhidden=16, 
                 embed_time=16, num_heads=1, learn_emb=False, device='cuda'):
        super(enc_mtan_rnn, self).__init__()
        self.embed_time = embed_time
        self.dim = input_dim
        self.device = device
        self.nhidden = nhidden
        self.query = query
        self.learn_emb = learn_emb
        self.att = multiTimeAttention(input_dim, nhidden, embed_time, num_heads)
        self.gru_rnn = nn.GRU(nhidden, nhidden, bidirectional=True, batch_first=True)
        self.hiddens_to_z0 = nn.Sequential(
            nn.Linear(2*nhidden, 50),
            nn.ReLU(),
            nn.Linear(50, latent_dim * 2))
        if learn_emb:
            self.periodic = nn.Linear(1, embed_time-1)
            self.linear = nn.Linear(1, 1)
        
    
    def learn_time_embedding(self, tt):
        tt = tt.to(self.device)
        tt = tt.unsqueeze(-1)
        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        return torch.cat([out1, out2], -1)
    
    def fixed_time_embedding(self, pos):
        d_model=self.embed_time
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model)
        position = 48.*pos.unsqueeze(2)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(np.log(10.0) / d_model))
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe
       
    def forward(self, x, time_steps):
        time_steps = time_steps.cpu()
        mask = x[:, :, self.dim:]
        mask = torch.cat((mask, mask), 2)
        if self.learn_emb:
            key = self.learn_time_embedding(time_steps).to(self.device)
            query = self.learn_time_embedding(self.query.unsqueeze(0)).to(self.device)
        else:
            key = self.fixed_time_embedding(time_steps).to(self.device)
            query = self.fixed_time_embedding(self.query.unsqueeze(0)).to(self.device)
        out = self.att(query, key, x, mask=None)
        out, _ = self.gru_rnn(out)
        out = self.hiddens_to_z0(out)
        return out

In [15]:
ref_points = torch.linspace(0, 1., 5)


am = features.payload['amount']
mcc = features.payload['mcc']

emb1 = nn.Embedding(384, 16)
mcce = emb1(mcc)

am = am.unsqueeze(-1)
x = torch.cat([mcce, am], dim=-1)

time_steps = features.payload['event_time']

In [19]:
input_dim = x.size(-1)
query = ref_points


enc = enc_mtan_rnn(input_dim, query, latent_dim=2, nhidden=16, 
                 embed_time=16, num_heads=2, learn_emb=True, device='cpu')

In [20]:
out = enc(x.float(), time_steps.float())

2it [00:00, 1668.05it/s]

torch.Size([20, 5, 34])
ModuleList(
  (0-1): 2 x Linear(in_features=16, out_features=16, bias=True)
  (2): Linear(in_features=34, out_features=16, bias=True)
)
torch.Size([16, 34])





In [21]:
out[0][0]

tensor([-0.0151,  0.0093, -0.0105,  0.0625], grad_fn=<SelectBackward0>)

In [13]:
class FeatureProcessor(nn.Module):

    def __init__(self, model_conf, data_conf):
        super(FeatureProcessor, self).__init__()
        self.model_conf = model_conf
        self.data_conf = data_conf

        self.emb_names = list(self.data_conf.features.embeddings.keys())
        self.init_embed_layers()

    def init_embed_layers(self):
        self.embed_layers = nn.ModuleDict()
        
        for name in self.emb_names:
            vocab_size = self.data_conf.features.embeddings[name]['max_value']
            self.embed_layers[name] = nn.Embedding(vocab_size, self.model_conf.features_emb_dim)

    def forward(self, padded_batch):
        numeric_values = []

        for key, values in padded_batch.payload.items():
            if key in self.emb_names:
                numeric_values.append(self.embed_layers[key](values))
            else:
                if key == 'event_time':
                    time_steps = values
                else:
                    numeric_values.append(values.unsqueeze(-1).float())
        
        x = torch.cat(numeric_values, dim=-1)
        return x, time_steps


In [14]:
processor = FeatureProcessor(model_conf=model_conf, data_conf=conf)

In [15]:
x, t = processor(features)

In [16]:
x.size()

torch.Size([20, 100, 65])

In [34]:
class MegaEncoder(nn.Module):

    def __init__(self, model_conf, data_conf):
        super(MegaEncoder, self).__init__()
        self.model_conf = model_conf
        self.data_conf = data_conf

        all_emb_size = self.model_conf.features_emb_dim * len(self.data_conf.features.embeddings)
        all_numeric_size = len(self.data_conf.features.numeric_values)
        self.input_dim = all_emb_size + all_numeric_size

        self.preprocessor = FeatureProcessor(model_conf=self.model_conf, data_conf=self.data_conf)

        self.ref_points = torch.linspace(0., 1., self.model_conf.num_ref_points)
        self.encoder = enc_mtan_rnn(
                        self.input_dim,
                        self.ref_points,
                        latent_dim=self.model_conf.latent_dim,
                        nhidden=self.model_conf.ref_point_dim, 
                        embed_time=self.model_conf.time_emb_dim,
                        num_heads=self.model_conf.num_heads_enc,
                        learn_emb=True,
                        device=self.model_conf.device)

    def forward(self, padded_batch):
        x, time_steps = self.preprocessor(padded_batch)
        out = self.encoder(x, time_steps.float())

        return out

In [35]:
enc = MegaEncoder(model_conf=model_conf, data_conf=conf)

In [36]:
batch[0]

<src.data_load.dataloader.PaddedBatch at 0x7fed44502440>

In [37]:
out = enc(batch[0])

2it [00:00, 1010.31it/s]




torch.Size([20, 64, 130])
ModuleList(
  (0-1): 2 x Linear(in_features=16, out_features=16, bias=True)
  (2): Linear(in_features=130, out_features=16, bias=True)
)
torch.Size([16, 130])


In [38]:
out.size()

torch.Size([20, 64, 4])

In [76]:
import math
def learn_time_embedding(tt):
    tt = tt.unsqueeze(-1)
    out2 = torch.sin(periodic(tt))
    out1 = linear(tt)
    return torch.cat([out1, out2], -1)

In [62]:
ref_points = torch.linspace(0, 1., 5)
periodic = nn.Linear(1, 16-1)
linear = nn.Linear(1, 1)

In [70]:
key = learn_time_embedding(time_steps.float())
query = learn_time_embedding(ref_points.unsqueeze(0))

In [73]:
key.size(), query.size()

(torch.Size([20, 100, 16]), torch.Size([1, 5, 16]))

In [116]:
class multiTimeAttention(nn.Module):
    
    def __init__(self, input_dim, nhidden=16, 
                 embed_time=16, num_heads=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = nhidden
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(input_dim*num_heads, nhidden)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        dim = value.size(-1)
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(d_k)
        scores = scores.unsqueeze(-1).repeat_interleave(dim, dim=-1)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -2)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.unsqueeze(-3), -2), p_attn
    
    
    def forward(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        batch, seq_len, dim = value.size()
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        
        query, key = [l(x).view(x.size(0), -1, self.h, self.embed_time_k).transpose(1, 2)
                      for l, x in tqdm(zip(self.linears, (query, key)))]
        x, _ = self.attention(query, key, value, mask, dropout)
        x = x.transpose(1, 2).contiguous() \
             .view(batch, -1, self.h * dim)

        print(x.size())
        print(self.linears)
        print(self.linears[-1].weight.size())
        return self.linears[-1](x)

In [117]:
2*dim

34

In [118]:
att = multiTimeAttention(dim, 16, 16, 2)

In [119]:
out = att(query, key, x.float())

2it [00:00, 2166.48it/s]

torch.Size([20, 5, 34])
ModuleList(
  (0-1): 2 x Linear(in_features=16, out_features=16, bias=True)
  (2): Linear(in_features=34, out_features=16, bias=True)
)
torch.Size([16, 34])





In [107]:
out.size()

torch.Size([20, 5, 16])