<a href="https://colab.research.google.com/github/ahmedhisham73/AI_in_healthcare/blob/main/Transformerfromscratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch 
import torch.nn as nn

In [None]:
class SelfAttention(nn.Module):
   def __init__(self, embed_size, heads):
     super(SelfAttention,self).__init__()
     self.embed_size=embed_size
     self.heads=heads
     self.head_dim=embed_size//heads
     "Embedding size needs to be divisible by heads"
     assert(self.head_dim*heads=embed_size)
     self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
     self.keys=nn.linear(self.head_dim,self.head_dim,bias=False)
     self.queries=nn.linear(self.head_dim,self.head_dim,bias=False)
     fc_out=nn.linear(heads*self.head_dim,embed_size)


  def forward_pass(self,values,keys,queries,mask):

    #get number of training examples
    N=queries.shape[0]
    value_len,keys_len,queries_len=values.shape[1],keys.shape[1],queries.shape[1]
    #Split the embedding into self.heads different pieces
    values=values.reshape(N,value_len,self.heads,self.head_dim)
    keys=keys.reshape(N,keys_len,self.heads,self.head_dim)
    queries=queries.reshape(N,queries_len.self.heads,self.head_dim)
    energy=torch.einsum("nqhd,nkhd->nhqk",[queries,keys])
     # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        #energy matrix shape=(heads,N,keylen(source sentence),query len(target sentence ))
    if mask is not None:
       energy = energy.masked_fill(mask == 0, float("-1e20"))


     attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
      # attention shape: (N, heads, query_len, key_len)
     out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
     
  # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimension
     out=self.fc_out(out)
     return out 




class transformer_block(nn.module):
  def __init__(self,embed_size,heads,dropout,forward_expansion):
    super(transformer_block,self).__init__()
    self.attention=SelfAttention(embed_size,heads)
    self.norm1=nn.LayerNorm(embed_size)
    self.norm2=nn.LayerNorm(embed_size)
    self.feed_forward=nn.Sequential(
        nn.Linear(embed_size,forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size,embed_size),


    )
    self.dropout=nn.Dropout(dropout)



  def forward(self,values,keys,queries,mask):

    attention=SelfAttention(values,keys,queries,mask)
    x = self.dropout(self.norm1(attention + queries))
    forward = self.feed_forward(x)
     # Add skip connection, run through normalization and finally dropout
    out = self.dropout(self.norm2(forward + x))
    return out



class Encoder(nn.Module):
  def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
     super(Encoder, self).__init__()
     self.embed_size = embed_size
     self.device = device
     self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
     self.position_embedding = nn.Embedding(max_length,embed_size)




     self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )
     
     self.dropout = nn.Dropout(dropout)


     




