In [1]:
%matplotlib inline

In [2]:
import os
import torch
from einops import rearrange

In [3]:
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
print(device)

cuda:7


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import TagReportDataset
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [7]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# train

In [8]:
from transformer_translation.model import LanguageTransformer, ReportTransformer

In [9]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [10]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [11]:
from tsf_train_utils import prep_transf_inputs

In [12]:
    %%time
    print_every = 15
    num_epochs = 20
    model.train()

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      f'Train Loss: {total_loss / print_every}')
                total_loss = 0

Epoch [1 / 20] 	 Step [15 / 98] 	 Train Loss: 9.233571688334147
Epoch [1 / 20] 	 Step [30 / 98] 	 Train Loss: 8.908067321777343
Epoch [1 / 20] 	 Step [45 / 98] 	 Train Loss: 8.530613327026368
Epoch [1 / 20] 	 Step [60 / 98] 	 Train Loss: 8.099655564626058
Epoch [1 / 20] 	 Step [75 / 98] 	 Train Loss: 7.864101696014404
Epoch [1 / 20] 	 Step [90 / 98] 	 Train Loss: 7.6330458958943685
Epoch [2 / 20] 	 Step [15 / 98] 	 Train Loss: 7.228667704264323
Epoch [2 / 20] 	 Step [30 / 98] 	 Train Loss: 7.1179145812988285
Epoch [2 / 20] 	 Step [45 / 98] 	 Train Loss: 6.805792840321859
Epoch [2 / 20] 	 Step [60 / 98] 	 Train Loss: 6.505191167195638
Epoch [2 / 20] 	 Step [75 / 98] 	 Train Loss: 6.318790435791016
Epoch [2 / 20] 	 Step [90 / 98] 	 Train Loss: 5.965014330546061
Epoch [3 / 20] 	 Step [15 / 98] 	 Train Loss: 5.686200396219889
Epoch [3 / 20] 	 Step [30 / 98] 	 Train Loss: 5.332015514373779
Epoch [3 / 20] 	 Step [45 / 98] 	 Train Loss: 5.136171913146972
Epoch [3 / 20] 	 Step [60 / 98] 	 Trai

# assess perf

### run inference

In [13]:
from tsf_infer_utils import infer

In [14]:
%%time
pred_list, tgt_list = infer(model, loader)

CPU times: user 2min 52s, sys: 6.38 s, total: 2min 58s
Wall time: 4.36 s


### compute BLEU

In [15]:
from torchtext.data.metrics import bleu_score
from tsf_utils import format_list_for_bleu

In [16]:
%%time
pred_list_bleu, tgt_list_bleu = format_list_for_bleu(pred_list, tgt_list)
bleu_score(pred_list_bleu, tgt_list_bleu)

CPU times: user 18.2 s, sys: 80 ms, total: 18.2 s
Wall time: 18.2 s


0.4738668307737861

### create vocab dicts 

In [17]:
tags_index2word = dict(zip(range(len(dataset.countvec.vocabulary)), dataset.countvec.vocabulary))

In [18]:
from transformer_translation.dataset import load_report_voc
reports_index2word = load_report_voc(
    os.path.join(data_path, 'reports', 'voc.pkl'))

### assess nat lg perf

In [19]:
from tsf_utils import print_nl_pred_vs_tgt

In [20]:
import random
k = 10
sel_idx = random.choices(range(len(pred_list)), k=k)

In [21]:
print_nl_pred_vs_tgt(
                [pred_list[i] for i in sel_idx]
                ,[tgt_list[i] for i in sel_idx]
                ,reports_index2word
            )

TARGET:   normal chest no evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .
PREDICTION:  no chest heart evidence of tuberculosis heart size normal . lungs are clear . xxxx are normal . no pneumonia effusions edema pneumothorax adenopathy nodules or masses . .


TARGET:   no acute cardiopulmonary abnormality . mediastinal contours are normal . heart size is within normal limits . multiple scattered calcified pulmonary nodules xxxx sequela of prior granulomatous disease . otherwise lungs are clear . . there is no pneumothorax or large pleural effusion . no bony abnormality . .
PREDICTION:   acute cardiopulmonary abnormality . . contours are within . no size is within normal limits . no calcified calcified granulomas nodules are . of prior granulomatous disease . no lungs are clear without no no is no pneumothorax or large pleural effusion . there acute abnormality . .


TARGET:   no 