In [4]:
# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
import math
import torch
import torch.nn as nn
from torch.nn import Transformer
from dataclasses import dataclass

In [22]:
from torch.optim.lr_scheduler import _LRScheduler

In [24]:
class NoamLR(_LRScheduler):
    """
    # source: https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
    
    Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
    linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
    to the inverse square root of the step number, scaled by the inverse square root of the
    dimensionality of the model. Time will tell if this is just madness or it's actually important.
    Parameters
    ----------
    warmup_steps: ``int``, required.
        The number of steps to linearly increase the learning rate.
    """
    def __init__(self, optimizer, warmup_steps):
        self.warmup_steps = warmup_steps
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        return [base_lr * scale for base_lr in self.base_lrs]

In [145]:
def get_sequences(model, dataloader : torch.utils.data.dataloader.DataLoader,  source_pad_id : int = 0, tgt_tokens_to_ids : Dict[str, int] =  None, max_len : int = 150,  DEVICE : str ='cuda:0'):
    """
    return relevant forcasted and sequences made by the model on the dataset.

    Args:
        model (torch.nn.Module): The model to be evaluated.
        val_dataloader (torch.utils.data.DataLoader): The validation dataloader.
        source_pad_id (int, optional): The padding token ID for the source input. Defaults to 0.
        DEVICE (str, optional): The device to run the evaluation on. Defaults to 'cuda:0'.
        tgt_tokens_to_ids (dict, optional): A dictionary mapping target tokens to their IDs. Defaults to None.
        max_len (int, optional): The maximum length of the generated target sequence. Defaults to 100.
    Returns:
        List[List[int]], List[List[int]]: The list of relevant and forecasted sequences.
    """

    model.eval()
    pred_trgs = []
    targets = []
    no_eov = {'pred_trgs':[], 'target':[]}
    with torch.inference_mode():
        for source_input_ids, target_input_ids in tqdm(dataloader, desc='scoring'):
            batch_pred_trgs = []
            batch_targets = []
            source_input_ids, target_input_ids = source_input_ids.to(DEVICE),target_input_ids.to(DEVICE)
            src_mask, source_padding_mask = create_source_mask(source_input_ids, source_pad_id, DEVICE) 
            memory = model.batch_encode(source_input_ids, src_mask, source_padding_mask)
            pred_trg = torch.tensor(tgt_tokens_to_ids['BOS'], device= DEVICE).repeat(source_input_ids.size(0)).unsqueeze(1)
            # generate target sequence one token at a time at batch level
            for i in range(max_len):
                trg_mask = generate_square_subsequent_mask(i+1, DEVICE)
                output = model.decode(pred_trg, memory, trg_mask)
                probs = model.generator(output[:, -1])
                pred_tokens = torch.argmax(probs, dim=1)
                eov_mask = pred_tokens == tgt_tokens_to_ids['EOV']
                if eov_mask.any():
                    # extend with sequences that have reached EOV
                    batch_pred_trgs.extend(torch.cat((pred_trg[eov_mask],torch.tensor(tgt_tokens_to_ids['EOV'], device = DEVICE).unsqueeze(0).repeat(eov_mask.sum(), 1)),dim = -1).cpu().tolist())
                    batch_targets.extend(target_input_ids[eov_mask].cpu().tolist())
                    # store corresponding target sequences
                    target_input_ids = target_input_ids[~eov_mask]
                    # break if all have reached EOV
                    if eov_mask.all():
                        break  
                    pred_trg = torch.cat((pred_trg[~eov_mask], pred_tokens[~eov_mask].unsqueeze(1)), dim=1)
                    memory = memory[~eov_mask]
                else:
                    pred_trg = torch.cat((pred_trg, pred_tokens.unsqueeze(1)), dim=1)
            if len(dataloader) != len(batch_pred_trgs):
                #print(f'len before{len(batch_pred_trgs)}')
                batch_pred_trgs.extend(pred_trg.cpu().tolist())
                batch_targets.extend(target_input_ids.cpu().tolist())
            pred_trgs.extend(batch_pred_trgs)
            targets.extend(batch_targets)
                
                
    return pred_trgs, targets

In [163]:
def get_gen_loss(crit_fake_pred):
    return  (-torch.mean(crit_fake_pred))

def get_crit_loss(crit_fake_pred, crit_real_pred):
    return  torch.mean(crit_real_pred) - torch.mean(crit_fake_pred)

In [32]:
@dataclass
class DataConfig:
    strategy : str = 'SDP'
    seed : int = 213033
    test_size : float = 0.05
    valid_size : float = 0.10
    predict_procedure : bool = None
    predict_drugs : bool = None
    input_max_length :int = 448
    target_max_length :int = 64
    source_vocab_size : int = None
    target_vocab_size : int = None
    target_pad_id : int = 0
    source_pad_id : int = 0
    batch_size : int = 4


In [33]:
@dataclass
class Config():
    gen_nlayers : int = 3
    gen_nheads : int = 8
    disc_nlayers : int = 1
    disc_nheads : int = 4
    ffn_hid_dim : int = 512
    hid_dim : int = 256
    dropout : float = 0.1
    label_smoothing : float = 0.0
    disc_clip : float = 0.1
    gen_clip : float = 1.0
    alpha : float = 0.3
    learning_rate : float = 4e-4
    warmup_steps : int = 30
    factor : int = 1


config = Config()
data_config = DataConfig()

In [18]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [66]:
generator = Generator(num_encoder_layers = config.gen_nlayers, num_decoder_layers = config.gen_nlayers,
                                  emb_size = config.hid_dim, nhead = config.gen_nheads, 
                                  src_vocab_size = data_config.source_vocab_size, tgt_vocab_size = data_config.target_vocab_size,
                                  dim_feedforward = config.ffn_hid_dim,
                                  dropout = config.dropout)

for p in generator.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)


generator = generator.to(DEVICE)

In [158]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    

class NoPositionalEncoding(nn.Module):
    def __init__(self):
        super(NoPositionalEncoding, self).__init__()

    def forward(self, x):
        return x
    
# Seq2Seq Network
class Generator(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int,
                 dropout: float = 0.1,
                 positional_encoding : bool = True):
        
        super(Generator, self).__init__()
        self.transformer = Transformer(d_model = emb_size,
                                       nhead = nhead,
                                       num_encoder_layers = num_encoder_layers,
                                       num_decoder_layers = num_decoder_layers,
                                       dim_feedforward = dim_feedforward,
                                       dropout = dropout,
                                       batch_first = True, norm_first = True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)

        self.positional_encoding = PositionalEncoding(emb_size, dropout = dropout, maxlen = max(src_vocab_size, tgt_vocab_size)+1) if positional_encoding else NoPositionalEncoding()


    def forward(self,
                src: torch.Tensor,
                trg: torch.Tensor,
                src_mask: torch.Tensor,
                tgt_mask: torch.Tensor,
                src_padding_mask: torch.Tensor,
                tgt_padding_mask: torch.Tensor,
                memory_key_padding_mask: torch.Tensor):
        
        outs = self.transformer(self.positional_encoding(self.src_tok_emb(src)), self.positional_encoding(self.tgt_tok_emb(trg)),
                                src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
        
        return self.transformer.encoder(
                            self.positional_encoding(self.src_tok_emb(src)),  mask=src_mask)

    def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor):
        return self.transformer.decoder(
                          self.positional_encoding(self.tgt_tok_emb(tgt)), memory = memory,
                          tgt_mask = tgt_mask)
    
    # We need to add source padding mask to avoid attending to source padding tokens
    def batch_encode(self, src: torch.Tensor, src_mask: torch.Tensor, src_key_padding_mask: torch.Tensor):
        return self.transformer.encoder(
                            self.positional_encoding(self.src_tok_emb(src)),  mask=src_mask, src_key_padding_mask=src_key_padding_mask)
    # No need for batch_decode as we're generating one token at a time

'''d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, batch_first=False, norm_first=False, bias=True, device=None, dtype=Non'''
class Discriminator(nn.Module):
    def __init__(self, emb_size, nhead, dim_feedforward, dropout, vocab_size, num_layers):
        super(Discriminator, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, emb_size)

        self.positional_encoding = PositionalEncoding(emb_size, dropout = dropout, maxlen = vocab_size + 1)
        self.Encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(emb_size, nhead, dim_feedforward, dropout, batch_first=True, norm_first=True), num_layers)
        self.classifier = nn.Linear(emb_size, 1)

    def forward(self, x: torch.Tensor,):
        x = self.positional_encoding(self.tok_emb(x))
        x = self.Encoder(x)
        x = self.classifier(x.mean(dim=1))
        return x
    
    def batch_encode(self, src: torch.Tensor, mask: torch.Tensor, key_padding_mask: torch.Tensor):
        return self.classifier(self.Encoder(self.positional_encoding(self.tok_emb(src)),
                                            mask = mask, src_key_padding_mask = key_padding_mask).mean(dim=1))
    

In [159]:
discriminator = Discriminator(emb_size = config.hid_dim, nhead = config.disc_nheads, dim_feedforward = config.ffn_hid_dim,
                             dropout = config.dropout, vocab_size = data_config.target_vocab_size, num_layers = config.disc_nlayers)

for p in discriminator.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

discriminator = discriminator.to(DEVICE)

In [28]:
gen_opt = torch.optim.Adam(generator.parameters(), lr = config.learning_rate)
disc_opt = torch.optim.SGD(discriminator.parameters(), lr = config.learning_rate)

In [29]:
lr_schedulerG = NoamLR(gen_opt, warmup_steps=config.warmup_steps )
lr_schedulerD = NoamLR(disc_opt, warmup_steps=config.warmup_steps )

In [31]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index= data_config.target_pad_id, label_smoothing = config.label_smoothing)

In [35]:
cd PatientTrajectoryForecasting/

/home/sifal.klioui/PatientTrajectoryForecasting


In [41]:
cd ..

/home/sifal.klioui


In [42]:
train_dataloader, val_dataloader, test_dataloader, src_tokens_to_ids, tgt_tokens_to_ids, _, data_and_properties = get_data_loaders(train_batch_size=data_config.batch_size, eval_batch_size = 128,
                                                                                                                                       pin_memory=True, **asdict(data_config))

new_to_old_ids_source file not availble, mapping is the same as the old on


In [43]:
data_config.source_vocab_size = data_and_properties['embedding_size_source']
data_config.target_vocab_size = data_and_properties['embedding_size_target']
data_config.target_pad_id = tgt_tokens_to_ids['PAD']
data_config.source_pad_id = src_tokens_to_ids['PAD']

In [120]:
from utils.eval import mapk, get_sequences
from utils.train import get_data_loaders
from dataclasses import dataclass, asdict
from typing import Dict
from tqdm import tqdm
from utils.train import create_source_mask, generate_square_subsequent_mask
from torch.utils.data import TensorDataset, DataLoader

In [147]:
predictions, targets = get_sequences(generator, test_dataloader, data_config.target_pad_id, tgt_tokens_to_ids, max_len = 64, DEVICE = DEVICE)

scoring: 100%|██████████| 13/13 [00:10<00:00,  1.24it/s]


In [180]:
# Convert lists to tensors
predictions_tensor = torch.tensor(predictions, dtype=torch.float)
targets_tensor = torch.tensor(targets, dtype=torch.float)

# Create TensorDataset
dataset = TensorDataset(predictions_tensor, targets_tensor)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [181]:
crit_repeats = 5

In [182]:
 for prediction, target in tqdm(dataloader, desc='train'):
    prediction, target = prediction.to(DEVICE),target.to(DEVICE)
    prediction_mask, prediction_padding_mask = create_source_mask(prediction, data_config.target_pad_id, DEVICE) 
    target_mask, target_padding_mask = create_source_mask(target, data_config.target_pad_id, DEVICE)
    ## Update discriminator ##
    DisLoss =0
    for _ in range(crit_repeats):
        disc_opt.zero_grad()
        #  torch.Tensor, src_mask: torch.Tensor, src_key_padding_mask: torch.Tensor
        crit_fake_pred = discriminator.batch_encode(prediction, prediction_mask, prediction_padding_mask)
        crit_real_pred = discriminator.batch_encode(target, target_mask, target_padding_mask)
        disc_loss = get_crit_loss(crit_fake_pred, crit_real_pred)
        print(disc_loss)
        DisLoss += disc_loss.item()/crit_repeats
        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        torch.nn.utils.clip_grad_value_(discriminator.parameters(), config.disc_clip)
    break

train:   0%|          | 0/26 [00:00<?, ?it/s]

tensor(-0.0685, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.0942, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.0583, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.0699, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.1012, device='cuda:0', grad_fn=<SubBackward0>)





In [169]:
crit_fake_pred.shape

torch.Size([128, 1])

In [None]:
crit_repeats = 5
vLoss = []
tLoss = []
for epoch in range(config.n_epochs):
    totalGen = 0
    totalDis = 0
    generator.train()
    discriminator.train()
    lr_schedulerG.step()
    lr_schedulerD.step()
    
    predictions, targets = get_sequences(generator, test_dataloader, data_config.target_pad_id, tgt_tokens_to_ids, max_len = 64, DEVICE = DEVICE)
        # Convert lists to tensors
    predictions_tensor = torch.tensor(predictions, dtype=torch.float)
    targets_tensor = torch.tensor(targets, dtype=torch.float)
    
    # Create TensorDataset
    dataset = TensorDataset(predictions_tensor, targets_tensor)
    
    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
     for prediction, target in tqdm(dataloader, desc='train'):
        prediction, target = prediction.to(DEVICE),target.to(DEVICE)
        prediction_mask, prediction_padding_mask = create_source_mask(prediction, data_config.target_pad_id, DEVICE) 
        target_mask, target_padding_mask = create_source_mask(target, data_config.target_pad_id, DEVICE)
        ## Update discriminator ##
        DisLoss = 0
         
        for _ in range(crit_repeats):
            disc_opt.zero_grad()
            #  torch.Tensor, src_mask: torch.Tensor, src_key_padding_mask: torch.Tensor
            crit_fake_pred = discriminator.batch_encode(prediction, prediction_mask, prediction_padding_mask)
            crit_real_pred = discriminator.batch_encode(target, target_mask, target_padding_mask)
            disc_loss = get_crit_loss(crit_fake_pred, crit_real_pred)
            print(disc_loss)
            DisLoss += disc_loss.item()/crit_repeats
            disc_loss.backward(retain_graph=True)
            disc_opt.step()
            torch.nn.utils.clip_grad_value_(discriminator.parameters(), config.disc_clip)
    
        totalDis += DisLoss
        ## Update generator ##
        gen_opt.zero_grad()
        output, _ = generator(src, trg[:,:-1])
        _,predValues = torch.max(output,2)
        fake = joinfakeData(pair,convertGenOutput(predValues.tolist(),reverseOutTypes))
        fake_mask =  generator.make_src_mask(fake)
        fake, fake_mask =fake.to(device) , fake_mask.to(device)
        #print(f"generator training fake :{predValues}")
        disc_fake_pred = discriminator(fake,fake_mask)
        gen_loss1 = get_gen_loss(disc_fake_pred)

        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trgs = trg[:,1:].contiguous().view(-1)

        gen_loss2 = criterion(output,trgs)
        gen_loss = (alpha * gen_loss1)  +  gen_loss2
        totalGen += gen_loss.item()
        gen_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), gen_clip)
        gen_opt.step()
        #epoch_loss = gen_loss.item() + disc_loss.item()
    tLoss.append(totalGen/len(trainLoader))
    ## validating

    valid_loss = evaluate(generator, valLoader, criterion,device)
    vLoss.append(valid_loss)
    

    print(f'current learning rate : {lr_schedulerG.get_last_lr()}')
    #print(f'current learning rate Discriminator : {lr_schedulerD.get_last_lr()}')
    print(f'Epoch: {epoch+1:02}')
    print(f" Train loss {totalGen/len(trainLoader)} , validation loss :{valid_loss}")