In [1]:
%matplotlib inline

In [2]:
import os
import torch
from einops import rearrange

In [3]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
print(device)

cuda:4


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import ParallelLanguageDataset, load_pickle
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [7]:
splits = load_pickle('20200525_splits.pkl')
countvec = load_pickle('20200525_countvec.pkl')

### train set

In [8]:
num_tokens = 2000
max_seq_length = 96
dataset = ParallelLanguageDataset(
    os.path.join(data_path, 'tags/set.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['train']
)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(dataset))

79


### val set

In [9]:
val_dataset = ParallelLanguageDataset(
    os.path.join(data_path, 'tags/set.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['val']
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(val_dataset))

6


# train

In [10]:
from transformer_translation.model import LanguageTransformer, ReportTransformer2

In [11]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 256
num_encoder_layers = 2
num_decoder_layers = 2
dim_feedforward = 512
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer2(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [12]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [13]:
from tsf_infer_utils import prep_transf_inputs2, infer2
from torchtext.data.metrics import bleu_score
from tsf_utils import format_list_for_bleu, get_bleu_from_loader2

In [14]:
    %%time
    print_every = 15
    num_epochs = 30
    early_stopping_flag = True

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        model.train()
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs2(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      'Train Loss: {:.3f}'.format(total_loss / print_every))
                total_loss = 0
                
        if early_stopping_flag:
            model.eval()
            print(f'Epoch [{epoch + 1} / {num_epochs}]:')
            print('{} BLEU: {:.2%}'.format(
                '\ttrain', get_bleu_from_loader2(model, loader)))
            print('{} BLEU: {:.2%}'.format(
                '\tval', get_bleu_from_loader2(model, val_loader)))
        print('\n')

Epoch [1 / 30] 	 Step [15 / 79] 	 Train Loss: 9.230
Epoch [1 / 30] 	 Step [30 / 79] 	 Train Loss: 9.196
Epoch [1 / 30] 	 Step [45 / 79] 	 Train Loss: 9.097
Epoch [1 / 30] 	 Step [60 / 79] 	 Train Loss: 8.976
Epoch [1 / 30] 	 Step [75 / 79] 	 Train Loss: 8.796
Epoch [1 / 30]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [2 / 30] 	 Step [15 / 79] 	 Train Loss: 8.575
Epoch [2 / 30] 	 Step [30 / 79] 	 Train Loss: 8.379
Epoch [2 / 30] 	 Step [45 / 79] 	 Train Loss: 8.217
Epoch [2 / 30] 	 Step [60 / 79] 	 Train Loss: 7.996
Epoch [2 / 30] 	 Step [75 / 79] 	 Train Loss: 7.806
Epoch [2 / 30]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [3 / 30] 	 Step [15 / 79] 	 Train Loss: 7.461
Epoch [3 / 30] 	 Step [30 / 79] 	 Train Loss: 7.267
Epoch [3 / 30] 	 Step [45 / 79] 	 Train Loss: 6.994
Epoch [3 / 30] 	 Step [60 / 79] 	 Train Loss: 6.765
Epoch [3 / 30] 	 Step [75 / 79] 	 Train Loss: 6.522
Epoch [3 / 30]:
	train BLEU: 0.00%
	val BLEU: 0.00%


Epoch [4 / 30] 	 Step [15 / 79] 	 Train Loss: 6.154
Epoch 

Epoch [26 / 30]:
	train BLEU: 43.90%
	val BLEU: 26.72%


Epoch [27 / 30] 	 Step [15 / 79] 	 Train Loss: 0.994
Epoch [27 / 30] 	 Step [30 / 79] 	 Train Loss: 1.179
Epoch [27 / 30] 	 Step [45 / 79] 	 Train Loss: 1.132
Epoch [27 / 30] 	 Step [60 / 79] 	 Train Loss: 1.217
Epoch [27 / 30] 	 Step [75 / 79] 	 Train Loss: 1.290
Epoch [27 / 30]:
	train BLEU: 45.72%
	val BLEU: 27.13%


Epoch [28 / 30] 	 Step [15 / 79] 	 Train Loss: 1.065
Epoch [28 / 30] 	 Step [30 / 79] 	 Train Loss: 1.010
Epoch [28 / 30] 	 Step [45 / 79] 	 Train Loss: 1.110
Epoch [28 / 30] 	 Step [60 / 79] 	 Train Loss: 1.239
Epoch [28 / 30] 	 Step [75 / 79] 	 Train Loss: 1.161
Epoch [28 / 30]:
	train BLEU: 46.78%
	val BLEU: 27.39%


Epoch [29 / 30] 	 Step [15 / 79] 	 Train Loss: 1.186
Epoch [29 / 30] 	 Step [30 / 79] 	 Train Loss: 1.056
Epoch [29 / 30] 	 Step [45 / 79] 	 Train Loss: 0.969
Epoch [29 / 30] 	 Step [60 / 79] 	 Train Loss: 1.098
Epoch [29 / 30] 	 Step [75 / 79] 	 Train Loss: 0.971
Epoch [29 / 30]:
	train BLEU: 47.6

In [15]:
    %%time
    print_every = 15
    num_epochs = 30
    early_stopping_flag = True

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        model.train()
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs2(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      'Train Loss: {:.3f}'.format(total_loss / print_every))
                total_loss = 0
                
        if early_stopping_flag:
            model.eval()
            print(f'Epoch [{epoch + 1} / {num_epochs}]:')
            print('{} BLEU: {:.2%}'.format(
                '\ttrain', get_bleu_from_loader2(model, loader)))
            print('{} BLEU: {:.2%}'.format(
                '\tval', get_bleu_from_loader2(model, val_loader)))
        print('\n')

Epoch [1 / 30] 	 Step [15 / 79] 	 Train Loss: 0.945
Epoch [1 / 30] 	 Step [30 / 79] 	 Train Loss: 0.892
Epoch [1 / 30] 	 Step [45 / 79] 	 Train Loss: 1.048
Epoch [1 / 30] 	 Step [60 / 79] 	 Train Loss: 0.863
Epoch [1 / 30] 	 Step [75 / 79] 	 Train Loss: 1.075
Epoch [1 / 30]:
	train BLEU: 51.85%
	val BLEU: 27.97%


Epoch [2 / 30] 	 Step [15 / 79] 	 Train Loss: 0.983
Epoch [2 / 30] 	 Step [30 / 79] 	 Train Loss: 1.002
Epoch [2 / 30] 	 Step [45 / 79] 	 Train Loss: 0.880
Epoch [2 / 30] 	 Step [60 / 79] 	 Train Loss: 0.866
Epoch [2 / 30] 	 Step [75 / 79] 	 Train Loss: 0.946
Epoch [2 / 30]:
	train BLEU: 53.16%
	val BLEU: 28.39%


Epoch [3 / 30] 	 Step [15 / 79] 	 Train Loss: 0.842
Epoch [3 / 30] 	 Step [30 / 79] 	 Train Loss: 0.875
Epoch [3 / 30] 	 Step [45 / 79] 	 Train Loss: 0.969
Epoch [3 / 30] 	 Step [60 / 79] 	 Train Loss: 0.834
Epoch [3 / 30] 	 Step [75 / 79] 	 Train Loss: 0.938
Epoch [3 / 30]:
	train BLEU: 53.66%
	val BLEU: 27.57%


Epoch [4 / 30] 	 Step [15 / 79] 	 Train Loss: 0.754


Epoch [26 / 30] 	 Step [75 / 79] 	 Train Loss: 0.457
Epoch [26 / 30]:
	train BLEU: 73.12%
	val BLEU: 29.28%


Epoch [27 / 30] 	 Step [15 / 79] 	 Train Loss: 0.412
Epoch [27 / 30] 	 Step [30 / 79] 	 Train Loss: 0.437
Epoch [27 / 30] 	 Step [45 / 79] 	 Train Loss: 0.429
Epoch [27 / 30] 	 Step [60 / 79] 	 Train Loss: 0.448
Epoch [27 / 30] 	 Step [75 / 79] 	 Train Loss: 0.455
Epoch [27 / 30]:
	train BLEU: 74.32%
	val BLEU: 29.12%


Epoch [28 / 30] 	 Step [15 / 79] 	 Train Loss: 0.385
Epoch [28 / 30] 	 Step [30 / 79] 	 Train Loss: 0.398
Epoch [28 / 30] 	 Step [45 / 79] 	 Train Loss: 0.422
Epoch [28 / 30] 	 Step [60 / 79] 	 Train Loss: 0.418
Epoch [28 / 30] 	 Step [75 / 79] 	 Train Loss: 0.465
Epoch [28 / 30]:
	train BLEU: 73.66%
	val BLEU: 28.74%


Epoch [29 / 30] 	 Step [15 / 79] 	 Train Loss: 0.389
Epoch [29 / 30] 	 Step [30 / 79] 	 Train Loss: 0.389
Epoch [29 / 30] 	 Step [45 / 79] 	 Train Loss: 0.387
Epoch [29 / 30] 	 Step [60 / 79] 	 Train Loss: 0.438
Epoch [29 / 30] 	 Step [75 / 79] 	

# assess perf

### run inference

In [17]:
test_dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length
    ,idxs=splits['test']
    ,countvec = countvec
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
print(len(test_dataset))

6


In [18]:
%%time
pred_list, tgt_list = infer(model, test_loader)

CPU times: user 19.5 s, sys: 1.32 s, total: 20.8 s
Wall time: 906 ms


### compute BLEU

In [23]:
%%time
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
print('{} BLEU: {:.2%}'.format(
                'test', bleu_score(pred_list_bleu, tgt_list_bleu)))

test BLEU: 40.81%
CPU times: user 1.58 s, sys: 20 ms, total: 1.6 s
Wall time: 1.6 s


### create vocab dicts 

In [17]:
tags_index2word = dict(zip(range(len(dataset.countvec.vocabulary)), dataset.countvec.vocabulary))

In [18]:
from transformer_translation.dataset import load_report_voc
reports_index2word = load_report_voc(
    os.path.join(data_path, 'reports', 'voc.pkl'))

### assess nat lg perf

In [19]:
from tsf_utils import print_nl_pred_vs_tgt

In [23]:
import random
k = 10
sel_idx = random.sample(range(len(pred_list)), k=k)

In [21]:
print_nl_pred_vs_tgt(
                [pred_list[i] for i in sel_idx]
                ,[tgt_list[i] for i in sel_idx]
                ,reports_index2word
            )

TARGET:   normal chest no evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .
PREDICTION:  no chest heart evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .


TARGET:   no acute cardiopulmonary abnormality . mediastinal contours are normal . heart size is within normal limits . multiple scattered calcified pulmonary nodules xxxx sequela of prior granulomatous disease . otherwise lungs are clear . . there is no pneumothorax or large pleural effusion . no bony abnormality . .
PREDICTION:   acute cardiopulmonary abnormality . . contours are within . no size is within normal limits . no calcified calcified granulomas nodules are . of prior granulomatous disease . no lungs are clear without no no is no pneumothorax or large pleural effusion . there acute abnormality . .


TARGET:   no 