In [1]:
%matplotlib inline

In [2]:
import os
import torch

In [3]:
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
print(device)

cuda:7


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import ParallelLanguageDataset
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [7]:
num_tokens = 2000
max_seq_length = 96
dataset = ParallelLanguageDataset(
    os.path.join(data_path, 'tags/set.pkl')#'en/train.pkl')#
    ,os.path.join(data_path, 'reports/set.pkl')#'fr/train.pkl')#
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

In [8]:
tk1 = [j for i in dataset.data_1 for j in i]
tk2 = [j for i in dataset.data_2 for j in i]

In [9]:
(max(tk1)), (max(tk2))

(589, 1952)

# train

In [25]:
from transformer_translation.model import LanguageTransformer
from transformer_translation.translate_sentence import gen_nopeek_mask

In [11]:
vocab_size = 10000 + 4#1952#
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = LanguageTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [12]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [21]:
    %%time
    print_every = 15
    num_epochs = 10
    model.train()

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask = src[0].to(device), src_key_padding_mask[0].to(device)
            tgt, tgt_key_padding_mask = tgt[0].to(device), tgt_key_padding_mask[0].to(device)

            memory_key_padding_mask = src_key_padding_mask.clone()
            tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to(device)

            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      f'Train Loss: {total_loss / print_every}')
                total_loss = 0
                #print(get_sent_from_tk(
                #    rearrange(outputs, 'b t v -> (b t) v').max(dim=1)[1].unique()
                #    ,reports_index2word))

Epoch [1 / 10] 	 Step [15 / 92] 	 Train Loss: 1.3208194176355998
Epoch [1 / 10] 	 Step [30 / 92] 	 Train Loss: 1.6033766587575278
Epoch [1 / 10] 	 Step [45 / 92] 	 Train Loss: 1.5051896969477336
Epoch [1 / 10] 	 Step [60 / 92] 	 Train Loss: 1.8119399348894756
Epoch [1 / 10] 	 Step [75 / 92] 	 Train Loss: 1.4665009458859761
Epoch [1 / 10] 	 Step [90 / 92] 	 Train Loss: 1.6129396160443623
Epoch [2 / 10] 	 Step [15 / 92] 	 Train Loss: 1.4453423062960307
Epoch [2 / 10] 	 Step [30 / 92] 	 Train Loss: 1.3141321460405986
Epoch [2 / 10] 	 Step [45 / 92] 	 Train Loss: 1.4676307280858358
Epoch [2 / 10] 	 Step [60 / 92] 	 Train Loss: 1.6809895277023315
Epoch [2 / 10] 	 Step [75 / 92] 	 Train Loss: 1.5172077020009358
Epoch [2 / 10] 	 Step [90 / 92] 	 Train Loss: 1.5481433153152466
Epoch [3 / 10] 	 Step [15 / 92] 	 Train Loss: 1.3572155793507894
Epoch [3 / 10] 	 Step [30 / 92] 	 Train Loss: 1.3383694112300872
Epoch [3 / 10] 	 Step [45 / 92] 	 Train Loss: 1.4461939374605814
Epoch [3 / 10] 	 Step [60

# get nat lg

In [16]:
import pickle
with open(os.path.join(data_path, 'tags', 'voc.pkl'), 'rb') as f:
    tags_word2index = pickle.load(f)
tags_index2word = dict(zip(tags_word2index.values(), tags_word2index.keys()))
tags_index2word[0], tags_index2word[1] = "SOS", "EOS"
    
with open(os.path.join(data_path, 'reports', 'voc.pkl'), 'rb') as f:
    reports_word2index = pickle.load(f)
reports_index2word = dict(zip(reports_word2index.values(), reports_word2index.keys()))
reports_index2word[0], reports_index2word[1] = "SOS", "EOS"

In [17]:
def get_sent_from_tk(tensor_tk, index2word):
    return [index2word[idx.item()] for idx in tensor_tk]

In [18]:
get_sent_from_tk(
    rearrange(tgt, 'b o -> (b o)')
    ,reports_index2word)

['',
 'no',
 'acute',
 'cardiopulmonary',
 'abnormality',
 '.',
 'normal',
 'heart',
 'size',
 'and',
 'mediastinal',
 'contours',
 '.',
 'no',
 'focal',
 'airspace',
 'consolidation',
 '.',
 'no',
 'pleural',
 'effusion',
 'or',
 'pneumothorax',
 '.',
 'chronic',
 'appearing',
 'right',
 'mid',
 'clavicle',
 'injury',
 '.',
 'visualized',
 'bony',
 'structures',
 'otherwise',
 'unremarkable',
 '.',
 '.',
 'SOS',
 '',
 'no',
 'acute',
 'cardiopulmonary',
 'disease',
 '.',
 '.',
 'the',
 'cardiomediastinal',
 'silhouette',
 'is',
 'normal',
 'size',
 'and',
 'configuration',
 '.',
 'pulmonary',
 'vasculature',
 'within',
 'normal',
 'limits',
 '.',
 'the',
 'lungs',
 'are',
 'well',
 'aerated',
 '.',
 'there',
 'is',
 'no',
 'pneumothorax',
 'pleural',
 'effusion',
 'or',
 'focal',
 'consolidation',
 '.',
 '.',
 '',
 'no',
 'acute',
 'cardiopulmonary',
 'disease',
 '.',
 '.',
 'the',
 'cardiomediastinal',
 'silhouette',
 'is',
 'normal',
 'size',
 'and',
 'configuration',
 '.',
 'pulmon

In [71]:
get_sent_from_tk(
    rearrange(src, 'b o -> (b o)')
    ,tags_index2word)

['infiltrates',
 'cardiomegaly',
 'granuloma',
 'granuloma',
 'aorta',
 'left_ventricle',
 'atelectases',
 'atelectasis',
 'granuloma',
 'granuloma',
 'granuloma',
 'granuloma',
 'rib_fracture',
 'rib_fractures',
 'plate_like_atelectasis',
 'pulmonary_atelectasis',
 'atelectases',
 'atelectasis',
 'lymph',
 'lymph_nodes',
 'scarring',
 'scarring',
 'scolioses',
 'scoliosis',
 'tortuous_aorta',
 'aorta',
 'pneumonia',
 'pneumonia',
 'sternotomy',
 'sternotomy',
 'hiatal_hernia',
 'hiatal_hernia',
 'spinal_osteophytosis',
 'thoracic_spondylosis',
 'clip',
 'gallbladder',
 'right_lower_lobe_pneumonia',
 'pneumonia',
 'granuloma',
 'granuloma',
 'aorta',
 'aortic_diseases',
 'apical_granuloma',
 'granuloma',
 'spondylosis',
 'spondylosis',
 'spondylosis',
 'spondylosis',
 'atelectases',
 'atelectasis',
 'tortuous_aorta',
 'aorta',
 'heart_valve_prosthesis',
 'heart_valve_prosthesis_implantation',
 'granulomatous_disease',
 'thoracic_spondylosis',
 'aorta_tortuous',
 'aorta',
 'atherosclero

In [70]:
get_sent_from_tk(
    rearrange(outputs, 'b t v -> (b t) v').max(dim=1)[1].unique()
    ,reports_index2word)

['.']