In [4]:
%matplotlib inline

In [5]:
import os
import torch
from einops import rearrange

In [6]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)

cuda:2


In [7]:
%load_ext autoreload
%autoreload 2

# load data

In [10]:
from transformer_translation.dataset import TagReportDataset
from torch.utils.data import DataLoader

In [11]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [12]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# train

In [13]:
from transformer_translation.model import LanguageTransformer, ReportTransformer
from transformer_translation.translate_sentence import gen_nopeek_mask

In [14]:
vocab_size = 10000 + 4#1952#
nhead = 8
d_model = 587 - (587 % nhead) + nhead
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = ReportTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [15]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [21]:
    %%time
    print_every = 15
    num_epochs = 2
    model.train()

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask = src.to(device), src_key_padding_mask[0].to(device)
            tgt, tgt_key_padding_mask = tgt[0].to(device), tgt_key_padding_mask[0].to(device)

            memory_key_padding_mask = src_key_padding_mask.clone()
            tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to(device)
            
            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      f'Train Loss: {total_loss / print_every}')
                total_loss = 0
                try:
                    print(get_sent_from_tk(
                        rearrange(outputs, 'b t v -> (b t) v').max(dim=1)[1].unique()
                        ,reports_index2word))
                except:
                    print('OOV')
                    pass

Epoch [1 / 2] 	 Step [15 / 98] 	 Train Loss: 3.3642006238301594
['.', 'right', 'no', 'acute', 'cardiopulmonary', 'abnormality', 'heart', 'size', 'are', 'normal', 'limits', 'pulmonary', 'is', 'focal', 'airspace', 'consolidation', 'pleural', 'effusion', 'or', 'pneumothorax', 'changes', 'the', 'spine', 'of', 'disease', 'xxxx', 'silhouette', 'clear', 'lungs']
Epoch [1 / 2] 	 Step [30 / 98] 	 Train Loss: 3.4836791515350343
['', '.', 'right', 'no', 'acute', 'cardiopulmonary', 'abnormality', 'heart', 'size', 'and', 'are', 'within', 'normal', 'limits', 'pulmonary', 'is', 'focal', 'airspace', 'consolidation', 'there', 'pleural', 'effusion', 'or', 'pneumothorax', 'changes', 'the', 'spine', 'of', 'lung', 'disease', 'lobe', 'xxxx', 'silhouette', 'clear', 'lungs', 'vascularity', 'structures']
Epoch [1 / 2] 	 Step [45 / 98] 	 Train Loss: 3.2779112180074055
['.', 'no', 'acute', 'cardiopulmonary', 'abnormality', 'heart', 'size', 'and', 'are', 'within', 'normal', 'limits', 'pulmonary', 'is', 'focal', '

# get nat lg

In [16]:
import pickle
with open(os.path.join(data_path, 'tags', 'voc.pkl'), 'rb') as f:
    tags_word2index = pickle.load(f)
tags_index2word = dict(zip(tags_word2index.values(), tags_word2index.keys()))
tags_index2word[0], tags_index2word[1] = "SOS", "EOS"
    
with open(os.path.join(data_path, 'reports', 'voc.pkl'), 'rb') as f:
    reports_word2index = pickle.load(f)
reports_index2word = dict(zip(reports_word2index.values(), reports_word2index.keys()))
reports_index2word[0], reports_index2word[1] = "SOS", "EOS"

In [17]:
def get_sent_from_tk(tensor_tk, index2word):
    return [index2word[idx.item()] for idx in tensor_tk]

In [17]:
get_sent_from_tk(
    rearrange(tgt, 'b o -> (b o)')
    ,reports_index2word)

['',
 'negative',
 'chest',
 'x',
 'xxxx',
 '.',
 'cardiac',
 'and',
 'mediastinal',
 'contours',
 'are',
 'within',
 'normal',
 'limits',
 '.',
 'prior',
 'granulomatous',
 'disease',
 '.',
 'the',
 'lungs',
 'are',
 'otherwise',
 'clear',
 '.',
 'bony',
 'structures',
 'are',
 'intact',
 '.',
 '.',
 '',
 'no',
 'acute',
 'cardiopulmonary',
 'abnormality',
 '.',
 'heart',
 'size',
 'and',
 'mediastinal',
 'contour',
 'within',
 'normal',
 'limits',
 '.',
 'no',
 'focal',
 'airspace',
 'consolidation',
 'pneumothorax',
 'or',
 'large',
 'pleural',
 'effusion',
 '.',
 'no',
 'acute',
 'osseous',
 'abnormality',
 '.',
 '.',
 '',
 'no',
 'acute',
 'cardiopulmonary',
 'finding',
 '.',
 'the',
 'heart',
 'and',
 'cardiomediastinal',
 'silhouette',
 'are',
 'normal',
 '.',
 'there',
 'is',
 'no',
 'focal',
 'airspace',
 'opacity',
 'pleural',
 'effusion',
 'pneumothorax',
 '.',
 'the',
 'osseous',
 'structures',
 'are',
 'intact',
 '.',
 '.',
 '',
 'no',
 'acute',
 'cardiopulmonary',
 'abnor

In [71]:
get_sent_from_tk(
    rearrange(src, 'b o -> (b o)')
    ,tags_index2word)

['infiltrates',
 'cardiomegaly',
 'granuloma',
 'granuloma',
 'aorta',
 'left_ventricle',
 'atelectases',
 'atelectasis',
 'granuloma',
 'granuloma',
 'granuloma',
 'granuloma',
 'rib_fracture',
 'rib_fractures',
 'plate_like_atelectasis',
 'pulmonary_atelectasis',
 'atelectases',
 'atelectasis',
 'lymph',
 'lymph_nodes',
 'scarring',
 'scarring',
 'scolioses',
 'scoliosis',
 'tortuous_aorta',
 'aorta',
 'pneumonia',
 'pneumonia',
 'sternotomy',
 'sternotomy',
 'hiatal_hernia',
 'hiatal_hernia',
 'spinal_osteophytosis',
 'thoracic_spondylosis',
 'clip',
 'gallbladder',
 'right_lower_lobe_pneumonia',
 'pneumonia',
 'granuloma',
 'granuloma',
 'aorta',
 'aortic_diseases',
 'apical_granuloma',
 'granuloma',
 'spondylosis',
 'spondylosis',
 'spondylosis',
 'spondylosis',
 'atelectases',
 'atelectasis',
 'tortuous_aorta',
 'aorta',
 'heart_valve_prosthesis',
 'heart_valve_prosthesis_implantation',
 'granulomatous_disease',
 'thoracic_spondylosis',
 'aorta_tortuous',
 'aorta',
 'atherosclero

592

In [70]:
get_sent_from_tk(
    rearrange(outputs, 'b t v -> (b t) v').max(dim=1)[1].unique()
    ,reports_index2word)

['.']

In [75]:
592/8

74.0