In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import apex
import csv
import models
import ast
from sklearn.model_selection import train_test_split
from fastprogress.fastprogress import master_bar, progress_bar
from tqdm.auto import trange

from config import config
import dataset_word as data

[nltk_data] Downloading package punkt to /home/krishna/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### Data Loading

In [4]:
def parse_list(input_str):    
    return ast.literal_eval(input_str)

reports = {}

with open(config.cleaned_reports) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    line_count = 0
    for row in csv_reader:
        if line_count == 0:
            line_count += 1
        else:
            uid, problems, findings, impression = row[1:]
            reports[str(uid)] = (parse_list(problems), findings, impression)

In [5]:
def create_report_splits(reports, seed=1337):
    uid_list = list(reports.keys())
    train_uids, valtest_uids = train_test_split(uid_list, test_size=0.2, random_state=seed)
    valid_uids, test_uids = train_test_split(valtest_uids, test_size=0.5, random_state=seed)

    train_reports = {}
    valid_reports = {}
    test_reports = {}
    splits = [train_uids, valid_uids, test_uids]
    output_reports = [train_reports, valid_reports, test_reports]
    
    for i in range(len(splits)):
        for uid in splits[i]:
            output_reports[i][str(uid)] = reports[str(uid)]
            
    return output_reports

train_reports, valid_reports, _ = create_report_splits(reports)
IMAGE_SIZE = 768

train_dataset = data.XRayDataset(
    reports=train_reports,
    transform=transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]
))
train_dataloader = torch.utils.data.dataloader.DataLoader(train_dataset,
                                                          collate_fn=data.collate_fn,
                                                          pin_memory=True,
                                                          shuffle=True,
                                                          batch_size=config.batch_size,
                                                          num_workers=config.batch_size)

valid_dataset = data.XRayDataset(
    reports=valid_reports,
    transform=transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        ]
))
valid_dataloader = torch.utils.data.dataloader.DataLoader(valid_dataset,
                                                          collate_fn=data.collate_fn,
                                                          pin_memory=True,
                                                          shuffle=True,
                                                          batch_size=config.batch_size,
                                                          num_workers=config.batch_size)

### Build Model

In [6]:
memory_format = torch.channels_last
num_classes = len(train_dataset.classes)

encoder = models.EncoderCNN(config.emb_dim, num_classes).to(config.device, memory_format=memory_format)
decoder = models.DecoderRNN_Word(config.emb_dim, config.hidden_dim, train_dataset.tokenizer, config.num_layers).to(config.device, memory_format=memory_format)

classes_loss = torch.nn.BCEWithLogitsLoss()
outputs_loss = torch.nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = apex.optimizers.FusedAdam(params, lr=config.learning_rate)

[encoder, decoder], optimizer = apex.amp.initialize([encoder, decoder], optimizer, opt_level="O1")

Embeddings: 1946 x 300
Loading embedding file: ./vectors/glove.6B.300d.txt
Pre-trained: 1611 (82.79%)
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


### Train Model

In [7]:
def train_one_epoch(dataloader, batch_size, encoder, decoder, classes_loss, outputs_loss, optimizer, train=True):
    total_step = len(dataloader.dataset)//batch_size
    if train:
        encoder.train()
        decoder.train()
    else:
        encoder.eval()
        decoder.eval()
    running_c_loss = torch.Tensor([0.0])
    running_o_loss = torch.Tensor([0.0])
    state_h, state_c = decoder.zero_state(batch_size)
    state_h = state_h.to(config.device, non_blocking=True)
    state_c = state_c.to(config.device, non_blocking=True)
    with torch.set_grad_enabled(train):
        for i, (images, class_labels, captions, lengths) in enumerate(progress_bar(dataloader)):
            images = images.to(config.device, non_blocking=True).contiguous(memory_format=memory_format)
            captions = captions.to(config.device, non_blocking=True)
            class_labels = class_labels.to(config.device, non_blocking=True)
            targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True, enforce_sorted=False)[0]
            encoder.zero_grad()
            decoder.zero_grad()
            logits, features = encoder(images)
            c_loss = classes_loss(logits, class_labels)
            outputs, (state_h, state_c) = decoder(features, captions, lengths, (state_h, state_c))
            o_loss = outputs_loss(outputs, targets)
            if train:
                with apex.amp.scale_loss(c_loss, optimizer) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
                with apex.amp.scale_loss(o_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                state_h = state_h.detach()
                state_c = state_c.detach()
                optimizer.step()
            running_c_loss += c_loss
            running_o_loss += o_loss
            if train and i % 10 == 0:
                print(
                    "train_loss - ",
                    round(float(c_loss.cpu().detach().numpy()), 3),
                    round(float(o_loss.cpu().detach().numpy()), 3),
                    "- perplexity -",
                    round(float(np.exp(o_loss.cpu().detach().numpy())), 3),
                )
    c_loss = float(running_c_loss.item() / total_step)
    o_loss = float(running_o_loss.item() / total_step)
    return c_loss, o_loss

##### Uncomment below code to train

In [8]:
# num_epochs = 5
# print("Start training")

# for epoch in range(num_epochs):
#     print("\nEpoch", epoch+1, "/", num_epochs, ":\n")
#     train_c_loss, train_o_loss = train_one_epoch(train_dataloader, config.batch_size, encoder, decoder, classes_loss, outputs_loss, optimizer, train=True)
#     valid_c_loss, valid_o_loss = train_one_epoch(valid_dataloader, config.batch_size, encoder, decoder, classes_loss, outputs_loss, optimizer, train=False)
#     print("train_loss - ", round(train_c_loss,3),round(train_o_loss,3), "- perplexity -", round(np.exp(train_o_loss),3),
#           "- valid_loss - ", round(valid_c_loss,3),round(valid_o_loss,3), "- perplexity -", round(np.exp(valid_o_loss),3))
# print("Finished training!")

### Save / Load trained model

In [9]:
# torch.save(encoder.state_dict(), 'save/encoder_word.pt')
# torch.save(decoder.state_dict(), 'save/decoder_word.pt')
encoder.load_state_dict(torch.load('save/encoder_word.pt'))
decoder.load_state_dict(torch.load('save/decoder_word.pt'))

<All keys matched successfully>

# Average Precision over class logits

In [10]:
from evaluate import get_class_predictions
y_true, y_pred = get_class_predictions(encoder, valid_dataset)

Running inference...


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=330.0), HTML(value='')), layout=Layout(di…




In [19]:
from evaluate import evaluate_encoder_predictions
recall, precision, AP, mAP = evaluate_encoder_predictions(y_true, y_pred)
print("\n\nRecall")
print(recall.round(3))
print("\n\nPrecision")
print(precision.round(3))
print("\n\nAP")
print(AP.round(3))
print("\n\nmean AP")
print(mAP.round(3))

[  4.   0.   0.  19.   0.   0.  40.   2. 184.   6.  67.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.  44.   0. 312.  29.   0.
   0. 284. 293.   0. 316.  75.   0. 182.  13.   0.  14.   0.   0.   0.
   0.   0.  13.   0.  58.  18.]


Recall
[0.    0.    0.    0.    0.    0.    0.075 0.    0.136 0.167 0.09  0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.045
 0.    0.154 0.069 0.    0.    0.44  0.113 0.    0.123 0.08  0.    0.115
 0.077 0.    0.    0.    0.    0.    0.    0.    0.077 0.    0.069 0.167]


Precision
[0.    0.    0.    0.    0.    0.    0.158 0.    0.893 0.1   0.545 0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.5
 0.    1.    0.2   0.    0.    0.962 1.    0.    1.    0.75  0.    0.913
 0.2   0.    0.    0.    0.    0.    0.    0.    0.111 0.    0.16  0.429]


AP
[0.042 0.03  0.009 0.03  0.006 0.067 0.06  0.018 0.13  0.044 0.064 0.006
 0.024 0.024 0.024 0.052 0.009 0.012 0.021 0.018 0.012 0.009 0.009 0.029
 0.00

  _warn_prf(average, modifier, msg_start, len(result))
