In [1]:
%matplotlib inline

In [2]:
import os
import torch
from einops import rearrange

In [3]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)

cuda:2


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import TagReportDataset
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [22]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# train

In [23]:
from transformer_translation.model import LanguageTransformer, ReportTransformer

In [24]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [25]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [50]:
from tsf_train_utils import prep_transf_inputs

In [26]:
    %%time
    print_every = 15
    num_epochs = 20
    model.train()

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      f'Train Loss: {total_loss / print_every}')
                total_loss = 0

Epoch [1 / 20] 	 Step [15 / 98] 	 Train Loss: 9.219210116068522
Epoch [1 / 20] 	 Step [30 / 98] 	 Train Loss: 8.983137957255046
Epoch [1 / 20] 	 Step [45 / 98] 	 Train Loss: 8.581922403971355
Epoch [1 / 20] 	 Step [60 / 98] 	 Train Loss: 8.177903493245443
Epoch [1 / 20] 	 Step [75 / 98] 	 Train Loss: 7.957618872324626
Epoch [1 / 20] 	 Step [90 / 98] 	 Train Loss: 7.712112903594971
Epoch [2 / 20] 	 Step [15 / 98] 	 Train Loss: 7.318874835968018
Epoch [2 / 20] 	 Step [30 / 98] 	 Train Loss: 7.026069831848145
Epoch [2 / 20] 	 Step [45 / 98] 	 Train Loss: 6.909017086029053
Epoch [2 / 20] 	 Step [60 / 98] 	 Train Loss: 6.666585063934326
Epoch [2 / 20] 	 Step [75 / 98] 	 Train Loss: 6.358024470011393
Epoch [2 / 20] 	 Step [90 / 98] 	 Train Loss: 6.102291520436605
Epoch [3 / 20] 	 Step [15 / 98] 	 Train Loss: 5.751453653971354
Epoch [3 / 20] 	 Step [30 / 98] 	 Train Loss: 5.386170546213786
Epoch [3 / 20] 	 Step [45 / 98] 	 Train Loss: 5.143763033548991
Epoch [3 / 20] 	 Step [60 / 98] 	 Train 

# assess nat lg perf

In [36]:
tags_index2word = dict(zip(range(len(dataset.countvec.vocabulary)), dataset.countvec.vocabulary))

In [43]:
from transformer_translation.dataset import load_report_voc
from utils import get_sent_from_tk
reports_index2word = load_report_voc(os.path.join(data_path, 'reports', 'voc.pkl'))

In [51]:
            (src, src_key_padding_mask, tgt, tgt_key_padding_mask) = next(iter(loader))
            
            src, src_key_padding_mask, tgt, tgt_key_padding_mask, memory_key_padding_mask, tgt_inp, tgt_out, tgt_mask = prep_transf_inputs(
                src, src_key_padding_mask, tgt, tgt_key_padding_mask, device)
            
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            
            # print nat lg along w/ target
            outputs = outputs.max(dim=2)[1]
            for row in range(outputs.shape[0]):
                
                # print target
                nl_tgt = get_sent_from_tk(
                    tgt[row, :]
                    ,reports_index2word)
                print('TARGET: ',' '.join(nl_tgt))
                
                # print predictions
                try:
                    nl_outputs = get_sent_from_tk(
                                outputs[row, :]
                                ,reports_index2word)
                    print('PREDICTION: ',' '.join(nl_outputs))
                except:
                    print('OOV')
                    pass
                print('\n')

TARGET:   emphysema and scarring without acute disease the heart is normal in size . the mediastinum is unremarkable . the lungs are hyperinflated with xxxx xxxx opacities compatible with pleural parenchymal scarring . there is no acute infiltrate or effusion . .
PREDICTION:   and scarring are acute cardiopulmonary . heart is normal in size . the mediastinum is unremarkable . the lungs are hyperinflated . chronic opacities opacities compatible with scarring scarring scarring . no is no focal infiltrate or pleural . there


TARGET:   moderate cardiomegaly with pulmonary vascular congestion early interstitial edema . there is moderate cardiomegaly . there are bilateral interstitial opacities increased since the previous exam . no focal airspace consolidation pleural effusions or pneumothorax . no acute bony abnormalities . .
PREDICTION:   to . pulmonary edema congestion and interstitial edema . chronic is mild to . mild is interstitial interstitial opacities . interstitial the previous e