In [1]:
import os
import torch
from einops import rearrange

In [2]:
os.environ['CUDA_VISIBLE_DEVICES']='3'

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import TagReportDataset, load_pickle
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [7]:
splits = load_pickle('20200525_splits.pkl')
countvec = load_pickle('20200528_countvec.pkl')

### train set

In [8]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['train']
    ,countvec = countvec
)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(dataset))

85


### val set

In [9]:
val_dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['val']
    ,countvec = countvec
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(val_dataset))

6


# train

In [10]:
from transformer_translation.model import LanguageTransformer, ReportTransformer

In [24]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [25]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [26]:
from tsf_infer_utils import prep_transf_inputs, infer
from torchtext.data.metrics import bleu_score
from tsf_utils import format_list_for_bleu, get_bleu_from_loader

In [78]:
    %%time
    print_every = 20
    num_epochs = 10
    early_stopping_flag = True

    lowest_val = 1e9
    train_losses, val_losses, train_bleu, val_bleu  = {}, [], [], []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        model.train()
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      'Train Loss: {:.3f}'.format(total_loss / print_every))
                total_loss = 0
            
        if early_stopping_flag:
            model.eval()
            print(f'Epoch [{epoch + 1} / {num_epochs}]:')
            
            # train
            pred_list, tgt_list, loss_per_batch = infer(model, loader)
            pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
            train_bleu.append(bleu_score(pred_list_bleu, tgt_list_bleu))
            print('{} BLEU in-sample: {:.2%}'.format(
                '\ttrain', train_bleu[-1]))
            
            # val IS
            pred_list, tgt_list, loss_per_batch = infer(model, val_loader)
            val_losses.append(sum(loss_per_batch) / len(loss_per_batch))
            pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
            val_bleu.append(bleu_score(pred_list_bleu, tgt_list_bleu))
            print('{} BLEU in-sample: {:.2%}'.format(
                '\tval', val_bleu[-1]))
            
            # val OOS
            pred_list, tgt_list, tag_list = oos_infer_batched(model, val_loader, max_seq_length)
            print('{} BLEU out-of-sample: {:.2%}\n\n'.format(
                '\tval',bleu_score(*format_list_for_bleu(pred_list, tgt_list))
            ))
            
        print('\n')

Epoch [1 / 10] 	 Step [20 / 85] 	 Train Loss: 1.616
Epoch [1 / 10] 	 Step [40 / 85] 	 Train Loss: 1.797
Epoch [1 / 10] 	 Step [60 / 85] 	 Train Loss: 1.542
Epoch [1 / 10] 	 Step [80 / 85] 	 Train Loss: 1.444
Epoch [1 / 10]:
	train BLEU in-sample: 41.15%
	val BLEU in-sample: 31.13%
	val BLEU out-of-sample: 14.26%




Epoch [2 / 10] 	 Step [20 / 85] 	 Train Loss: 1.411
Epoch [2 / 10] 	 Step [40 / 85] 	 Train Loss: 1.583
Epoch [2 / 10] 	 Step [60 / 85] 	 Train Loss: 1.505
Epoch [2 / 10] 	 Step [80 / 85] 	 Train Loss: 1.707
Epoch [2 / 10]:
	train BLEU in-sample: 43.75%
	val BLEU in-sample: 32.46%
	val BLEU out-of-sample: 13.88%




Epoch [3 / 10] 	 Step [20 / 85] 	 Train Loss: 1.560
Epoch [3 / 10] 	 Step [40 / 85] 	 Train Loss: 1.408
Epoch [3 / 10] 	 Step [60 / 85] 	 Train Loss: 1.488
Epoch [3 / 10] 	 Step [80 / 85] 	 Train Loss: 1.562
Epoch [3 / 10]:
	train BLEU in-sample: 45.82%
	val BLEU in-sample: 34.11%
	val BLEU out-of-sample: 13.18%




Epoch [4 / 10] 	 Step [20 / 85] 	 Train Loss: 

### test the model with our previous way to eval

In [28]:
from tsf_utils import get_tk_from_proba, format_list_for_bleu

In [29]:
    %%time
    pred_list = []
    tgt_list = []
    device = model.embed_tgt.weight.device
    
    for (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in iter(val_loader):
        
        # prepare inputs
        src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
            src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)

        # run inference
        outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)

        # get predictions from proba
        pred = get_tk_from_proba(outputs)
        
        # get pred and ground truth ready for metric eval
        pred_list += [list(pred[row, :].cpu().numpy()) for row in range(pred.shape[0])]
        tgt_list += [list(tgt[row, :].cpu().numpy()) for row in range(pred.shape[0])]

CPU times: user 14.6 s, sys: 848 ms, total: 15.4 s
Wall time: 569 ms


In [30]:
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
print('{} BLEU: {:.2%}'.format(
                '\tval',bleu_score(pred_list_bleu, tgt_list_bleu)))

	val BLEU: 0.00%


### OOS eval - one sentence

In [31]:
def forward_model(model, src, pred_sentence, device):
    tgt = torch.tensor(pred_sentence).unsqueeze(0).unsqueeze(0).to(device)

    # prepare inputs
    src = src.to(device)
    tgt = tgt[0].to(device)
    tgt_mask = gen_nopeek_mask(tgt.shape[1]).to(device)

    # run inference
    return model(src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=tgt_mask)

def oos_infer_sent(model, src, tgt, max_seq_length, device): 
    # replace by initial tgt for out-of-sample decoding
    pred = [IDX_SOS]

    # run out-of-sample decoding
    i = 0
    while int(pred[-1]) != IDX_EOS and i < max_seq_length:
        output = get_tk_from_proba(
            forward_model(model, src, pred, device))
        pred.append(output[0][-1].item())
        i += 1
        
    # format outputs
    pred_list, tgt_list = [to_list_npint64(pred)], [to_list_npint64(pop_padding_ts(tgt).flatten().tolist())]
    tag_list = [src.nonzero()[:,2]]
    
    return pred_list, tgt_list, tag_list

def pop_padding(tk_list):
    while tk_list[-2:] == [0, 0]:
        tk_list.pop(-1)
    tk_list.pop(-1)
    return tk_list

def pop_padding_ts(tk_tensor):
    return tk_tensor[:,:,tk_tensor.nonzero()[:,2]]

def to_list_npint64(list_int):
    return [np.int64(tk) for tk in list_int]

In [32]:
from transformer_translation.translate_sentence import gen_nopeek_mask
from transformer_translation.dataset import IDX_EOS, IDX_SOS
import numpy as np

In [33]:
(src, src_key_padding_mask, tgt, tgt_key_padding_mask) = next(iter(val_loader))

In [34]:
%%time
# select single sentence
i = 2
src_i = src[:,[i],:]
tgt_i = tgt[:,[i],:]

# run inference
pred_out, tgt_out, tag_out = oos_infer_sent(model, src_i, tgt_i, max_seq_length, device)

CPU times: user 1min 38s, sys: 4.26 s, total: 1min 42s
Wall time: 2.54 s


In [39]:
# print BLEU
print('{}BLEU: {:.2%}\n\n'.format(
    '',bleu_score(*format_list_for_bleu(pred_out, tgt_out))
))

# print sent lengths
print('target length: {} words, pred legnth: {} words\n'.format(len(tgt_out[0]), len(pred_out[0])))

# print in nat lg
print_nl_pred_vs_tgt(pred_out, tgt_out, reports_index2word, tag_out, tags_index2word)

BLEU: 0.00%


target length: 10 words, pred legnth: 97 words

TAGS:  normal
TARGET:  SOS heart size normal . lungs are clear . EOS
PREDICTION:  SOS . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .




### OSS eval - one loader, batched

In [59]:
def get_EOS_indices(pred):
    # get indices of instances of IDX_EOS in each row
    idxr, idxc = torch.where(pred == IDX_EOS)
    
    if not any(idxr):
        return [-2 for i in range(pred.shape[0])]

    # greedy search of first appearance
    row_many_idx, row_unique_idx = 0, 0
    idx_eos = []
    while row_unique_idx < pred.shape[0] and row_many_idx < idxr.shape[0]:
        if idxr[row_many_idx].item() > row_unique_idx:
            idx_eos.append(-2)
            row_unique_idx += 1
        if idxr[row_many_idx].item() == row_unique_idx:
            idx_eos.append(idxc[row_many_idx].item())
            row_unique_idx += 1
        else:
            row_many_idx += 1
        if row_many_idx == idxr.shape[0]:
            idx_eos.append(-2)
            break
        
    return idx_eos

In [60]:
def oos_infer_batched(model, loader, max_seq_length):
    device = model.embed_tgt.weight.device
    pred_list_lg, tgt_list_lg, tag_list_lg = [], [], []

    for (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in iter(loader):

        src = src.to(device)

        pred = IDX_SOS * torch.ones((tgt.shape[1], 1), dtype=torch.long, device=device)
        pred_mask = gen_nopeek_mask(pred.shape[1]).to(device)

        while  not (pred == IDX_EOS).any(1).all() and pred.shape[1] < max_seq_length + 1:
            output = model(src, pred, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=pred_mask)#[[-1],:])
            pred = torch.cat([pred, get_tk_from_proba(output)[:,[-1]]], dim=1)
            pred_mask = gen_nopeek_mask(pred.shape[1]).to(device)

        # format pred sentence output
        idx_eos = get_EOS_indices(pred)
        pred_list = [to_list_npint64(pred[i,:idx_eos[i]].tolist()+[IDX_EOS]) for i in range(pred.shape[0])]

        # format tgt sentence output
        tgt_list = [to_list_npint64(pop_padding_ts(tgt[:,[i],:]).flatten().tolist()) for i in range(tgt.shape[1])]

        # format tags
        tag_idx = src.nonzero()
        tag_list = [tag_idx[tag_idx[:,1] == i, 2] for i in range(src.shape[1])]

        # aggregate results
        pred_list_lg += pred_list
        tgt_list_lg += tgt_list
        tag_list_lg += tag_list
    
    return pred_list_lg, tgt_list_lg, tag_list_lg

In [69]:
    pred_list_lg, tgt_list_lg, tag_list_lg = [], [], []

    for (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in iter(val_loader):

        src = src.to(device)

        pred = IDX_SOS * torch.ones((tgt.shape[1], 1), dtype=torch.long, device=device)
        pred_mask = gen_nopeek_mask(pred.shape[1]).to(device)

        while  not (pred == IDX_EOS).any(1).all() and pred.shape[1] < max_seq_length + 1:
            output = model(src, pred, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=pred_mask)#[[-1],:])
            pred = torch.cat([pred, get_tk_from_proba(output)[:,[-1]]], dim=1)
            pred_mask = gen_nopeek_mask(pred.shape[1]).to(device)

        # format pred sentence output
        idx_eos = get_EOS_indices(pred)
        pred_list = [to_list_npint64(pred[i,:idx_eos[i]].tolist()+[IDX_EOS]) for i in range(pred.shape[0])]

        # format tgt sentence output
        tgt_list = [to_list_npint64(pop_padding_ts(tgt[:,[i],:]).flatten().tolist()) for i in range(tgt.shape[1])]

        # format tags
        tag_idx = src.nonzero()
        tag_list = [tag_idx[tag_idx[:,1] == i, 2] for i in range(src.shape[1])]

        # aggregate results
        pred_list_lg += pred_list
        tgt_list_lg += tgt_list
        tag_list_lg += tag_list

In [73]:
tgt_list_lg

[[1,
  74,
  383,
  51,
  16,
  17,
  1783,
  69,
  1309,
  62,
  24,
  333,
  83,
  84,
  390,
  372,
  88,
  4,
  42,
  15,
  26,
  22,
  51,
  16,
  4,
  42,
  227,
  26,
  167,
  4,
  52,
  171,
  147,
  94,
  26,
  166,
  149,
  84,
  607,
  219,
  1047,
  1006,
  4,
  31,
  26,
  10,
  38,
  4,
  1133,
  62,
  24,
  333,
  224,
  317,
  51,
  16,
  17,
  1783,
  94,
  310,
  56,
  174,
  403,
  4,
  42,
  1383,
  381,
  485,
  51,
  42,
  52,
  425,
  26,
  115,
  5,
  317,
  4,
  31,
  26,
  10,
  35,
  36,
  4,
  2],
 [1,
  7,
  225,
  93,
  29,
  88,
  303,
  65,
  66,
  352,
  51,
  42,
  469,
  523,
  468,
  4,
  15,
  16,
  17,
  18,
  97,
  20,
  22,
  4,
  24,
  134,
  26,
  22,
  4,
  31,
  26,
  32,
  274,
  69,
  42,
  7,
  243,
  94,
  260,
  42,
  144,
  320,
  4,
  31,
  26,
  115,
  32,
  819,
  69,
  42,
  7,
  211,
  94,
  4,
  29,
  88,
  51,
  923,
  231,
  69,
  7,
  225,
  93,
  115,
  168,
  260,
  42,
  145,
  320,
  56,
  10,
  35,
  36,
  4,
  52,
  87,
 

In [72]:
pred_list_lg

[[1,
  3,
  4,
  10,
  11,
  12,
  13,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  28,
  30,
  35,
  36,
  37,
  38,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  11,
  188,
  13,
  4,
  2],
 [1,
  3,
  4,
  10,
  11,
  12,
  13,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  28,
  29,
  30,
  35,
  36,
  37,
  38,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  11,
  188,
  13,
  4,
  2],
 [1,
  3,
  4,
  10,
  11,
  12,
  13,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  28,
  30,
  35,
  36,
  37,
  38,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  10,
  11,
  188,
  290,
  69,
  42,
  7,
  63,
  93,
  26,
  14,
  26,
  14,
  4,
  2],
 [1,
  10,
  11,
  12,
  13,
  4,
  4,
  4,
  4,
  42,
  15,
  16,
  26,
  22,
  4,
  42,
  125,
  20,
  99,
  4,


In [75]:
%%time
model.eval()
pred_list, tgt_list, tag_list = oos_infer_batched(model, val_loader, max_seq_length)

CPU times: user 11min 18s, sys: 28.8 s, total: 11min 47s
Wall time: 16.7 s


In [76]:
# print BLEU
print('{}BLEU: {:.2%}\n\n'.format(
    '',bleu_score(*format_list_for_bleu(pred_list_lg, tgt_list_lg))
))

# print in nat lg
print_nl_pred_vs_tgt(pred_list_lg, tgt_list_lg, reports_index2word, tag_list_lg, tags_index2word)

BLEU: 13.51%


TAGS:  metastatic_disease lung_neoplasms nodule
TARGET:  SOS interval increase in size and number of innumerable bilateral pulmonary nodules consistent with worsening metastatic disease . the heart is normal in size . the mediastinum is stable . left sided chest xxxx is again visualized with tip at cavoatrial junction . there is no pneumothorax . numerous bilateral pulmonary nodules have increased in size and number xxxx compared to prior study . the dominant nodule mass in the left midlung is also mildly increased . there is no pleural effusion . EOS
PREDICTION:  SOS  . no acute cardiopulmonary abnormality . . . . . . . . . . . . . . . . . . no focal consolidation pleural effusion or pneumothorax . . . . . . . . . . no acute bony abnormality . EOS


TAGS:  degenerative_change pneumonia
TARGET:  SOS right middle lobe airspace disease which could represent pneumonia in the appropriate clinical setting . heart size and mediastinal contour are normal . pulmonary vascularity

### OSS eval - one loader, sent-by-sent

In [41]:
from tsf_infer_utils import oos_infer

In [42]:
%%time
model.eval()
pred_list, tgt_list, tag_list = oos_infer(model, val_loader, max_seq_length)

CPU times: user 2h 3min 52s, sys: 3min 53s, total: 2h 7min 46s
Wall time: 2min 54s


In [43]:
# print BLEU
print('{}BLEU: {:.2%}\n\n'.format(
    '',bleu_score(*format_list_for_bleu(pred_list, tgt_list))
))

BLEU: 13.79%




### create vocab dicts 

In [36]:
tags_index2word = dict(zip(range(len(dataset.countvec.vocabulary)), dataset.countvec.vocabulary))

In [37]:
from transformer_translation.dataset import load_report_voc
reports_index2word = load_report_voc(
    os.path.join(data_path, 'reports', 'voc.pkl'))

In [38]:
from tsf_utils import print_nl_pred_vs_tgt

In [706]:
# print in nat lang
print_nl_pred_vs_tgt(pred_list, tgt_list, reports_index2word, tag_list, tags_index2word)

TAGS:  deformity old_injury
TARGET:  SOS no acute findings heart size within normal limits stable mediastinal and hilar contours . no alveolar consolidation no findings of pleural effusion or pulmonary edema . chronic appearing contour deformity of the right posterolateral th rib again noted suggestive of old injury . EOS
PREDICTION:  SOS no acute cardiopulmonary abnormality . . . the lungs are clear bilaterally . specifically no evidence of focal consolidation pneumothorax or pleural effusion . . cardio mediastinal silhouette is unremarkable . visualized osseous structures of the thorax are without acute abnormality . EOS


TAGS:  degenerative_change
TARGET:  SOS no acute cardiopulmonary findings . there is no focal consolidation . there is no pneumothorax or large pleural effusion . the cardiomediastinal contours are grossly unremarkable . the heart size is within normal limits . there are mild thoracic spine degenerative changes . EOS
PREDICTION:  SOS no acute cardiopulmonary abnorm