<a href="https://colab.research.google.com/github/ML-Hesham/Calestiniks-Gym/blob/main/Copy_of_tries.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from isab_pytorch import ISAB

# helpers

def exists(val):
    return val is not None

def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

# helper classes

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)

# adjacent attention class

class AdjacentAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 4,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.null_k = nn.Parameter(torch.randn(heads, dim_head))
        self.null_v = nn.Parameter(torch.randn(heads, dim_head))

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        adj_kv_indices,
        mask
    ):
        b, n, d, h = *x.shape, self.heads
        flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h)

        # derive query, key, value
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # gather keys and values according to adjacency matrix
        k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
        k = batched_index_select(k, flat_indices)
        v = batched_index_select(v, flat_indices)
        k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v))

        # add null key / value, so a node can attend to nothing
        # have come across this in GNN literature as some other name
        nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)
        mask = F.pad(mask, (1, 0), value = 1)

        # similarity of each node to its neighbors
        sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale

        # mask out neighbors that are just padding
        mask_value = -torch.finfo(sim.dtype).max
        mask = rearrange(mask.bool(), 'b n a -> b () n a')
        sim.masked_fill_(~mask.bool(), mask_value)

        # attention
        attn = sim.softmax(dim = -1)

        # dropout
        attn = self.dropout(attn)

        # get weighted average of the values of all neighbors
        out = einsum('b h n a, b h n a d -> b h n d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        # combine output
        return self.to_out(out)

# adjacent network (layers of adjacent attention)

class AdjacentAttentionNetwork(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 4,
        num_neighbors_cutoff = None,
        num_global_nodes = 0,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, ISAB(
                dim = dim,
                heads = heads,
                num_induced_points = num_global_nodes
            )) if num_global_nodes > 0 else None

            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, AdjacentAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag # nodes should pay attention itself (self-interacting)

        # zero out points on adjacency matrix
        # where the nodes are just padding
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()

        # if we don't set a hard limit to the number of neighbors:
        #   - get the maximum number of neighbors and pad the rest of the nodes with less than that number of neighbors
        # else:
        #   - randomly sample the cutoff number of neighbors for any node that exceeds the max
        #   - this would be similar to random sparse attention (bigbird)

        # get the maximum number of neighbors
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            # to randomly sample the neighbors, add a small uniform noise to the mask and topk
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise

            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)

            # cast the mask back to 0s and 1s
            adj_mask = (adj_mask > 0.5).float()
        else:
            # todo - get distribution of number of neighbors, and strategically break up attention (message passing) to multiple steps
            #      - start with a bimodal num neighbors test case, then generalize

            # use topk to get all the neighbors
            # also pass the mask into the attention, as some neighbors will be just padding and not actually neighbors
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)


        for attn, global_attn, ff in self.layers:
            x = attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )

            if exists(global_attn):
                out, _ = global_attn(x, mask = mask)
                x = x + out

            x = ff(x)

        return x

### embedding


In [None]:
import numpy as np
import torch
from torch import nn


class PositionalEmbedding(nn.Module):
    def __init__(self,
                 max_seq_len,
                 dim_m,
                 vocab_size,
                 emb_size=None,
                 embeddings=None):
        """Embeddings with positional encoding.

        Args:
            max_seq_len (int): Max length of the sequence.
            dim_m (int): Model dimension.
            vocab_size (int): Vocabulary size.
            emb_size (int, optional): Embedding size. You do not need to specify a value if you are using
              embedding weights.
            embeddings (torch.Tensor, optional): Tensor `(vocab_size, emb_size)` of embeddings weights. Embedding size
              value would inherited from shape of this tensor.

        Inputs:
            - **input**: Long tensor of shape `(batch, seq_len)` - input sequence.

        Outputs:
            - **output**: Float tensor of shape `(batch, seq_len, emb_size)` - output sequence.

        Notes:
            - Model dimension and embedding size haven't to be equal. There is an alignment layer, that project
              embedding to model size.
        """
        super(PositionalEmbedding, self).__init__()
        self.max_seq_len = max_seq_len + 1
        self.dim_m = dim_m

        self.positional = nn.Embedding(max_seq_len + 1, dim_m, padding_idx=0)

        if embeddings is None:
            self.embedding = nn.Embedding(vocab_size, emb_size)
        else:
            emb_size = embeddings.shape[1]
            self.embedding = nn.Embedding(*embeddings.shape)
            self.embedding.weight = nn.Parameter(
                embeddings, requires_grad=False)

        self.alignment = nn.Linear(emb_size, dim_m, bias=False)

        self.reset_parameters()

    def forward(self, input):
        mask = input == 0

        pos_mask = self.position_mask(input)
        pos_mask.masked_fill_(mask, 0)
        return self.alignment(
            self.embedding(input)) + self.positional(pos_mask)

    def reset_parameters(self):
        # Lookup table for position codes: (max_seq_len, dim_m)
        weights = [
            self.sin_position_scale(i, np.arange(0, self.dim_m))
            for i in range(self.max_seq_len)
        ]
        weights = np.stack(weights)
        weights[1:, ::2] = np.sin(weights[1:, ::2])
        weights[1:, 1::2] = np.cos(weights[1:, 1::2])
        self.positional.weight = nn.Parameter(
            torch.tensor(weights, dtype=torch.float), requires_grad=False)

    def sin_position_scale(self, pos, i):
        """Position scaling :math:`pos/10000^{i*dim_m}` for Sinusoidal Positional Encoding.

        Args:
            pos (int): Position index.
            i (numpy.ndarray): Dimension indexes.

        Returns:
            float: Scaled value.
        """
        return pos / np.power(1e4, i / self.dim_m)

    @staticmethod
    def position_mask(tensor):
        """Generate position mask for tensor.

        Args:
            tensor (torch.tensor): a float tensor of shape `(batch_size, seq_len, *)`.

        Returns:
            torch.tensor: an int tensor of word positions.

        """
        # Maybe it would be more productive to use a global buffer of positions `(max_batch_size, max_seq_len)`
        # and get a mask from this buffer using slicing.
        batch_size, seq_len = tensor.shape
        mask = torch.arange(
            1, seq_len + 1, dtype=torch.long, device=tensor.device).repeat(
                batch_size, 1)

        return mask

### position_wise


In [None]:
from torch import nn


class PositionWise(nn.Module):
    def __init__(self, dim_m, dim_i, dropout=0.1):
        """Position-wise Feed-Forward Network.

        Args:
            dim_m (int): input and output dimension.
            dim_i (int): inner dimension.
            dropout (float, optional): dropout probability.

        Inputs:
            - **input** of shape `(batch, *, dim_m)`: a float tensor.

        Outputs:
            - **output** of shape `(batch, *, dim_m)`: a float tensor.
        """
        super(PositionWise, self).__init__()

        self.feedforward = nn.Sequential(
            nn.Linear(dim_m, dim_i), nn.ReLU(), nn.Linear(dim_i, dim_m),
            nn.Dropout(dropout))
        self.normalization = nn.LayerNorm(dim_m, eps=1e-12)

    def forward(self, input):
        # There's nothing difficult here.
        residual = input
        output = self.feedforward(input)
        output = self.normalization(output + residual)
        return output

### transformer

In [None]:
import torch
from torch import nn




class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim_m, dim_q_k, dim_v, n_heads, dim_i, dropout):
        """Transformer encoder layer.

        Args:
            dim_m (int): Dimension of model.
            dim_q_k (int): Dimension of `query` & `key` attention projections.
            dim_v (int): Dimension of `value` attention projection.
            n_heads (int): Number of attention heads.
            dim_i (int): Inner dimension of feed-forward position-wise sublayer.
            dropout (float): Dropout probability.

        Inputs:
            - **input** of shape `(batch, enc_seq_len, dim_m)`, a float tensor, where `batch` is batch size,
              `enc_seq_len` is length of encoder sequence for this batch and `dim_m` is hidden size of model.
              Input embedding has `dim_m` size too.

        Outputs:
            - **output** of shape `(batch, seq_len, dim_m)`, a float tensor.
        """
        super(TransformerEncoderLayer, self).__init__()

        self.attention = AdjacentAttentionNetwork(n_heads, dim_m, dim_q_k, dim_v,
                                            dropout)
        self.positionwise = PositionWise(dim_m, dim_i, dropout)

    def forward(self, input, mask=None,return_attn_weight=False):
        if(return_attn_weight):
            enc_att,attn_weight=self.attention(input, input, input, mask=mask,return_attn_weight=True)
            output = self.positionwise(enc_att)
            return output,attn_weight
        else:
            enc_att = self.attention(input, input, input, mask=mask,return_attn_weight=False)
            output = self.positionwise(enc_att)
            return output


class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim_m, dim_q_k, dim_v, n_heads, dim_i, dropout):
        """Transformer decoder layer.

        Args:
            dim_m (int): Dimension of model.
            dim_q_k (int): Dimension of `query` & `key` attention projections.
            dim_v (int): Dimension of `value` attention projection.
            n_heads (int): Number of attention heads.
            dim_i (int): Inner dimension of feed-forward position-wise sublayer.
            dropout (float): Dropout probability.

        Inputs:
            - **input** of shape `(batch, dec_seq_len, dim_m)`, a float tensor, where `batch` is batch size,
              `dec_seq_len` is length of decoder sequence for this batch and `dim_m` is hidden size of model.
              Input embedding has `dim_m` size too.
            - **encoder_output** of shape `(batch, enc_seq_len, dim_m)`, a float tensor, where `enc_seq_len` is length
              of encoder sequence.
            - **mask** of shape `(batch, dec_seq_len, dec_sec_len)`, a byte tensor containing mask for
              illegal connections between encoder and decoder sequence tokens. It's used to preserving
              the auto-regressive property.

        Outputs:
            - **output** of shape `(batch, dec_seq_len, dim_m)`, a float tensor.
        """
        super(TransformerDecoderLayer, self).__init__()

        self.masked_attention = AdjacentAttentionNetwork(n_heads, dim_m, dim_q_k,
                                                   dim_v, dropout)
        self.attention = AdjacentAttentionNetwork(n_heads, dim_m, dim_q_k, dim_v,
                                            dropout)
        self.positionwise = PositionWise(dim_m, dim_i, dropout)

    def forward(self, input, encoder_output, mask, extra_mask=None):
        dec_att = self.masked_attention(input, input, input, mask)
        adj_att = self.attention(
            value=encoder_output, key=encoder_output, query=dec_att,mask=extra_mask)
        output = self.positionwise(adj_att)

        return output


class Transformer(nn.Module):
    def __init__(self,
                 max_seq_len,
                 vocab_size,
                 emb_size=250,
                 embeddings=None,
                 n_layers=6,
                 dim_m=512,
                 dim_q_k=64,
                 dim_v=64,
                 n_heads=8,
                 dim_i=2048,
                 dropout=0.1):
        """Transformer model from 'Attention Is All You Need' paper.

        Args:
            max_seq_len (int): Maximum sequence length.
            vocab_size (int): Vocabulary size.
            emb_size (int, optional): Embedding size. You do not need to specify a value if you are using
              embedding weights.
            embeddings (torch.Tensor, optional): Long tensor of shape `(vocab_size, emb_size)` - embedding tensor.
              Embedding size value would inherited from shape of this tensor.
            n_layers (int, optional): Number of transformer layers.
            dim_m (int, optional): Model hidden size, must be equal with embedding size.
            dim_q_k (int, optional): Dimension of `query` & `key` attention projections.
            dim_v (int, optional): Dimension of `value` attention projection.
            n_heads (int, optional): Number of attention heads.
            dim_i (int, optional): Inner dimension of feed-forward position-wise sublayer.
            dropout (float, optional): Dropout probability.

        Variables:
            - **encoder_state**: a float tensor of shape `(batch, enc_seq_len, dim_m)` containing encoder state from
              last layer.

        Inputs:
            - **enc_seq** of shape `(batch, enc_seq_len)`, a long tensor encoder input sequence.
            - **dec_seq** of shape `(batch, dec_seq_len)`, a long tensor decoder input sequence.

        Outputs:
            - **output** of of shape `(batch, dec_seq_len, vocab_size)`, a float tensor of vocabulary probability
              distribution.

        Notes:
            - For optimizing model, encoder state stores in local variable and calculate only one per batch. After
              auto-regressive process encoder state must be reset. You can do this using
              :func:`Transformer.reset_encoder_state`.
        """
        super(Transformer, self).__init__()

        self.positional_encoding = PositionalEmbedding(
            max_seq_len, dim_m, vocab_size, emb_size, embeddings)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(dim_m, dim_q_k, dim_v, n_heads, dim_i,
                                    dropout) for i in range(n_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(dim_m, dim_q_k, dim_v, n_heads, dim_i,
                                    dropout) for i in range(n_layers)
        ])
        # I think it's better to use smooth transition from dim_m to vocab_size
        self.out = nn.Sequential(
            nn.Linear(dim_m, vocab_size),
            # nn.ReLU(),
            # nn.Linear(7000, vocab_size),
        )
        self.softmax = nn.Softmax(-1)

        self.encoder_state = None

    def forward(self, enc_seq, dec_seq):
        # Calculate encoder state for batch.
        if self.encoder_state is None:
            # Sum embeddings with positional encodings.
            self.encoder_state = self.positional_encoding(enc_seq)

            for enc_layer in self.encoder_layers:
                self.encoder_state = enc_layer(self.encoder_state)

        # Decoder block.
        # Apply positional encoding.
        dec_state = self.positional_encoding(dec_seq)

        mask = self.autoregressive_mask(dec_seq)

        for dec_layer in self.decoder_layers:
            dec_state = dec_layer(dec_state, self.encoder_state, mask)

        output = self.out(dec_state)

        return output

    def reset_encoder_state(self):
        """Reset previous encoder state of batch. This method must calls before process new batch.
        """
        self.encoder_state = None

    @staticmethod
    def autoregressive_mask(tensor):
        """Generate auto-regressive mask for tensor. It's used to preserving the auto-regressive property.

        Args:
            tensor (torch.Tensor): of shape `(batch, seq_len, dim)`.

        Returns:
            torch.Tensor: a byte mask tensor of shape `(batch, seq_len, seq_len)` containing mask for
            illegal attention connections between decoder sequence tokens.

        """
        batch_size, seq_len = tensor.shape
        x = torch.ones(
            seq_len, seq_len, device=tensor.device).tril(-1).transpose(0, 1)

        return x.repeat(batch_size, 1, 1).byte()