In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
try:
  import einops
  from einops import rearrange,reduce,repeat
except ImportError:
  os.system('pip install einops')
  from einops import rearrange,reduce,repeat
import math

In [None]:
class MHA(nn.Module):
  def __init__(self,dim,attention_dropout,num_heads):
    super().__init__()
    self.dim=dim
    self.attention_dropout=attention_dropout
    self.num_heads=num_heads

    self.q=nn.Linear(dim,dim)
    self.k=nn.Linear(dim,dim)
    self.v=nn.Linear(dim,dim)
    self.out=nn.Linear(dim,dim)
  def forward(self,x,position_emb,is_casual=False):
    H=self.num_heads
    #B T D
    assert position_emb is not None

    q=rearrange(self.q(x + position_emb),pattern="B T (D H) -> B H T D",H=H)
    k=rearrange(self.k(x + position_emb),pattern="B T (D H) -> B H T D",H=H)
    v=rearrange(self.v(x),pattern="B T (D H) -> B H T D",H=H)

    attn=F.scaled_dot_product_attention(q,k,v,is_causal=is_casual)
    attn=rearrange(tensor=attn,pattern="B H T D -> B T (H D)")
    attn=self.out(attn)

    return attn


In [None]:
class Cross_Attn(nn.Module):
  def __init__(self,dim,attention_dropout,num_heads):
    super().__init__()
    self.dim=dim
    self.attention_dropout=attention_dropout
    self.num_heads=num_heads

    self.k=nn.Linear(dim,dim)
    self.v=nn.Linear(dim,dim)
    self.q=nn.Linear(dim,dim)
    self.out=nn.Linear(dim,dim)

  def forward(self,kv,q,q_embedding,k_embedding):
    H=self.num_heads
    k=rearrange(tensor=self.k(kv + k_embedding),pattern="B T (D H) -> B H T D",H=H)

    v=rearrange(tensor=self.v(kv),pattern="B T (D H) -> B H T D",H=H)

    q=rearrange(tensor=self.q(q + q_embedding),pattern="B T (D H) -> B H T D",H=H)

    attn=F.scaled_dot_product_attention(q,k,v,is_causal=False)
    attn=rearrange(tensor=attn,pattern="B H T D -> B T (H D)")
    attn=self.out(attn)

    return attn



In [None]:
class MLP(nn.Module):
  def __init__(self,dim):
    super().__init__()
    self.dim=dim
    self.net=nn.Sequential(
        nn.Linear(dim,dim*2),
        nn.GELU(),
        nn.Linear(dim*2,dim)
    )
  def forward(self,x):
    return self.net(x)


In [None]:
class Add_Norm(nn.Module):
  def __init__(self,module,dim):
    super().__init__()
    self.module=module
    self.dim=dim
    self.ln=nn.LayerNorm(dim)

  def forward(self,x,*args,**kwargs):
    return x + self.module(self.ln(x),*args,**kwargs)

In [None]:
class Encoder_layer(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0.):
    super().__init__()
    self.MHA=Add_Norm(MHA(dim,attn_drop,num_heads=n_heads),dim)
    self.ffn=Add_Norm(MLP(dim),dim)

  def forward(self,x,position_emb):
    x=self.MHA(x,position_emb)
    x=self.ffn(x)
    return x

In [None]:
class Decoder_layer(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0.,first=False):
    super().__init__()
    self.MMHA= nn.Identity() if first else Add_Norm(MHA(dim,attn_drop,n_heads),dim)
    self.cross_attn=Add_Norm(Cross_Attn(dim,attn_drop,n_heads),dim)
    self.ffn=Add_Norm(MLP(dim),dim)

  def forward(self,dec_input,enc_input,
              q_embedding,k_embedding,k_d_embedding):

    dec_out=self.MMHA(dec_input,q_embedding,True)
    mlp_out=self.cross_attn(enc_input,q=dec_out,q_embedding=q_embedding,k_embedding=k_embedding)
    out=self.ffn(mlp_out)

    return out


In [None]:
class Sepsis_Transformer(nn.Module):

  def __init__(self,n_heads=4,n_encoder_layers=2,n_decoder_layers=2,attn_drop=0.,dim=128,feature_count=1,mean=False,num_classes=2):
    super().__init__()
    self.n_heads=n_heads
    self.n_encoder_layers=n_encoder_layers
    self.n_decoder_layers=n_decoder_layers
    self.attn_drop=attn_drop
    self.mean=False
    self.learnable_query=nn.Parameter(torch.zeros(feature_count,dim),requires_grad=True)

    self.position_embedding=nn.Parameter(torch.randn(feature_count,dim),requires_grad=True)
    self.encoder_network=nn.ModuleList([
        Encoder_layer(dim,n_heads) for _ in range(n_encoder_layers)
    ])

    self.decoder_network=nn.ModuleList([
        Decoder_layer(dim,n_heads) for _ in range(n_decoder_layers)
    ])

    self.classification= nn.Sequential(
        nn.Linear(dim,num_classes)
    )

  def forward(self,x):
    B,T,D=x.shape
    for i in range(self.n_encoder_layers):

      x=self.encoder_network[i](x,position_emb=self.position_embedding)

    encoder_out=x
    dec_in=repeat(self.learnable_query,pattern="T D -> B T D",B=B)
    for j in range(self.n_decoder_layers):
      """self,dec_input,enc_input,
              q_embedding,k_embedding,k_d_embedding"""
      dec_in=self.decoder_network[i](dec_input=dec_in,enc_input=encoder_out,
                                     q_embedding=dec_in,k_d_embedding=dec_in,k_embedding=self.position_embedding)

    final=dec_in[:,-1,:]


In [None]:
Features=18
model=Sepsis_Transformer(feature_count=Features)

In [None]:
model(torch.randn(1,Features,128)).shape

torch.Size([1, 18, 128])