In [1]:
%load_ext autoreload

In [2]:
%autoreload 2
import sys
sys.path.append('..')
import torch
import collections
from torch.utils.data import DataLoader
from torchvision import transforms

import tqdm

import pytorch_lightning as pl
import json
import pandas as pd

from ecgnet.utils.transforms import ToTensor, ApplyGain, Resample

from EcgCaptionGenerator.utils.dataset import collate_fn, CaptionDataset
from EcgCaptionGenerator.utils.pycocoevalcap.eval import COCOEvalCap


from EcgCaptionGenerator.systems.top_down_attention_lstm import TopDownLSTM

from EcgCaptionGenerator.systems.topic_unchanged_decoder import TopicSimDecoder
from EcgCaptionGenerator.systems.topic_transformer import TopicTransformer

from EcgCaptionGenerator.systems.transformer import Transformer
from EcgCaptionGenerator.util import get_loaders
from EcgCaptionGenerator.network.utils_model import beam_search



In [3]:
use_topic_model = False
use_transformer = False
use_fully_transformer = False
use_topic_transform = False
use_topic = False
basedir = './training/captioning/models/'
# checkpoint_loc = basedir + 'top_down_lstm/CAP1-4/checkpoints/epoch=13-step=13495.ckpt'

# checkpoint_loc, param_file = basedir + 'top_down_lstm/CAP1-82/checkpoints/epoch=14-step=62174.ckpt', 'config_muse.json' # Muse 6.3 val
# checkpoint_loc, param_file = basedir + 'top_down_lstm/CAP1-75/checkpoints/epoch=17-step=8333.ckpt', 'config.json' # Consults 27 val

use_topic_model = True
# checkpoint_loc, use_topic, param_file = basedir + 'topic/TOP-68/checkpoints/epoch=10-step=45594.ckpt', True, 'config_topic_muse.json' # Muse 4.3
checkpoint_loc, use_topic, param_file = basedir + 'topic/TOP-74/checkpoints/epoch=12-step=10009.ckpt', True, 'config_topic_consult.json' # consults 28

# checkpoint_loc, param_file, use_transformer ="./training/transformer/models/transformer/TRAN-48/checkpoints/epoch=16-step=70464.ckpt", 'config_transformer_muse.json', True # 6.08
# checkpoint_loc, param_file, use_transformer ="./training/transformer/models/transformer/TRAN-52/checkpoints/epoch=17-step=13859.ckpt", 'config_transformer_consult.json', True # 6.08

# use_topic, use_topic_transform = True, True
# checkpoint_loc, param_file ="./training/transformertopic/models/tansformertopic/TRAN1-2/checkpoints/epoch=16-step=70464.ckpt", 'config_transformer_muse.json' # 6.08
# checkpoint_loc, param_file ="./training/transformertopic/models/tansformertopic/TRAN1-9/checkpoints/epoch=14-step=11549.ckpt", 'config_transformer_consult_topic.json' # 6.08


# checkpoint_loc, param_file ="./training/transformer/models/transformer/TRAN-48/checkpoints/epoch=16-step=70464.ckpt", 'config_transformer_muse.json', True # 6.08
# checkpoint_loc, param_file ="./training/transformer/models/transformer/TRAN-52/checkpoints/epoch=17-step=13859.ckpt", 'config_transformer_consult.json', True # 6.08


  and should_run_async(code)


In [4]:
pl.seed_everything(1234)
params = json.load(open(param_file, 'r'))

transform = transforms.Compose([Resample(500), ToTensor(), ApplyGain()])

if use_topic_model:
    model = TopicSimDecoder.load_from_checkpoint(checkpoint_path=checkpoint_loc).cuda()
elif use_transformer:
    model = Transformer.load_from_checkpoint(checkpoint_path=checkpoint_loc).cuda()
elif use_topic_transform:
#     model = TopicTransformer.load_from_checkpoint(checkpoint_path=checkpoint_loc)
#     vocab = model.vocab
    model = TopicTransformer.load_from_checkpoint(checkpoint_path=checkpoint_loc).cuda()
else:
    model = TopDownLSTM.load_from_checkpoint(checkpoint_path=checkpoint_loc).cuda()


threshold, is_train, vocab = 0, False, model.vocab

testset_df = pd.read_csv(params['test_labels_csv'], index_col=0)

testset = CaptionDataset(threshold, is_train, vocab,None, use_topic, 'umcu', params['data_dir'], testset_df, 
                                transform=transform)
gts = testset_df.apply(lambda x: {x['TestID']: [x['Label']]}, axis=1).to_list()
gts = {list(dict_item.keys())[0]: list(dict_item.values())[0][0] for dict_item in gts}
test_loader = DataLoader(testset, batch_size=64,
                            num_workers=4, collate_fn=collate_fn)
# max_length=50
# model.eval()

In [7]:
testset_df[testset_df["PseudoID"] in []]

Unnamed: 0,PseudoID,TestID,ecg_description,original_des,Diagnosis,OriginalDiagnosis,SampleBase,Gain,Label
1,2fa58392b5321aa,aa595c5ec6c6003,sinusritme vrijwel verticale stand van de gemi...,sinusritme vrijwel verticale stand van de gemi...,sinusritme vrijwel verticale stand van de gemi...,sinusritme,500,4.88,sinusritme vrijwel verticale stand van de gemi...
29,5da308a0da624e1,7cd7e2ac6bbd9bd,sinusritme <VENTRICULARRATE> /min pq <PRINTERV...,"sinusritme 70 /min pq 0,16 sec qrs 0,08 sec in...",sinusritme <VENTRICULARRATE> /min pq <PRINTERV...,sinusritme infero-posterior infarct onbepaalde...,500,4.88,sinusritme <VENTRICULARRATE> /min pq <PRINTERV...
33,f34378182544447,282e296d6d56aba,sinusritme met normale morfologie van de p-top...,sinusritme met normale morfologie van de p-top...,sinusritme met normale morfologie van de p-top...,sinusritme,500,4.88,sinusritme met normale morfologie van de p-top...
34,8e54203475ed507,32bdb461a47a28d,sinusbradycardie <VENTRICULARRATE> /min interm...,sinusbradycardie 55 /min intermediaire elektri...,sinusbradycardie <VENTRICULARRATE> /min interm...,sinusbradycardie linker bundeltakblock,500,4.88,sinusbradycardie <VENTRICULARRATE> /min interm...
41,a612bf0f93594bc,d9b5ffc7649bdd9,sinusritme intermediaire elektrische hartas vl...,sinusritme intermediaire elektrische hartas vl...,sinusritme intermediaire elektrische hartas vl...,sinusbradycardie linker ventrikelhypertrofie v...,500,4.88,sinusritme intermediaire elektrische hartas vl...
...,...,...,...,...,...,...,...,...,...
68146,43d3244c3179ff6,5456ee2e15d7b0b,sinusritme <VENTRICULARRATE> /min linker hart-...,sinusritme 64 /min linker hart-as lahb 1e graa...,sinusritme <VENTRICULARRATE> /min linker hart-...,sinusritme met 1e graads av-block linker anter...,500,4.88,sinusritme <VENTRICULARRATE> /min linker hart-...
68163,8a8269d2c7c1079,268c9ef8bdf2a85,afib vr <VENTRICULARRATE> /min kwam net binnen...,afib vr 107 /min kwam net binnenlopen,afib vr <VENTRICULARRATE> /min kwam net binnen...,atriumfibrilleren non-specifieke st-afwijking,500,4.88,afib vr <VENTRICULARRATE> /min kwam net binnen...
68193,210c047f3b589f7,066a9907ec37f65,sinusritme <VENTRICULARRATE> /min intermediair...,sinusritme 91 /min intermediaire hartas pq 140...,sinusritme <VENTRICULARRATE> /min intermediair...,sinusritme met atrium extrasystole,500,4.88,sinusritme <VENTRICULARRATE> /min intermediair...
68200,2db918615256455,f840608e373ddf9,sinusritme met ie graads av block en wisselend...,sinusritme met ie graads av block en wisselend...,sinusritme met ie graads av block en wisselend...,av dubbel-gepacet ritme frequent ventriculair-...,500,4.88,sinusritme met ie graads av block en wisselend...


In [5]:
# top-down attention lstm
settings = [
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.9, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':None, 'k':640, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.7, 'k':40, 'p':None, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':0.95, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':0.2},
    {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None},
]

for s in settings:
    gts = {}
    res = {}
    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
        waveforms, _, _, ids, targets, _, topic = batch
        tags, (words, props) = model.sample(waveforms.cuda(), ids, s)
        truth = model.vocab.decode(targets)
        for i in range(waveforms.shape[0]):
            gts[ids[i]] = [truth[i]]
#             print(words)
            res.update(words)
    
#     print(gts, res)

    gts = collections.OrderedDict(sorted(gts.items()))
    res = collections.OrderedDict(sorted(res.items()))
    
    pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
    pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

    COCOEval = COCOEvalCap()
    COCOEval.evaluate(gts, res)
    print(s, COCOEval.eval)

100%|██████████| 104/104 [00:16<00:00,  6.35it/s]


setting up scorers...
computing Bleu score...
{'testlen': 90130, 'reflen': 94489, 'guess': [90130, 84112, 78094, 72076], 'correct': [35493, 19039, 11535, 6081]}
ratio: 0.9538676459693619
Bleu_1: 0.375
Bleu_2: 0.284
Bleu_3: 0.225
Bleu_4: 0.174
computing METEOR score...
METEOR: 0.213
computing Rouge score...
ROUGE_L: 0.373
computing CIDEr score...
CIDEr: 0.618
{'temp': None, 'k': None, 'p': None, 'greedy': True, 'm': None} {'Bleu_1': 0.37520563113419675, 'Bleu_2': 0.2844629868440683, 'Bleu_3': 0.2249827170620265, 'Bleu_4': 0.17394299841303934, 'METEOR': 0.2134236076128138, 'ROUGE_L': 0.37305702853126854, 'CIDEr': 0.6181676585074518}


In [None]:
# For transformer topic
settings = [
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.9, 'k':None, 'p':None, 'greedy':None, 'm':None},
# #     {'temp':None, 'k':640, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.7, 'k':40, 'p':None, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':0.95, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':0.2},
    {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None},
]

for sample_method in settings:
    gts = {}
    res = {}
    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
        waveforms, _, _, ids, targets, _, topic = batch
        words = model.sample(waveforms, sample_method, max_length)

        generated = model.vocab.decode(words, skip_first=False)
        truth = model.vocab.decode(targets)
        for i in range(waveforms.shape[0]):
            res[ids[i]] = [generated[i]]
            gts[ids[i]] = [truth[i]]
#         print(res, gts)

#     pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
#     pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

    COCOEval = COCOEvalCap()
    COCOEval.evaluate(gts, res)
#     print(sample_method, COCOEval.eval)
    print(sample_method, COCOEval.eval)
# print(gts, res)

In [None]:
# not sure
settings = [
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.9, 'k':None, 'p':None, 'greedy':None, 'm':None},
# #     {'temp':None, 'k':640, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.7, 'k':40, 'p':None, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':0.95, 'greedy':None, 'm':None},
# #     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':0.2},
    {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None},
]

for sample_method in settings:
    gts = {}
    res = {}
    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
        waveforms, _, _, ids, targets, _, topic = batch
        words = model.sample(waveforms.cuda(), sample_method, max_length)

        generated = model.vocab.decode(words, skip_first=False)
        truth = model.vocab.decode(targets)
        for i in range(waveforms.shape[0]):
            res[ids[i]] = [generated[i]]
            gts[ids[i]] = [truth[i]]
#         print(res, gts)

#     pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
#     pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

    COCOEval = COCOEvalCap()
    COCOEval.evaluate(gts, res)
    print(sample_method, COCOEval.eval)
#     print(s, COCOEval.eval)
# print(gts, res)

In [None]:
# Transformer
max_length = 50
settings = [
    {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None},
]
# settings = [
#     {'temp':None, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.9, 'k':None, 'p':None, 'greedy':None, 'm':None},
#     {'temp':0.7, 'k':40, 'p':None, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':0.95, 'greedy':None, 'm':None},
#     {'temp':None, 'k':None, 'p':None, 'greedy':True, 'm':None},
# ]

for sample_method in settings:
    gts = {}
    res = {}
    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
        with torch.no_grad():
            waveforms, _, _, ids, targets, _ = batch
            words = model.sample(waveforms, sample_method, max_length)

            generated = model.vocab.decode(words, skip_first=False)
            truth = model.vocab.decode(targets)
            for i in range(waveforms.shape[0]):
                res[ids[i]] = [generated[i]]
                gts[ids[i]] = [truth[i]]
    print(res, gts)

#     pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
#     pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

    COCOEval = COCOEvalCap()
    COCOEval.evaluate(gts, res)
    print(sample_method, COCOEval.eval)
    
        

In [25]:
# Top-down attention lstm beam
gts = {}
res = {}
for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
    waveforms, _, _, ids, targets, lengths = batch
    max_length = 50
    beam_size = 5
    vocab = model.vocab
    # im features:
    _, image_feats = model.model(waveforms.cuda())

    out, _ = beam_search(model.language_model, image_feats, max_length, vocab.word2idx['<end>'], beam_size, out_size=1)
    generated = model.vocab.decode(out.view(-1, max_length))
    truth = model.vocab.decode(targets)
    for i in range(waveforms.shape[0]):
        res[ids[i]] = [generated[i]]
        gts[ids[i]] = [truth[i]]
    
pd.DataFrame(gts).to_csv(checkpoint_loc[:-5] + 'gts_.csv')
pd.DataFrame(res).to_csv(checkpoint_loc[:-5] + 'res_.csv')

COCOEval = COCOEvalCap()
COCOEval.evaluate(gts, res)
print(s, COCOEval.eval)
    

100%|██████████| 104/104 [00:46<00:00,  2.24it/s]


setting up scorers...
computing Bleu score...
{'testlen': 86959, 'reflen': 94487, 'guess': [86959, 80941, 74923, 68905], 'correct': [26455, 11074, 4579, 2709]}
ratio: 0.9203276641230971
Bleu_1: 0.279
Bleu_2: 0.187
Bleu_3: 0.125
Bleu_4: 0.092
computing METEOR score...
METEOR: 0.163
computing Rouge score...
ROUGE_L: 0.279
computing CIDEr score...
CIDEr: 0.252
{'temp': None, 'k': None, 'p': None, 'greedy': True, 'm': None} {'Bleu_1': 0.278995088957361, 'Bleu_2': 0.18709747259687357, 'Bleu_3': 0.12518862960945387, 'Bleu_4': 0.09170944462619122, 'METEOR': 0.16289310871316756, 'ROUGE_L': 0.27869807041555844, 'CIDEr': 0.25231225498386384}


In [18]:
testset_df = pd.read_csv(params['test_labels_csv'], index_col=0)
testset_df.OriginalDiagnosis = testset_df.OriginalDiagnosis.fillna('')

baseline_diagnosis = testset_df.apply(lambda x: {x['TestID']: [x['OriginalDiagnosis']]}, axis=1).to_list()
truth_diagnosis = testset_df.apply(lambda x: {x['TestID']: [x['Label']]}, axis=1).to_list()

ref = {list(dict_item.keys())[0]: list(dict_item.values())[0] for dict_item in baseline_diagnosis}
# keys = list(res.keys())
# gts = {list(dict_item.keys())[0] : list(dict_item.values())[0] for dict_item in truth_diagnosis if list(dict_item.keys())[0] in keys}
gts = {list(dict_item.keys())[0]: list(dict_item.values())[0] for dict_item in truth_diagnosis}
gts = collections.OrderedDict(sorted(gts.items()))
ref = collections.OrderedDict(sorted(ref.items()))

COCOEval = COCOEvalCap()
COCOEval.evaluate(gts, ref)

setting up scorers...
computing Bleu score...
{'testlen': 38236, 'reflen': 95511, 'guess': [38236, 32218, 27593, 23264], 'correct': [3783, 299, 40, 4]}
ratio: 0.40033085194375095
Bleu_1: 0.022
Bleu_2: 0.007
Bleu_3: 0.002
Bleu_4: 0.001
computing METEOR score...
METEOR: 0.019
computing Rouge score...
ROUGE_L: 0.050
computing CIDEr score...
CIDEr: 0.045
