In [326]:
%matplotlib inline

In [327]:
import os
import torch
from einops import rearrange

In [328]:
os.environ['CUDA_VISIBLE_DEVICES']='7'

In [329]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [330]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data

In [331]:
from transformer_translation.dataset import TagReportDataset, load_pickle
from torch.utils.data import DataLoader

In [332]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [361]:
splits = load_pickle('20200525_splits.pkl')
countvec = load_pickle('20200528_countvec.pkl')

### train set

In [362]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['train']
    ,countvec = countvec
)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(dataset))

85


### val set

In [363]:
val_dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['val']
    ,countvec = countvec
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(val_dataset))

6


# train

In [372]:
from transformer_translation.model import LanguageTransformer, ReportTransformer

In [373]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [374]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [375]:
from tsf_infer_utils import prep_transf_inputs, infer
from torchtext.data.metrics import bleu_score
from tsf_utils import format_list_for_bleu, get_bleu_from_loader

In [376]:
    %%time
    print_every = 20
    num_epochs = 20
    early_stopping_flag = True

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        model.train()
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      'Train Loss: {:.3f}'.format(total_loss / print_every))
                total_loss = 0
                
        if early_stopping_flag:
            model.eval()
            print(f'Epoch [{epoch + 1} / {num_epochs}]:')
            print('{} BLEU: {:.2%}'.format(
                '\ttrain', get_bleu_from_loader(model, loader)))
            print('{} BLEU: {:.2%}'.format(
                '\tval', get_bleu_from_loader(model, val_loader)))
        print('\n')

Epoch [1 / 20] 	 Step [20 / 85] 	 Train Loss: 9.181
Epoch [1 / 20] 	 Step [40 / 85] 	 Train Loss: 8.737
Epoch [1 / 20] 	 Step [60 / 85] 	 Train Loss: 8.212
Epoch [1 / 20] 	 Step [80 / 85] 	 Train Loss: 7.880
Epoch [1 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [2 / 20] 	 Step [20 / 85] 	 Train Loss: 7.473
Epoch [2 / 20] 	 Step [40 / 85] 	 Train Loss: 7.190
Epoch [2 / 20] 	 Step [60 / 85] 	 Train Loss: 6.867
Epoch [2 / 20] 	 Step [80 / 85] 	 Train Loss: 6.471
Epoch [2 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [3 / 20] 	 Step [20 / 85] 	 Train Loss: 6.037
Epoch [3 / 20] 	 Step [40 / 85] 	 Train Loss: 5.709
Epoch [3 / 20] 	 Step [60 / 85] 	 Train Loss: 5.346
Epoch [3 / 20] 	 Step [80 / 85] 	 Train Loss: 5.179
Epoch [3 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [4 / 20] 	 Step [20 / 85] 	 Train Loss: 4.754
Epoch [4 / 20] 	 Step [40 / 85] 	 Train Loss: 4.394
Epoch [4 / 20] 	 Step [60 / 85] 	 Train Loss: 4.397
Epoch [4 / 20] 	 Step [80 / 85] 	 Train Loss: 4.241
Epoch 

### test the model with our previous way to eval

In [377]:
from tsf_utils import get_tk_from_proba, format_list_for_bleu

In [378]:
    %%time
    pred_list = []
    tgt_list = []
    device = model.embed_tgt.weight.device
    
    for (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in iter(val_loader):
        
        # prepare inputs
        src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
            src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)

        # run inference
        outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)

        # get predictions from proba
        pred = get_tk_from_proba(outputs)
        
        # get pred and ground truth ready for metric eval
        pred_list += [list(pred[row, :].cpu().numpy()) for row in range(pred.shape[0])]
        tgt_list += [list(tgt[row, :].cpu().numpy()) for row in range(pred.shape[0])]

CPU times: user 16.8 s, sys: 1.02 s, total: 17.8 s
Wall time: 692 ms


In [379]:
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
print('{} BLEU: {:.2%}'.format(
                '\tval',bleu_score(pred_list_bleu, tgt_list_bleu)))

	val BLEU: 35.74%


### build the new way to eval

In [380]:
from transformer_translation.translate_sentence import gen_nopeek_mask

In [381]:
def forward_model(model, src, pred_sentence):
    tgt = torch.tensor(pred_sentence).unsqueeze(0).unsqueeze(0).to(device)

    # prepare inputs
    src = src.to(device)
    tgt = tgt[0].to(device)
    tgt_mask = gen_nopeek_mask(tgt.shape[1]).to(device)

    # run inference
    return model(src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=tgt_mask)

In [398]:
from transformer_translation.dataset import IDX_EOS, IDX_SOS

In [400]:
# get data example
(src, src_key_padding_mask, tgt, tgt_key_padding_mask) = next(iter(val_loader))

# select single sentence
src = src[:,[0],:]
tgt = tgt[:,[0],:]

# replace by initial tgt for out-of-sample decoding
pred_sentence = [IDX_SOS]

# run out-of-sample decoding
i = 0
while int(pred_sentence[-1]) != IDX_EOS and i < max_seq_length:
    output_tk = get_tk_from_proba(
        forward_model(model, src, pred_sentence))
    pred_sentence.append(output_tk[0][-1].item())
    output_tk = torch.tensor(pred_sentence).unsqueeze(0).unsqueeze(0).to(device)
    i += 1

print('TAGS: ',[tags_index2word[idx.item()] for idx in src[0,0,:].nonzero()])
print_nl_pred_vs_tgt(
    [list(output_tk[0][row, :].cpu().numpy()) for row in range(output_tk[0].shape[0])]
    ,[list(tgt[0][row, :].cpu().numpy()) for row in range(tgt[0].shape[0])]
    ,reports_index2word
)

TAGS:  ['normal']
TARGET:  SOS no active cardiopulmonary disease . EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD
PREDICTION:  SOS no acute cardiopulmonary abnormality . . the lungs are clear bilaterally . specifically no evidence of focal consolidation pneumothorax or pleural effusion . . cardio mediastinal silhouette is unremarkable . visualized osseous structures of the thorax are without acute abnormality . EOS




# assess perf

### run inference

In [17]:
test_dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['test']
    ,countvec = countvec
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(test_dataset))

6


In [18]:
%%time
pred_list, tgt_list = infer(model, test_loader)

CPU times: user 19.5 s, sys: 1.32 s, total: 20.8 s
Wall time: 906 ms


### compute BLEU

In [23]:
%%time
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
print('{} BLEU: {:.2%}'.format(
                'test', bleu_score(pred_list_bleu, tgt_list_bleu)))

test BLEU: 40.81%
CPU times: user 1.58 s, sys: 20 ms, total: 1.6 s
Wall time: 1.6 s


### create vocab dicts 

In [384]:
tags_index2word = dict(zip(range(len(dataset.countvec.vocabulary)), dataset.countvec.vocabulary))

In [385]:
from transformer_translation.dataset import load_report_voc
reports_index2word = load_report_voc(
    os.path.join(data_path, 'reports', 'voc.pkl'))

### assess nat lg perf

In [106]:
from tsf_utils import print_nl_pred_vs_tgt

In [23]:
import random
k = 10
sel_idx = random.sample(range(len(pred_list)), k=k)

In [21]:
print_nl_pred_vs_tgt(
                [pred_list[i] for i in sel_idx]
                ,[tgt_list[i] for i in sel_idx]
                ,reports_index2word
            )

TARGET:   normal chest no evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .
PREDICTION:  no chest heart evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .


TARGET:   no acute cardiopulmonary abnormality . mediastinal contours are normal . heart size is within normal limits . multiple scattered calcified pulmonary nodules xxxx sequela of prior granulomatous disease . otherwise lungs are clear . . there is no pneumothorax or large pleural effusion . no bony abnormality . .
PREDICTION:   acute cardiopulmonary abnormality . . contours are within . no size is within normal limits . no calcified calcified granulomas nodules are . of prior granulomatous disease . no lungs are clear without no no is no pneumothorax or large pleural effusion . there acute abnormality . .


TARGET:   no 