In [None]:
import torch as th
import torch.nn.functional as F
import math
import time
import torch.nn as nn
from torch import Tensor

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = th.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = th.matmul(attention, v)
    return values

def batched_scaled_dot_product(q,k,v,mask,num_batches=2):
    chunked_queries = q.chunk(num_batches, dim=-2)
    result = None
    for cq in chunked_queries:
        current_result = scaled_dot_product(cq, k, v)
        if result is None:
            result = current_result
        else:
            result = th.cat((result, current_result), dim=-2)
    return result

In [None]:
'''batch_size = 10
lenght = 10000
dim=2048

q = th.zeros([batch_size,1,dim], device='cuda')
k = th.zeros([batch_size,lenght,dim], device='cuda')
v = th.zeros([batch_size,lenght,dim], device='cuda')

values2 = scaled_dot_product(q,k,v)'''

In [147]:
from urllib.parse import non_hierarchical


class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads, num_chunks, batched, last_n=None):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self.embed_dim = embed_dim
        self.num_chunks = num_chunks

        self.batched = batched
        self.last_n = last_n

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None):
        qkv = self.qkv_proj(x)
        batch_size, seq_length, _ = x.size()


        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        if self.last_n is not None:
            q = q[:,:,-self.last_n:]
            seq_length = self.last_n
        # Determine value outputs
        if self.batched:
            values = batched_scaled_dot_product(q, k, v, mask=mask, num_batches=self.num_chunks)
        else:
            values = scaled_dot_product(q, k, v)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        return o

In [172]:
from unittest import result
from xml.etree.ElementTree import tostring


class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0, batched=True, last_n=None, num_chunks = 10):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads, num_chunks=num_chunks, batched=batched, last_n=last_n)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        self.input_dim = input_dim

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)
        self.batched = batched
        self.last_n = last_n
        self.num_chunks = num_chunks
        self.positional_encoding = PositionalEncoding(d_model=input_dim)
        self.dim_feedforward = dim_feedforward

    def forward(self, x):
        #x = self.norm1(x)
        x =  x * math.sqrt(self.dim_feedforward)
        x = self.positional_encoding.forward(x)
        # Attention part
        attn_out = self.self_attn(x)
        # MLP part
        linear_out = self.linear_net(attn_out)
        x = self.norm2(x)
        result = th.cat((x, linear_out), dim=1)

        return result

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 50000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = th.arange(max_len).unsqueeze(1)
        div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = th.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = th.sin(position * div_term)
        pe[0, :, 1::2] = th.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [173]:
#mha = MultiheadAttention(1, embed_dim=10, num_heads=2, num_chunks = 10).to('cuda')
eb = EncoderBlock(input_dim=2, num_heads=1, dim_feedforward=10, batched=True, last_n=None, num_chunks=1)

In [182]:
class LongHorizonLearner:
    def __init__(self, encoder_block:EncoderBlock) -> None:
        self.encoder_block = encoder_block

    def forward(self, inpt, obsv):
        x = th.cat((inpt, obsv), dim = 1)
        output = self.encoder_block.forward(x)
        return output

    def learn_seqence(self, sequence:th.Tensor, label:th.Tensor):
        horizon = th.zeros(sequence.size(0), 1, self.encoder_block.input_dim, device='cuda')
        optimizer = th.optim.Adam(self.encoder_block.parameters(), lr=1e-3)
        for j in range(10000):
            for i in range(len(sequence[0])):
                num_chunks = max(horizon.size(1)//10000, 1)
                #start = th.cuda.Event(enable_timing=True)
                #end = th.cuda.Event(enable_timing=True)
                self.encoder_block.self_attn.num_chunks = num_chunks            
                #start.record()
                horizon = self.forward(horizon.detach(), sequence[:,i].unsqueeze(1))
                #end.record()
                horizon = th.cat((horizon, label[:,i].unsqueeze(1)), dim=1)
                loss = ((horizon[:,-2] - horizon[:,-1])**2)
                loss = loss.mean()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (j*len(sequence[0]) + i) % 1000 ==0:
                    print(loss)
                    print(num_chunks)
                    print(horizon.shape)
                    print('_____________________')

            


In [183]:
seq_len = 100
batch_size = 2
inpt_dim = 100
eb = EncoderBlock(input_dim=inpt_dim, num_heads=2, dim_feedforward=64, batched=False, last_n=1, num_chunks=1).to('cuda')
LHL = LongHorizonLearner(eb)
sequence = th.randint(0, 2, [batch_size,seq_len,inpt_dim], device='cuda', dtype=th.bool)
label = ~sequence
sequence = sequence.type(th.float)
label = label.type(th.float)


In [184]:
from ActiveCritic.tests.test_utils.utils import make_seq_encoding_data

In [185]:
inpt_seq, outpt_seq = make_seq_encoding_data(batch_size=2, seq_len=5, ntoken=inpt_dim, d_out=inpt_dim)

In [186]:
outpt_seq[:1]

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,


In [187]:
LHL.learn_seqence(sequence=inpt_seq, label=outpt_seq)

tensor(0.6391, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 4, 100])
_____________________
tensor(0.1533, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 3004, 100])
_____________________
tensor(0.1614, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 6004, 100])
_____________________
tensor(0.1188, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 9004, 100])
_____________________
tensor(0.1499, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 12004, 100])
_____________________
tensor(0.1449, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 15004, 100])
_____________________
tensor(0.1772, device='cuda:0', grad_fn=<MeanBackward0>)
1
torch.Size([2, 18004, 100])
_____________________
tensor(0.1513, device='cuda:0', grad_fn=<MeanBackward0>)
2
torch.Size([2, 21004, 100])
_____________________
tensor(0.1684, device='cuda:0', grad_fn=<MeanBackward0>)
2
torch.Size([2, 24004, 100])
_____________________
tensor(0.1518, device='cud

KeyboardInterrupt: 

In [None]:
inpt = th.ones([2,2,2])
result = eb.forward(inpt)

In [None]:
result.shape

In [None]:
seq_len = 1000
inpt = th.zeros([10,seq_len,1], device='cuda')
inpt[:5,0,0] = 1

result = th.zeros([10,seq_len,10], device='cuda')
result[:5,-1,0] = 1

In [None]:
res =mha.forward(inpt, batched=True, last_n=1)

In [None]:
res.shape

In [None]:
optimiter = th.optim.Adam(mha.parameters(), lr=1e-3)


In [None]:
num_epochs = 10
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
start_event.record()

# Run some things here


for i in range(num_epochs):

    erg = mha.forward(inpt, batched=True)

    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()

    if i%10==0:
        print(loss)
end_event.record()
th.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(elapsed_time_ms)

In [None]:
num_epochs = 10
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
start_event.record()

# Run some things here


for i in range(num_epochs):

    erg = mha.forward(inpt, batched=True)



    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()

    if i%10==0:
        print(loss)
end_event.record()
th.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(elapsed_time_ms)

In [None]:
num_epochs = 1000
optimiter = th.optim.Adam(mha.parameters(), lr=1e-3)

for i in range(num_epochs):
    h = time.perf_counter()
    erg = mha.forward(inpt, batched=False)
    print(time.perf_counter()-h)
    print('_______________________')
    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()
    if i%10==0:
        print(loss)

In [None]:
erg[0,-1,0]

In [None]:
result[0,-1,0]