In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
class TrueGELU(nn.Module):#PyTorch defines GELU as x*sigmoid(x), 
  def forward(self,x):
    return x*torch.sigmoid(1.702*x) 

In [0]:
def weights_init(m):
  if type(m)==nn.Linear:
    m.weight.data.normal_(mean=0.0,std=0.1/np.sqrt(self.ninp))
    m.bias.data.fill_(0.0)
  if type(m)==nn.Embedding:
    m.weight.data.normal_(mean=0.0,std=0.1/np.sqrt(self.ninp))
  if type(m)==nn.LayerNorm:
    m.weight.data.normal_(mean=0.0,std=0.1/np.sqrt(self.ninp))
    m.bias.data.fill_(0.01)

In [0]:
class Boom(nn.Module):
  def __init__(self,input_dim,blowup_dim=2048,dropout=0.1,shortcut=False):
    super(Boom,self).__init__()
    self.lin1=nn.Linear(input_dim,blowup_dim,bias=True)
    self.dropout=nn.Dropout(dropout) if dropout else None
    #self.true_gelu=TrueGELU()
    if shortcut==False:
      self.lin2=nn.Linear(blowup_dim,input_im)
    
  def forward(self,x):
    x_dim=x.shape[-1]
    x=nn.GELU(self.lin1(x)) 
    #x=self.true_gelu(x)
    x=self.dropout(X)
    if shortcut==False:
      x=self.lin2(x)
    else:
      x=torch.narrow(x,-1,0,x.shape[-1]//x_dim*x_dim) #trimming some stuff to make it possible to reshape into (...,x.shape[-1]//x_dim,x_dim)
      x=x.view(*x.shape[:-1],x.shape[-1]//x_dim,x_dim) #breaking the tensor x into equal chunks of size along the innermost dimesnion.
                                                       #index slicing can do the job too, but this is much faster, and achieves the same thing.
      x=x.sum(dim=-2)                                  #summing up these chunks
    return x

In [0]:
#class SimplifiedAttention(nn.Module):

In [0]:
class DefaultAttention(nn.Module):
  def __init__(self,nhidden,q=True,k=False,v=False,num_heads=1,dropout=0.0):
    super().__init__()
    self.mha=nn.MultiHeadedAttention(embed_dim=nhid,num_heads=num_heads,dropout=dropout)

  def forward(self,q,k,v,attn_mask=None):
    return self.mha(q,k,v,attn_mask)

In [0]:
class Block(nn.Module):
  def __init__(self,embed_dim,hidden_dim,num_heads=1,dropout=0.0,residual=True,simpified_attention=False):
    super().__init__()
    self.attention=None
    if simplified_attention==False:
      self.attention=DefaultAttention(embed_dim,num_heads=num_heads,dropout=dropout)
    #else:
    #  self.attention=SimplifiedAttention()
    self.boom=Boom(embed_dim,hidden_dim,dropout=dropout)
    self.dropout=nn.Dropout(dropout)
    self.gelu=nn.TrueGELU()
    self.residual=residual
    self.lstm=nn.LSTM(input_size=embed_dim,hidden_size=embed_dim)

    self.norm0=nn.LayerNorm(embed_dim,eps=1e-12) #must init multiple layer norm
    self.norm1=nn.LayerNorm(embed_dim,eps=1e-12) #modules because layer norm 
    self.norm2=nn.LayerNorm(embed_dim,eps=1e-12) #consists of trainable params
    self.norm3=nn.LayerNorm(embed_dim,eps=1e-12)
    self.norm4=nn.LayerNorm(embed_dim,eps=1e-12)

  def forward(self,h,pos_encoding,attn_mask,mem=None,hidden=None):
    h=self.norm0(h)
    x,new_hidden=self.lstm(h)
    h_dim=h.shape[-1]
    z=torch.narrow(x,-1,0,x.shape[-1]//h_dim*h_dim)
    z=z.view(*x.shape[:-1],x.shape[-1]//dim_h,dim_h)
    x=self.dropout(z).sum(dim=-2)
    if self.residual==True: #skip connection
      h=h+x
    else:
      h=x.float()
    
    attention_weights=None
    new_mem=[]
    h=self.norm1(h)
    mh=self.norm2(h)
    if mem is not None:
      k=torch.cat([mem,mh],dim=0)
    else:
      k=mh
    new_mem=k[-len(pos_encoding):] #positional encoding adds information about the relative position of tokens in the sequence of input data
    x,attention_weights=self.attention(q=h,k=k,v=k,attn_mask=attn_mask)
    x=self.dropout(x)
    h=x+h

    h,x=self.norm3(h),self.norm4(x)
    x=self.boom(x)
    x=self.drop(x)
    h=x+h

    return h,new_mem,new_hidden,attention_weights




In [0]:
class SHARNN(nn.Module):
  def __init__(self,num_embeddings,embed_dim,hidden_dim,num_layers,dropout=0.5)
  super().__init__()
  self.embed_dim=embed_dim
  self.hidden_dim=hidden_dim
  self.num_layers=num_layers
  self.num_embeddings=num_embeddings
  self.max_positions=5000
  self.num_heads=1
  self.dropout=nn.Dropout(dropout)
  
  self.block_list=nn.ModuleList() #functions just like a Python list, but to store any number of nn modules. Helps if the number of layers are passed as input
  for i in range(num_layers):
    block=Block(self.embed_dim,self.hidden_dim,num_heads=1,dropout=dropout,residual=False)
    self.block_list.append(block)
  self.positional_embedding=[0]*self.max_positions #usually the fastest way to initialise single valued lists
  self.encoder=nn.Embedding(num_embeddings,embed_dim)
  self.decoder = nn.Linear(embed_dim, num_embeddings)

  def forward(self,x,hidden=None,mems=None):
    encoding=self.encoder(x)
    h=self.dropout(encoding)

    if mems is not None:
      mems=[m[-(self.max_positions-len(h)):] for m in mems]
    total_length=len(x)+(len(mems[0]) if mems else 0)
    new_hidden=[]
    new_mems=[]
    attn_weights=[]
    attn_mask=torch.full((len(x),len(x)),-float('Inf'))
    attn_mask=attn_mask.to(device)
    attn_mask=torch.triu(attn_mask,diagonal=1)
    if mems:
      max_mems=max(len(m) for m in mems)
      z=torch.zeros((len(x),max_mems),device=h.device,dtype=h.dtype)
      attn_mask=torch.cat([z,attn_mask],dim=-1)

    for i,block in enumerate(self.blocks):
      memory=mems[i] if mems else None
      hid=hidden[i] if hidden else None
      h,nm,nh,weights=block(h,self.positional_encoding,attn_mask,memory,hid)
      new_hidden.append(nh)
      #attn_weights.append(weights)
      new_mems.append(nm)
    
    h=self.dropout(h)

    return h,new_hidden,new_mems