In [6]:
import json
import os
from torch.utils.data import Dataset
from sklearn.feature_extraction.text import CountVectorizer

In [3]:
data_path = r"/home/alex/data/nlp/agmir"
img_path = os.path.join(data_path, r"img")
captions_path = os.path.join(data_path, r"txt/ecgen-radiology")

In [4]:
with open(os.path.join(data_path,'cases.json'),'r') as handle:
    cases = json.load(handle)

In [23]:
def clean_merge_tags(tags):
    return ' '.join(
        [t.lower().replace(' ','_').replace(',','_').replace('-','_') for t in tags])

In [37]:
class IUXRay(Dataset):
    def __init__(self, caption_json, tags_vocab=None):
        self.cases = self.load_json(caption_json)
        self.case_list = [case['id'] for case in self.cases]
        if tags_vocab:
            self.tags_vocab = tags_vocab
        else:
            self.tags_vocab = self.create_tags_voc(self.cases)
        
    def __len__(self):
        return len(self.case_list)
        
    def __getitem__(self, index):
        case = self.cases[index]
        return clean_merge_tags(case['tags_mti']), cases[0]['impression']+' '+cases[0]['findings'], case['id'], index
        
    def create_tags_voc(self, cases):
        tags = []
        for case in cases:
            tags += case['tags_mti']
        tags_voc = [t.lower().replace(' ','_').replace(',','_').replace('-','_') for t in tags]
        return list(set(tags_voc))
        
    def load_json(self, caption_json):
        with open(caption_json,'r') as handle:
            cases = json.load(handle)
        return cases

In [38]:
ds = IUXRay(os.path.join(data_path,'cases.json'))

In [22]:
ds.tags_vocab

['coronary_vessels',
 'carcinoma',
 'skin_fold',
 'bilateral_breast_implants',
 'pulmonary_artery',
 'elevated_diaphragm',
 'renal_dialysis',
 'picc',
 'chronic_fibrosis',
 'calcifications_of_the_aorta',
 'cardiac_monitor',
 'alveolar_edema',
 'cervical_spine_surgeries',
 'joint_prosthesis',
 'exudates_and_transudates',
 'drainage',
 'bone_density',
 'empyema',
 'chest_tubes',
 'pneumomediastinum',
 'pulmonary_sarcoidosis',
 'rib_fracture',
 'calcific_tendinitis',
 'calcified_lymph_nodes',
 'bronchiolitis__viral',
 'exostoses',
 'pleural_fluid',
 'hepatic_cyst',
 'cabg',
 'cardiac_pacing',
 'pleural_effusion',
 'granulomatous_disease',
 'atheroscleroses',
 'pulmonary_disease__obstructive',
 'pericardial_effusion',
 'displaced_fractures',
 'osteoarthritis__knee',
 'sarcoidosis__pulmonary',
 'patchy_atelectasis',
 'pneumonia_right_lower_lobe',
 'contusion',
 'fourth_rib_fracture',
 'azygos_vein',
 'granuloma',
 'right_upper_lobe_pneumonia',
 'aneurysm__dissecting',
 'bullous_emphysema',


In [40]:
ds.__getitem__(0)

('degenerative_change diaphragm',
 '1. Mildly elevated right hemidiaphragm. Otherwise no acute cardiopulmonary abnormality seen. Heart size and mediastinal contours are within normal limits. Pulmonary vasculature is unremarkable. No focal airspace consolidation. There is mild elevation right hemidiaphragm. No visible pleural effusion or pneumothorax. There are mild degenerative changes along the thoracic spine.',
 'CXR2216',
 0)

### get semantic features (BOTags) from dataset 

In [24]:
from sklearn.feature_extraction.text import CountVectorizer

In [27]:
countvec = CountVectorizer(vocabulary=ds.tags_vocab)

In [32]:
countvec.transform([clean_merge_tags(cases[0]['tags_mti'])]).toarray()

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 