In [None]:
  import torch
import torch.nn as nn
import torch.nn.functional as F
import os
try:
  import einops
  from einops import rearrange,reduce,repeat
except ImportError:
  os.system('pip install einops')
  from einops import rearrange,reduce,repeat
import math
from torch.utils.data import DataLoader,Dataset,Sampler,WeightedRandomSampler
from torch.nn.utils.rnn import pad_packed_sequence,pack_sequence
import numpy as np
from tqdm.notebook import tqdm

In [None]:
class MHA(nn.Module):
  def __init__(self,dim,attention_dropout,num_heads):
    super().__init__()
    self.dim=dim
    self.attention_dropout=attention_dropout
    self.num_heads=num_heads

    self.q=nn.Linear(dim,dim)
    self.k=nn.Linear(dim,dim)
    self.v=nn.Linear(dim,dim)
    self.out=nn.Linear(dim,dim)
  def forward(self,x,position_emb,padding_mask,is_casual=False):
    H=self.num_heads
    #B T D
    assert position_emb is not None
    assert padding_mask is not None

    q=rearrange(self.q(x + position_emb),pattern="B T (D H) -> B H T D",H=H)
    k=rearrange(self.k(x + position_emb),pattern="B T (D H) -> B H T D",H=H)
    v=rearrange(self.v(x),pattern="B T (D H) -> B H T D",H=H)

    attn=F.scaled_dot_product_attention(q,k,v,attn_mask=padding_mask,is_causal=is_casual)
    attn=rearrange(tensor=attn,pattern="B H T D -> B T (H D)")
    attn=self.out(attn)

    return attn


In [None]:
class Cross_Attn(nn.Module):
  def __init__(self,dim,attention_dropout,num_heads):
    super().__init__()
    self.dim=dim
    self.attention_dropout=attention_dropout
    self.num_heads=num_heads

    self.k=nn.Linear(dim,dim)
    self.v=nn.Linear(dim,dim)
    self.q=nn.Linear(dim,dim)
    self.out=nn.Linear(dim,dim)

    self.norm=nn.LayerNorm(dim)

  def forward(self,kv,q,q_embedding,k_embedding):
    H=self.num_heads

    k=rearrange(tensor=self.k(kv + k_embedding),pattern="B T (D H) -> B H T D",H=H)

    v=rearrange(tensor=self.v(kv),pattern="B T (D H) -> B H T D",H=H)

    q=rearrange(tensor=self.q(q + q_embedding),pattern="B T (D H) -> B H T D",H=H)
    attn=F.scaled_dot_product_attention(q,k,v,is_causal=False)
    attn=rearrange(tensor=attn,pattern="B H T D -> B T (H D)")
    attn=self.out(attn)

    return rearrange(tensor=q,pattern="B H T D-> B T (H D)")+self.norm(attn)



In [None]:
class MLP(nn.Module):
  def __init__(self,dim):
    super().__init__()
    self.dim=dim
    self.net=nn.Sequential(
        nn.Linear(dim,dim*2),
        nn.GELU(),
        nn.Linear(dim*2,dim)
    )
  def forward(self,x):
    return self.net(x)


In [None]:
class Add_Norm(nn.Module):
  def __init__(self,module,dim):
    super().__init__()
    self.module=module
    self.dim=dim
    self.ln=nn.LayerNorm(dim)

  def forward(self,x,*args,**kwargs):
    return x + self.ln(self.module(x,*args,**kwargs))

In [None]:
class Encoder_layer(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0.):
    super().__init__()
    self.MHA=Add_Norm(MHA(dim,attn_drop,num_heads=n_heads),dim)
    self.ffn=Add_Norm(MLP(dim),dim)

  def forward(self,x,position_emb,padding_mask):
    x=self.MHA(x,position_emb,padding_mask)
    x=self.ffn(x)
    return x

In [None]:
class Decoder_layer(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0.,first=False):
    super().__init__()
    self.MMHA= nn.Identity() if first else Add_Norm(MHA(dim,attn_drop,n_heads),dim)
    self.cross_attn = Cross_Attn(dim,attn_drop,n_heads)
    self.ffn=Add_Norm(MLP(dim),dim)

  def forward(self,dec_input,enc_input,
              q_embedding,k_embedding,k_d_embedding):

    dec_out=self.MMHA(dec_input,q_embedding,True)
    mlp_out=self.cross_attn(enc_input,q=dec_out,q_embedding=q_embedding,k_embedding=k_embedding)
    out=self.ffn(mlp_out)

    return out


In [None]:
class MMHA_normal(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0):
    super().__init__()
    self.dim=dim
    self.attention_dropout=attn_drop
    self.num_heads=n_heads

    self.q=nn.Linear(dim,dim)
    self.k=nn.Linear(dim,dim)
    self.v=nn.Linear(dim,dim)
    self.out=nn.Linear(dim,dim)

  def forward(self,x):
    H=self.num_heads

    k=rearrange(tensor=self.k(x),pattern="B T (D H) -> B H T D",H=H)

    v=rearrange(tensor=self.v(x),pattern="B T (D H) -> B H T D",H=H)

    q=rearrange(tensor=self.q(x),pattern="B T (D H) -> B H T D",H=H)
    attn=F.scaled_dot_product_attention(q,k,v,is_causal=True)
    attn=rearrange(tensor=attn,pattern="B H T D -> B T (H D)")
    attn=self.out(attn)

    return attn
class Decoder_block_normal(nn.Module):
  def __init__(self,dim,n_heads,attn_drop=0.,first=False):
    super().__init__()
    self.MMHA= nn.Identity() if first else Add_Norm(MMHA_normal(dim,n_heads,attn_drop),dim)
    self.ffn=Add_Norm(MLP(dim),dim)

  def forward(self,x):
    x=self.MMHA(x)
    x=self.ffn(x)
    return x

In [None]:
class Embedding_layer(nn.Module):

  def __init__(self,
               num_features,
               tokens_count,
               dim,
               padding_index=0,
               learnable_timestep=False,time_step_count=335,device="cuda"):
    super().__init__()
    self.device=device
    self.embedding_layer=nn.Embedding(tokens_count+1,embedding_dim=dim,padding_idx=padding_index)

    self.learnable_timestep=learnable_timestep
    if learnable_timestep:
      self.time_step_embedding=nn.Embedding(time_step_count,embedding_dim=dim,padding_idx=padding_index)

    self.conv1=nn.Conv1d(num_features,1,kernel_size=1)

    self.apply(self._init_weight)

  def forward(self,x):
    features=self.embedding_layer(x)

    B,T,F,D=features.shape
    features=rearrange(tensor=features,pattern="B T F D-> (B T) F D")

    features=self.conv1(features)
    features= rearrange(tensor=features,pattern="(B T) F D-> B (T F) D",B=B,T=T,F=1)

    if self.learnable_timestep:
      time_step=self.time_step_embedding(x) #TODO implement
    else:
      time_step=self.Sinusoidal_position_Emdedding(length=T,d_model=D).to(self.device)

    return features + time_step,time_step

  def _init_weight(self,module):

    if isinstance(module,nn.Embedding):
      std=0.02
      torch.nn.init.normal_(module.weight, mean=0.0, std=std)

    if isinstance(module,nn.Conv1d):
      module.weight= nn.init.normal_(module.weight,mean=0.0,std=0.02)
      if module.bias is not None:
        module.bias=nn.init.zeros_(module.bias)

  @staticmethod
  def Sinusoidal_position_Emdedding(length=1024,d_model=768,learnable=False):
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)

    return nn.Parameter(pe) if learnable else pe


In [None]:
class Sepsis_Transformer_decoderonly(nn.Module):
  def __init__(self, num_features,
               tokens_count,
               padding_index=0,
               n_heads=4,
               n_layers=2,
               attn_drop=0.,
               dim=128,
               num_classes=2,device="cuda"):
    super().__init__()
    self.device=device
    self.n_heads=n_heads
    self.n_layers=n_layers
    self.attn_drop=attn_drop

    self.embedding_layer=Embedding_layer(
               num_features,
               tokens_count,
               dim,
               padding_index,device=self.device).to(self.device)

    self.decoder_network=nn.ModuleList([
          Decoder_block_normal(dim,n_heads) for _ in range(n_layers)
      ]).to(self.device)


    self.classification= nn.Sequential(
          nn.Linear(dim,num_classes)
      ).to(self.device)
    self.apply(self._init_layers)

  def forward(self,x):

    x,_=self.embedding_layer(x)

    B,T,D=x.shape

    for i in range(self.n_layers):
      x = self.decoder_network[i](x)

    x=x.view(-1,D)

    x=self.classification(x)

    return x

  def _init_layers(self,module):
    if isinstance(module, nn.Linear):
            std = 0.02
            std*=(2*self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)



In [None]:
class Sepsis_Transformer(nn.Module):

  def __init__(self,
               num_features,
               tokens_count,
               padding_index=0,
               n_heads=4,
               n_encoder_layers=2,
               n_decoder_layers=2,
               attn_drop=0.,
               dim=128,
               mean=False,
               num_classes=2):

    super().__init__()

    self.n_heads=n_heads
    self.n_encoder_layers=n_encoder_layers
    self.n_decoder_layers=n_decoder_layers
    self.attn_drop=attn_drop
    self.mean=mean
    self.embedding_layer=Embedding_layer(
               num_features,
               tokens_count,
               dim,
               padding_index)
    self.learnable_query=nn.Parameter(torch.zeros(20,dim),requires_grad=True)

    self.encoder_network=nn.ModuleList([
        Encoder_layer(dim,n_heads) for _ in range(n_encoder_layers)
    ])

    self.enc_layers=n_encoder_layers

    self.decoder_network=nn.ModuleList([
        Decoder_layer(dim,n_heads) for _ in range(n_decoder_layers)
    ])

    self.dec_layers=n_decoder_layers

    self.classification= nn.Sequential(
        nn.Linear(dim,num_classes)
    )

    self.encoder_network.apply(self._init_enc_layers)
    self.decoder_network.apply(self._init_dec_layers)


  def forward(self,x, padding_mask , final_tokens=None):
    x,pos_emb=self.embedding_layer(x)

    B,T,D=x.shape

    for i in range(self.n_encoder_layers):

      x=self.encoder_network[i](x,position_emb=pos_emb,padding_mask=padding_mask)

    encoder_out=x
    dec_in=repeat(self.learnable_query,pattern="T D -> B T D",B=B)
    for j in range(self.n_decoder_layers):
      """self,dec_input,enc_input,
              q_embedding,k_embedding,k_d_embedding"""
      dec_in=self.decoder_network[i](dec_input=dec_in,enc_input=encoder_out,
                                     q_embedding=dec_in,k_d_embedding=dec_in,k_embedding=pos_emb)

    if not self.mean:
      final=final(-1,D)
    else:
      final=dec_in.mean(dim=1)

    classes = self.classification(final)

    return classes

  def _init_enc_layers(self,module):
    if isinstance(module, nn.Linear):
            std = 0.02
            std*=(2*self.enc_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

  def _init_dec_layers(self,module):
    if isinstance(module,nn.Linear):
            std = 0.02
            std*=(2*self.dec_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)


In [None]:
d={'HR': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
 'O2Sat': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
 'Temp': [0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
 'MAP': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
 'Resp': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
 'BUN': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
 'Chloride': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
 'Creatinine': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
 'Glucose': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
 'Hct': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
 'Hgb': [0, 1, 2, 3, 4, 6],
 'WBC': [0, 1, 2, 3, 4, 5, 6, 7, 9],
 'Platelets': [0, 1, 2, 3, 4, 5, 6, 7, 9],
 'Age': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'Gender': [0, 1]}

In [None]:
tokens={
    'padding_index':0,
}

index=1
for i in d.keys():
  for j in range(len(d[i])):
    if j<len(d[i])-1:
      tokens[f"{i}_{j}"]=index
    else:
      if i!="Gender":
        tokens[f"{i}_nan"]=index
      else:
        tokens[f"{i}_{j}"]=index

    index+=1

In [None]:
tokens

{'padding_index': 0,
 'HR_0': 1,
 'HR_1': 2,
 'HR_2': 3,
 'HR_3': 4,
 'HR_4': 5,
 'HR_5': 6,
 'HR_6': 7,
 'HR_7': 8,
 'HR_8': 9,
 'HR_nan': 10,
 'O2Sat_0': 11,
 'O2Sat_1': 12,
 'O2Sat_2': 13,
 'O2Sat_3': 14,
 'O2Sat_4': 15,
 'O2Sat_nan': 16,
 'Temp_0': 17,
 'Temp_1': 18,
 'Temp_2': 19,
 'Temp_3': 20,
 'Temp_4': 21,
 'Temp_5': 22,
 'Temp_nan': 23,
 'MAP_0': 24,
 'MAP_1': 25,
 'MAP_2': 26,
 'MAP_3': 27,
 'MAP_4': 28,
 'MAP_5': 29,
 'MAP_6': 30,
 'MAP_7': 31,
 'MAP_8': 32,
 'MAP_9': 33,
 'MAP_nan': 34,
 'Resp_0': 35,
 'Resp_1': 36,
 'Resp_2': 37,
 'Resp_3': 38,
 'Resp_4': 39,
 'Resp_5': 40,
 'Resp_6': 41,
 'Resp_nan': 42,
 'BUN_0': 43,
 'BUN_1': 44,
 'BUN_2': 45,
 'BUN_3': 46,
 'BUN_4': 47,
 'BUN_5': 48,
 'BUN_6': 49,
 'BUN_7': 50,
 'BUN_nan': 51,
 'Chloride_0': 52,
 'Chloride_1': 53,
 'Chloride_2': 54,
 'Chloride_3': 55,
 'Chloride_4': 56,
 'Chloride_5': 57,
 'Chloride_nan': 58,
 'Creatinine_0': 59,
 'Creatinine_1': 60,
 'Creatinine_2': 61,
 'Creatinine_3': 62,
 'Creatinine_4': 63,
 'Cre

In [None]:
def impute_row(val,row):
  if val==0.0:
    return get_val(tokens,f"{row}_{0}")
  elif val==1.0:
    return get_val(tokens,f"{row}_{1}")
  elif val==2.0:
    return get_val(tokens,f"{row}_{2}")
  elif val==3.0:
    return get_val(tokens,f"{row}_{3}")
  elif val==4.0:
    return get_val(tokens,f"{row}_{4}")
  elif val==5.0:
    return get_val(tokens,f"{row}_{5}")
  elif val==6.0:
    return get_val(tokens,f"{row}_{6}")
  elif val==7.0:
    return get_val(tokens,f"{row}_{7}")
  elif val==8.0:
    return get_val(tokens,f"{row}_{8}")
  elif val==9.0:
    return get_val(tokens,f"{row}_{9}")
  else:
    return tokens[f'{row}_nan']

def get_val(tokens,val):
  try:
    return tokens[val]
  except:
    row=val.split('_')[0]
    return tokens[f"{row}_nan"]
def impute_logic(df):
  for i in df.columns:
    df[i]=df[i].apply(impute_row,row=i)


In [None]:
class sepsis_dataset(Dataset):

  def __init__(self,df,idx):

    self.df=df

    self.idx=idx


  def __len__(self):
    return len(self.idx)

  def __getitem__(self,ids):

    id=self.idx[ids]


    current_df=self.df[self.df['Patient_ID']==id]


    current_df=current_df.sort_values("Hour")

    current_df=current_df.drop(["Hour","Patient_ID"],axis=1)


    X = current_df.drop('SepsisLabel',axis=1)
    impute_logic(X)

    y = current_df['SepsisLabel'].to_numpy()

    return torch.from_numpy(X.to_numpy()).long().cuda(),torch.from_numpy(y).long().cuda()


In [None]:
def train_on_epoch(model,optimizer,train_loader,device,current_epoch):
  running_loss=0.
  accuracy=0.

  for idx,(x_pad,y_pad,x_len,y_len) in enumerate(train_loader):
    x_mask=create_encoder_masks(x_len)
    y_mask=create_mask(y_len)
    with torch.autocast(device_type=device,dtype=torch.bfloat16):

      y_pred=model(x_pad,x_mask)

      """Evaluvate each time step"""

      B,D=y_pred.shape
      y_pred=y_pred.view(-1,D)
      y=y_pred.view(-1,)
      prediction=F.cross_entropy(y_pred,y)

      loss = prediction * y_mask

    loss.backward()

    running_loss+= loss.item()

    class_labels=y_pred.argmax(dim=-1)

    total_correct= (class_labels==y_pred).sum()

    accuracy += total_correct / y_pred.shape[0]

    if idx%100==0:
      print(f"""
        EPOCH : {current_epoch} , \n
        idx : {idx}, \n
        loss:{loss.item()}, \n
        accuracy: {total_correct / y_pred.shape[0]} \n
      """)

  return running_loss, accuracy





In [None]:
def validate_one_epoch(model,validation_loader,device,current_epoch):
  model.eval()
  running_loss=0.
  acc=[]
  with torch.no_grad():
    for idx,(x_pad,y_pad,x_len,y_len)  in enumerate(tqdm(validation_loader,total=len(validation_loader))):
      B,T,_=x_pad.shape
      with torch.autocast(device_type=device,dtype=torch.bfloat16):

        y_pred= model(x_pad)

        _,D=y_pred.shape
        loss= F.cross_entropy(y_pred.view(-1,D),y_pad.view(-1,),reduction='none')
        mask=create_mask(y_len).cuda()

        masked_loss=loss.view(B,-1)*mask
        avg_loss = torch.mean(masked_loss.sum(dim=1)/ mask.sum(dim=1))

      running_loss+=avg_loss.item()

      y_pred=rearrange(tensor=y_pred,pattern="(B T) D -> B T D",B=B,T=T)

      y_pred=y_pred.argmax(dim=-1)

      total_sum=0
      total=0
      for i in range(len(y_pred)):
        total_sum += (y_pred[i,:y_len[i]]== y_pad[i,:y_len[i]]).sum()
        total+=y_len[i]

      acc.append((total_sum/total).item())
      if idx%100==0:
        s=f"""
            EPOCH : {current_epoch} , \n
            idx : {idx}, \n
            loss:{avg_loss.item()}, \n
            accuracy: {total_sum/total} \n
          """
        mode='w' if idx==0 else 'a'
        with open(f"ValLog{current_epoch}.txt",mode) as f:
          f.write(s)

    return running_loss/len(train_loader), sum(acc)/len(acc)

In [None]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
  (xx,yy) = zip(*batch)
  x_lens=[len(xx[i]) for i in range(len(xx))]
  y_lens=[len(yy[i]) for i in range(len(yy))]
  xx_pad = pad_sequence(xx,batch_first=True,padding_value=0)
  yy_pad = pad_sequence(yy,batch_first=True,padding_value=0)

  return xx_pad,yy_pad,torch.tensor(x_lens,dtype=torch.long),torch.tensor(y_lens,dtype=torch.long)


In [None]:
def create_mask(y_lens):

  B=len(y_lens)

  max_length,_ = y_lens.max(dim=-1)

  mask=torch.zeros(B,max_length)

  for i in range(B):
    mask[i,:y_lens[i]] = torch.ones(y_lens[i])

  return mask

def create_encoder_masks(x_lens,D):

  B=len(x_lens)

  max_length,_=x_lens.max(dim=-1)

  mask=torch.zeros(B,max_length.item(),D)

  for i in range(B):
      mask[i,:x_lens[i],:] = torch.ones(x_lens[i],D)

  return mask.to(dtype=torch.bool)

In [None]:
model=Sepsis_Transformer_decoderonly(15,123,n_heads=4,attn_drop=0.1)

In [None]:
import gdown
import pandas as pd

# Define the Google Drive file IDs
file_ids = {
    "train_idx_df": "1JRWgOjcafwIw-FZ0KE2AN9end1LVSu4Q",
    "train_idxs": "1TYwxzVuiY6BgMObPruJxOY7_5ZrKKKys",
    "val_idx_df": "16hbvVRacYSXGPLysvlOXbD80FcWbYsFt",
    "val_idxs": "14MdOgJ0AAnYSHXGnTz812wYQkSlEpbPe"
}

# Download the files
for name, file_id in file_ids.items():
    url = f"https://drive.google.com/uc?id={file_id}"
    output = f"{name}.csv"
    gdown.download(url, output, quiet=False)

# Read the downloaded files
train_idx_df = pd.read_csv("train_idx_df.csv")
train_idxs = pd.read_csv("train_idxs.csv")
val_idx_df = pd.read_csv("val_idx_df.csv")
val_idxs = pd.read_csv("val_idxs.csv")

Downloading...
From: https://drive.google.com/uc?id=1JRWgOjcafwIw-FZ0KE2AN9end1LVSu4Q
To: /content/train_idx_df.csv
100%|██████████| 2.00M/2.00M [00:00<00:00, 204MB/s]
Downloading...
From: https://drive.google.com/uc?id=1TYwxzVuiY6BgMObPruJxOY7_5ZrKKKys
To: /content/train_idxs.csv
100%|██████████| 92.1k/92.1k [00:00<00:00, 60.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=16hbvVRacYSXGPLysvlOXbD80FcWbYsFt
To: /content/val_idx_df.csv
100%|██████████| 313k/313k [00:00<00:00, 83.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=14MdOgJ0AAnYSHXGnTz812wYQkSlEpbPe
To: /content/val_idxs.csv
100%|██████████| 34.0k/34.0k [00:00<00:00, 23.5MB/s]


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
vals=train_idxs['HasSepsis'].value_counts().values

In [None]:
class_weights=1.0- vals /sum(vals)

In [None]:
class_weights

array([0.16431063, 0.83568937])

In [None]:
weights=[]
for i in range(len(train_idxs)):
  patient=train_idxs.iloc[i]['Patient_ID']
  has_sepsis=train_idxs[train_idxs['Patient_ID']==patient]['HasSepsis'].values
  weights.append(class_weights[has_sepsis].tolist())

In [None]:
import torch
weight=torch.tensor(weights).cuda()

In [None]:
weight.shape

torch.Size([10894, 1])

In [None]:
sampler=WeightedRandomSampler(
    weights=weight.view(-1,),
    num_samples=len(train_idx_df),
    replacement=True
)

In [None]:
B=16
train_loader=DataLoader(
    sepsis_dataset(train_idx_df,train_idxs['Patient_ID'].values),batch_size=B,collate_fn=collate_fn,sampler=sampler
)

validation_loader=DataLoader(
    sepsis_dataset(val_idx_df,val_idxs['Patient_ID'].values),batch_size=B,collate_fn=collate_fn
)

In [None]:
model

Sepsis_Transformer_decoderonly(
  (embedding_layer): Embedding_layer(
    (embedding_layer): Embedding(124, 128, padding_idx=0)
    (conv1): Conv1d(15, 1, kernel_size=(1,), stride=(1,))
  )
  (decoder_network): ModuleList(
    (0-1): 2 x Decoder_block_normal(
      (MMHA): Add_Norm(
        (module): MMHA_normal(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (k): Linear(in_features=128, out_features=128, bias=True)
          (v): Linear(in_features=128, out_features=128, bias=True)
          (out): Linear(in_features=128, out_features=128, bias=True)
        )
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ffn): Add_Norm(
        (module): MLP(
          (net): Sequential(
            (0): Linear(in_features=128, out_features=256, bias=True)
            (1): GELU(approximate='none')
            (2): Linear(in_features=256, out_features=128, bias=True)
          )
        )
        (ln): LayerNorm((128,), eps=1e-05, ele

In [None]:
y_pred=torch.randn(16,15,2).argmax(dim=-1)
y=torch.randint(low=0,high=2,size=(16,15))

y_lens=torch.randint(low=0,high=10,size=(16,)).long()

y_lens

tensor([3, 8, 5, 4, 5, 7, 8, 2, 3, 0, 8, 2, 7, 2, 2, 3])

In [None]:
len(y_pred)

16

In [None]:
s=0
for i in range(len(y_pred)):
  s+=(y_pred[i,:y_lens[i]]==y[i,:y_lens[i]]).sum()
s

tensor(38)

In [None]:
def train_on_epoch(model,optimizer,train_loader,gradient_accum,device,current_epoch):
  model.train()
  running_loss=0.
  acc=[]
  gradient_accumulation_steps=gradient_accum
  index=0
  for idx,(x_pad,y_pad,x_len,y_len)  in enumerate(tqdm(train_loader,total=len(train_loader))):

    B,T,_=x_pad.shape
    x_pad,y_pad,x_len,y_len=x_pad.cuda(),y_pad.cuda(),x_len.cuda(),y_len.cuda()
    mask=create_mask(y_len).cuda()

    with torch.autocast(device_type=device,dtype=torch.bfloat16):
      y_pred= model(x_pad)

      _,D=y_pred.shape
      loss= F.cross_entropy(y_pred.view(-1,D),y_pad.view(-1,),reduction='none')


      masked_loss=loss.view(B,-1)*mask
      avg_loss = torch.mean(masked_loss.sum(dim=1)/ mask.sum(dim=1))

    running_loss+=avg_loss.item()
    avg_loss=avg_loss/gradient_accumulation_steps
    avg_loss.backward()

    y_pred=rearrange(tensor=y_pred,pattern="(B T) D -> B T D",B=B,T=T)

    y_pred=y_pred.argmax(dim=-1)

    total_sum=0
    total=0
    for i in range(len(y_pred)):
      total_sum += (y_pred[i,:y_len[i]]== y_pad[i,:y_len[i]]).sum()
      total+=y_len[i]

    acc.append((total_sum/total).item())
    if (idx % gradient_accumulation_steps) ==0:
        s=f"""
            EPOCH : {current_epoch} , \n
            idx : {idx}, \n
            loss:{avg_loss.item()}, \n
            accuracy: {total_sum/total} \n
          """
        mode='w' if idx==0 else 'a'
        with open(f"TrainLog{current_epoch}.txt",mode) as f:
          f.write(s)

        optimizer.step()
        optimizer.zero_grad()
        model.zero_grad()



  return running_loss/len(train_loader), sum(acc)/len(acc)

In [None]:
def validate_one_epoch(model,validation_loader,device,current_epoch):
  model.eval()
  running_loss=0.
  acc=[]
  with torch.no_grad():
    for idx,(x_pad,y_pad,x_len,y_len)  in enumerate(tqdm(validation_loader,total=len(validation_loader))):
      B,T,_=x_pad.shape
      with torch.autocast(device_type=device,dtype=torch.bfloat16):

        y_pred= model(x_pad)

        _,D=y_pred.shape
        loss= F.cross_entropy(y_pred.view(-1,D),y_pad.view(-1,),reduction='none')
        mask=create_mask(y_len).cuda()

        masked_loss=loss.view(B,-1)*mask
        avg_loss = torch.mean(masked_loss.sum(dim=1)/ mask.sum(dim=1))

      running_loss+=avg_loss.item()

      y_pred=rearrange(tensor=y_pred,pattern="(B T) D -> B T D",B=B,T=T)

      y_pred=y_pred.argmax(dim=-1)

      total_sum=0
      total=0
      for i in range(len(y_pred)):
        total_sum += (y_pred[i,:y_len[i]]== y_pad[i,:y_len[i]]).sum()
        total+=y_len[i]

      acc.append((total_sum/total).item())
      if idx%100==0:
        s=f"""
            EPOCH : {current_epoch} , \n
            idx : {idx}, \n
            loss:{avg_loss.item()}, \n
            accuracy: {total_sum/total} \n
          """
        mode='w' if idx==0 else 'a'
        with open(f"ValLog{current_epoch}.txt",mode) as f:
          f.write(s)

    return running_loss/len(validation_loader), sum(acc)/len(acc)

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch._dynamo
torch._dynamo.config.suppress_errors = True

model=Sepsis_Transformer_decoderonly(15,123,n_heads=4,attn_drop=0.1).cuda()

optimizer=torch.optim.AdamW(model.parameters(),lr=1e-4,betas=(0.9,0.95),weight_decay=0.01)
cosine_annealing_lr=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=500)

In [None]:
total_epoch=3
for i in range(total_epoch):
  train_loss,accuracy=train_on_epoch(
      model,optimizer,train_loader,device="cuda",current_epoch=i,gradient_accum=5
  )

  validation_loss,validation_accuracy=validate_one_epoch(
      model,validation_loader,device="cuda",current_epoch=i
  )
  print(f"""Training loss and accuracy after epoch {i}
          Training Loss{train_loss}
          Accuracy {accuracy}
        Validation Loss {validation_loss}
        Validation Accuracy{validation_accuracy}""")
  torch.save(model.state_dict(),f"modelparameters{i}.pt")


  0%|          | 0/2050 [00:00<?, ?it/s]

  0%|          | 0/252 [00:00<?, ?it/s]

Training loss and accuracy after epoch 0
          Training Loss0.3969107009506807
          Accuracy 0.7831862219659294
        Validation Loss 0.07183019594811811
        Validation Accuracy0.7133097350597382


  0%|          | 0/2050 [00:00<?, ?it/s]

  0%|          | 0/252 [00:00<?, ?it/s]

Training loss and accuracy after epoch 1
          Training Loss0.371152789381946
          Accuracy 0.7968206761813745
        Validation Loss 0.07483130834814979
        Validation Accuracy0.7142178260144734


  0%|          | 0/2050 [00:00<?, ?it/s]

  0%|          | 0/252 [00:00<?, ?it/s]

Training loss and accuracy after epoch 2
          Training Loss0.36294117117073477
          Accuracy 0.801701791315544
        Validation Loss 0.0780895491380517
        Validation Accuracy0.7050688522202628


In [None]:
model1 = Sepsis_Transformer_decoderonly(15,123,n_heads=4,attn_drop=0.1)
model2 = Sepsis_Transformer_decoderonly(15,123,n_heads=4,attn_drop=0.1)

In [None]:
model1.load_state_dict(torch.load('/content/modelparameters1.pt'))
model2.load_state_dict(torch.load('/content/modelparameters2.pt'))

<All keys matched successfully>

In [None]:
model2_dict = model2.state_dict()
model1_dict = model1.state_dict()

In [None]:
for key in model1.state_dict():
  model2_dict[key] = 0.2*model1_dict[key] + 0.8*model2_dict[key]

In [None]:
model2.load_state_dict(model2_dict)

<All keys matched successfully>

In [None]:
torch.save(model2.state_dict(),"modelparameters_final.pt")

In [None]:
validation_loss,validation_accuracy=validate_one_epoch(
      model2,validation_loader,device="cuda",current_epoch=i
  )

  0%|          | 0/252 [00:00<?, ?it/s]

In [None]:
validation_loss,validation_accuracy

(0.630071083202012, 0.7058831313772808)