In [402]:
import os
import torch
from einops import rearrange

In [403]:
os.environ['CUDA_VISIBLE_DEVICES']='5'

In [404]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [405]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data

In [406]:
from transformer_translation.dataset import TagReportDataset, load_pickle
from torch.utils.data import DataLoader

In [407]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [408]:
splits = load_pickle('20200525_splits.pkl')
countvec = load_pickle('20200528_countvec.pkl')

### train set

In [409]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['train']
    ,countvec = countvec
)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(dataset))

85


### val set

In [410]:
val_dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['val']
    ,countvec = countvec
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(val_dataset))

6


# train

In [411]:
from transformer_translation.model import LanguageTransformer, ReportTransformer

In [697]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [698]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [699]:
from tsf_infer_utils import prep_transf_inputs, infer
from torchtext.data.metrics import bleu_score
from tsf_utils import format_list_for_bleu, get_bleu_from_loader

In [700]:
    %%time
    print_every = 20
    num_epochs = 20
    early_stopping_flag = True

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        model.train()
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      'Train Loss: {:.3f}'.format(total_loss / print_every))
                total_loss = 0
                
        if early_stopping_flag:
            model.eval()
            print(f'Epoch [{epoch + 1} / {num_epochs}]:')
            print('{} BLEU: {:.2%}'.format(
                '\ttrain', get_bleu_from_loader(model, loader)))
            print('{} BLEU: {:.2%}'.format(
                '\tval', get_bleu_from_loader(model, val_loader)))
        print('\n')

Epoch [1 / 20] 	 Step [20 / 85] 	 Train Loss: 9.186
Epoch [1 / 20] 	 Step [40 / 85] 	 Train Loss: 8.761
Epoch [1 / 20] 	 Step [60 / 85] 	 Train Loss: 8.224
Epoch [1 / 20] 	 Step [80 / 85] 	 Train Loss: 7.830
Epoch [1 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [2 / 20] 	 Step [20 / 85] 	 Train Loss: 7.419
Epoch [2 / 20] 	 Step [40 / 85] 	 Train Loss: 7.176
Epoch [2 / 20] 	 Step [60 / 85] 	 Train Loss: 6.908
Epoch [2 / 20] 	 Step [80 / 85] 	 Train Loss: 6.537
Epoch [2 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [3 / 20] 	 Step [20 / 85] 	 Train Loss: 6.020
Epoch [3 / 20] 	 Step [40 / 85] 	 Train Loss: 5.682
Epoch [3 / 20] 	 Step [60 / 85] 	 Train Loss: 5.425
Epoch [3 / 20] 	 Step [80 / 85] 	 Train Loss: 5.194
Epoch [3 / 20]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [4 / 20] 	 Step [20 / 85] 	 Train Loss: 4.874
Epoch [4 / 20] 	 Step [40 / 85] 	 Train Loss: 4.642
Epoch [4 / 20] 	 Step [60 / 85] 	 Train Loss: 4.428
Epoch [4 / 20] 	 Step [80 / 85] 	 Train Loss: 4.108
Epoch 

### test the model with our previous way to eval

In [701]:
from tsf_utils import get_tk_from_proba, format_list_for_bleu

In [702]:
    %%time
    pred_list = []
    tgt_list = []
    device = model.embed_tgt.weight.device
    
    for (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in iter(val_loader):
        
        # prepare inputs
        src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
            src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)

        # run inference
        outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)

        # get predictions from proba
        pred = get_tk_from_proba(outputs)
        
        # get pred and ground truth ready for metric eval
        pred_list += [list(pred[row, :].cpu().numpy()) for row in range(pred.shape[0])]
        tgt_list += [list(tgt[row, :].cpu().numpy()) for row in range(pred.shape[0])]

CPU times: user 15.4 s, sys: 824 ms, total: 16.2 s
Wall time: 593 ms


In [703]:
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
print('{} BLEU: {:.2%}'.format(
                '\tval',bleu_score(pred_list_bleu, tgt_list_bleu)))

	val BLEU: 35.57%


### OOS eval

In [704]:
%%time
pred_list, tgt_list, tag_list = oos_infer(model, val_loader, max_seq_length)

CPU times: user 1h 58min 12s, sys: 3min 48s, total: 2h 2min
Wall time: 2min 46s


In [705]:
# print BLEU
print('{}BLEU: {:.2%}\n\n'.format(
    '',bleu_score(*format_list_for_bleu(pred_list, tgt_list))
))

BLEU: 11.83%




In [706]:
# print in nat lang
print_nl_pred_vs_tgt(pred_list, tgt_list, reports_index2word, tag_list, tags_index2word)

TAGS:  deformity old_injury
TARGET:  SOS no acute findings heart size within normal limits stable mediastinal and hilar contours . no alveolar consolidation no findings of pleural effusion or pulmonary edema . chronic appearing contour deformity of the right posterolateral th rib again noted suggestive of old injury . EOS
PREDICTION:  SOS no acute cardiopulmonary abnormality . . . the lungs are clear bilaterally . specifically no evidence of focal consolidation pneumothorax or pleural effusion . . cardio mediastinal silhouette is unremarkable . visualized osseous structures of the thorax are without acute abnormality . EOS


TAGS:  degenerative_change
TARGET:  SOS no acute cardiopulmonary findings . there is no focal consolidation . there is no pneumothorax or large pleural effusion . the cardiomediastinal contours are grossly unremarkable . the heart size is within normal limits . there are mild thoracic spine degenerative changes . EOS
PREDICTION:  SOS no acute cardiopulmonary abnorm