In [1]:
from medvqa.utils.files import load_pickle, load_json_file, save_to_pickle
from medvqa.datasets.iuxray import IUXRAY_CACHE_DIR
from medvqa.datasets.mimiccxr import MIMICCXR_CACHE_DIR
from medvqa.utils.common import CACHE_DIR
from medvqa.datasets.medical_tags_extractor import MedicalTagsExtractor
import medvqa
from PIL import Image
import random
import os
from importlib import reload

In [59]:
reload(medvqa.datasets.medical_tags_extractor)

<module 'medvqa.datasets.medical_tags_extractor' from '/home/pamessina/medvqa/medvqa/datasets/medical_tags_extractor.py'>

In [2]:
iuxray_qa_adapted_reports_path = os.path.join(IUXRAY_CACHE_DIR, "qa_adapted_reports__20220904_091601.json")
iuxray_qa_adapted_reports = load_json_file(iuxray_qa_adapted_reports_path)

In [3]:
mimiccxr_qa_adapted_reports_path = os.path.join(MIMICCXR_CACHE_DIR, "qa_adapted_reports__20220904_095810.json")
mimiccxr_qa_adapted_reports = load_json_file(mimiccxr_qa_adapted_reports_path)

In [4]:
iuxray_chexpert_labels = load_pickle(os.path.join(IUXRAY_CACHE_DIR, "chexpert_labels_per_report__20220904_113427.pkl"))

In [5]:
mimiccxr_chexpert_labels = load_pickle(os.path.join(MIMICCXR_CACHE_DIR, "chexpert_labels_per_report__20220904_113605.pkl"))

In [6]:
assert len(iuxray_chexpert_labels) == len(iuxray_qa_adapted_reports['reports'])

In [7]:
assert len(mimiccxr_chexpert_labels) == len(mimiccxr_qa_adapted_reports['reports'])

In [8]:
for i, labels in enumerate(iuxray_chexpert_labels):
    if labels[0] == 1:
        assert all(x == 0 for x in labels[1:-1]), (labels, iuxray_qa_adapted_reports['reports'][i])

In [9]:
for i, labels in enumerate(mimiccxr_chexpert_labels):
    if labels[0] == 1:
        assert all(x == 0 for x in labels[1:-1]), (labels, mimiccxr_qa_adapted_reports['reports'][i])

In [10]:
def healthy_and_unhealthy_reports(chexpert_labels):
    healthy_ids = []
    unhealthy_ids = []
    for i, labels in enumerate(chexpert_labels):
        if labels[0] == 1:
            healthy_ids.append(i)
        else:
            unhealthy_ids.append(i)
    return healthy_ids, unhealthy_ids

In [11]:
iu_h_ids, iu_unh_ids = healthy_and_unhealthy_reports(iuxray_chexpert_labels)
len(iu_h_ids), len(iu_unh_ids)

(1438, 2489)

In [12]:
mi_h_ids, mi_unh_ids = healthy_and_unhealthy_reports(mimiccxr_chexpert_labels)
len(mi_h_ids), len(mi_unh_ids)

(44472, 183363)

In [13]:
iuxray_chexpert_labels[random.choice(iu_h_ids)]

array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)

In [14]:
iuxray_chexpert_labels[random.choice(iu_unh_ids)]

array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)

In [15]:
mimiccxr_chexpert_labels[random.choice(mi_h_ids)]

array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)

In [16]:
mimiccxr_chexpert_labels[random.choice(mi_unh_ids)]

array([0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], dtype=int8)

In [17]:
med_tags_extractor = MedicalTagsExtractor('medical_terms_frequency__20220918_184255.pkl')

In [18]:
def show_iuxray_example(idx, reports):    
    report = reports[idx]
    report = '. '.join(report['sentences'][i] for i in report['matched'])
    print('Report:\n')
    print(report)
    print('\nTags:\n')
    print(med_tags_extractor.extract_tags_sequence(report))

In [20]:
x = random.choice(iu_h_ids)
x, show_iuxray_example(x, iuxray_qa_adapted_reports['reports'])

Report:

Frontal and lateral views of the chest with overlying external cardiac monitor leads show normal size and configuration of the cardiac silhouette. Normal pulmonary vasculature and central airways. No focal airspace consolidation or pleural effusion. No acute or active cardiac , pulmonary or pleural disease

Tags:

['frontal', 'lateral', 'overlying', 'external', 'cardiac', 'monitoring', 'leads', 'normal', 'size', 'configuration', 'cardiac', 'silhouette', 'normal', 'pulmonary', 'vasculature', 'central', 'airway', 'no', 'focal', 'airspace', 'consolidation', 'pleural', 'effusion', 'no', 'cardiac', 'pulmonary', 'pleural', 'disease']


(935, None)

In [21]:
x = random.choice(iu_unh_ids)
x, show_iuxray_example(x, iuxray_qa_adapted_reports['reports'])

Report:

The cardiac silhouette is moderately enlarged with a cardiothoracic ratio of NUMBER/NUMBER. Diffuse coarse interstitial opacity seen throughout the lungs with perihilar and lower lobe predominance. There is right greater than left bibasilar consolidation. There are small pleural effusions , right larger than left. No evidence of pneumothorax. Dense atherosclerotic calcification seen involving the thoracic and upper abdominal aorta. Enlarged cardiac silhouette with coarse perihilar and lower lobe interstitial opacities may be due to diffuse infection or heart failure. Small pleural effusions

Tags:

['cardiac', 'silhouette', 'moderate', 'enlarged', 'cardiothoracic', 'diffuse', 'coarse', 'interstitial', 'opacity', 'lungs', 'perihilar', 'lower', 'lobe', 'right', 'greater', 'left', 'bibasilar', 'consolidation', 'small', 'pleural', 'effusion', 'right', 'large', 'left', 'no', 'pneumothorax', 'density', 'atherosclerotic', 'calcified', 'thoracic', 'upper', 'abdominal', 'aorta', 'enlar

(826, None)

In [22]:
x = random.choice(mi_h_ids)
x, show_iuxray_example(x, mimiccxr_qa_adapted_reports['reports'])

Report:

The lungs are clear. There is no effusion or pneumothorax. The cardiomediastinal silhouette is normal. There is no radiopaque density to suggest a missing tooth fragment. No visualized tooth fragment identified

Tags:

['lungs', 'clear', 'no', 'effusion', 'pneumothorax', 'cardiomediastinal', 'silhouette', 'normal', 'no', 'radiopaque', 'density', 'missing', 'tooth', 'fragment', 'no', 'tooth', 'fragment']


(139751, None)

In [23]:
x = random.choice(mi_unh_ids)
x, show_iuxray_example(x, mimiccxr_qa_adapted_reports['reports'])

Report:

Low lung volumes. Left pectoral Port-A-Cath in situ. On the right , the patient has received a Swan-Ganz catheter. The catheter shows a normal course , the tip is positioned relatively distal in the right lower lobe artery. The tip should be pulled back by approximately 3 cm. No pleural effusions. No pneumothorax. Moderate pulmonary edema

Tags:

['lower', 'lung', 'volumes', 'left', 'pectoral', 'port', 'catheter', 'situ', 'right', 'swan', 'ganz', 'catheter', 'catheter', 'normal', 'course', 'tip', 'position', 'distal', 'right', 'lower', 'lobe', 'artery', 'tip', 'pulled', 'back', 'no', 'pleural', 'effusion', 'no', 'pneumothorax', 'moderate', 'pulmonary', 'edema']


(182306, None)

In [24]:
def _get_top_k_ngrams(reports, q_id, used_vocab, k=10, n=3):
    freq = dict()
    for report in reports:
        if q_id not in report['qa']:
            continue
        for i in report['qa'][q_id]:
            tags = med_tags_extractor.extract_tags_sequence(report['sentences'][i])
            for j in range(len(tags)-n+1):
                key = tuple(tags[j:j+n])
                freq[key] = freq.get(key, 0) + 1
    pairs = [(f,k) for k,f in freq.items()]
    pairs.sort(reverse=True)
    output = []
    for p in pairs:
        valid = True
        for w in p[1]:
            if w in used_vocab:
                valid = False
                break
        if valid:
            output.append(p)
            used_vocab.update(p[1])
            if len(output) == k:
                break
    return output

def get_top_k_ngrams(reports, q_id, ks, ns):
    used_vocab = set()
    output = []
    for k, n in zip(ks, ns):
        pairs = _get_top_k_ngrams(reports, q_id, used_vocab, k, n)
        output.extend(pairs)
    output.sort()
    return output

In [25]:
q_id = str(iuxray_qa_adapted_reports['questions'].index('support devices and foreign bodies?'))
top_ngrams = get_top_k_ngrams(iuxray_qa_adapted_reports['reports'], q_id, ks=[50, 50], ns=[2, 1])

In [26]:
top_ngrams

[(2, ('hilar', 'contours')),
 (2, ('humeral', 'prosthetic')),
 (2, ('lung', 'volumes')),
 (2, ('osseous', 'anchors')),
 (2, ('radiopaque', 'density')),
 (2, ('screws', 'fixation')),
 (2, ('terminates', 'lower')),
 (2, ('thoracic', 'tips')),
 (2, ('top', 'normal')),
 (2, ('tunneled', 'dialysis')),
 (3, ('base', 'atelectasis')),
 (3, ('generator', 'projecting')),
 (3, ('intact', 'place')),
 (3, ('vp', 'shunt')),
 (4, ('anterior', 'wall')),
 (4, ('bilateral', 'breast')),
 (4, ('bullet', 'fragment')),
 (4, ('closure', 'devices')),
 (4, ('enteric',)),
 (4, ('head',)),
 (4, ('large',)),
 (4, ('lobe',)),
 (4, ('lungs',)),
 (4, ('mediport',)),
 (4, ('midabdomen',)),
 (4, ('moderate',)),
 (4, ('nerve', 'stimulator')),
 (4, ('obscured',)),
 (4, ('pleural',)),
 (4, ('small',)),
 (4, ('vertebral',)),
 (5, ('abdominal',)),
 (5, ('below', 'diaphragm')),
 (5, ('configuration',)),
 (5, ('course',)),
 (5, ('fracture',)),
 (5, ('hemithorax',)),
 (5, ('no', 'pneumothorax')),
 (5, ('out',)),
 (5, ('rib',)

In [27]:
def classify_sentence(s, top_ngrams, max_n):
    tags = med_tags_extractor.extract_tags_sequence(s)
    ngram_sets = [set() for _ in range(max_n)]
    for n in range(1, max_n+1):
        for i in range(len(tags)-n+1):
            ngram = tuple(tags[i:i+n])
            ngram_sets[n-1].add(ngram)
    for i, ngram in enumerate(top_ngrams):
        if ngram[1] in ngram_sets[len(ngram[1])-1]:
            return i
    return len(top_ngrams)

In [28]:
def classify_sentences(reports, q_id, top_ngrams, max_n):
    freqs = dict()
    for report in reports:
        if q_id not in report['qa']:
            continue
        for i in report['qa'][q_id]:
            x = classify_sentence(report['sentences'][i], top_ngrams, max_n)
            freqs[x] = freqs.get(x, 0) + 1
    return freqs

In [29]:
classify_sentence('The cardiac silhouette and mediastinum size are within normal limits', top_ngrams, 2)

60

In [30]:
classify_sentences(iuxray_qa_adapted_reports['reports'], q_id, top_ngrams, 2)

{12: 3,
 100: 62,
 51: 4,
 21: 3,
 30: 2,
 66: 6,
 69: 5,
 37: 5,
 63: 7,
 33: 5,
 55: 5,
 93: 12,
 86: 12,
 73: 9,
 25: 4,
 45: 3,
 36: 2,
 19: 3,
 82: 15,
 97: 9,
 94: 8,
 90: 9,
 84: 13,
 41: 2,
 89: 9,
 88: 16,
 47: 6,
 16: 4,
 43: 4,
 58: 4,
 91: 10,
 44: 4,
 71: 4,
 64: 9,
 77: 7,
 80: 2,
 65: 6,
 50: 5,
 61: 7,
 85: 10,
 92: 14,
 17: 4,
 22: 3,
 31: 5,
 2: 2,
 54: 7,
 79: 6,
 87: 13,
 59: 7,
 62: 4,
 75: 5,
 53: 3,
 13: 3,
 83: 4,
 15: 4,
 49: 3,
 3: 2,
 78: 7,
 95: 8,
 98: 2,
 48: 5,
 7: 2,
 96: 8,
 46: 5,
 28: 2,
 81: 3,
 11: 3,
 14: 4,
 24: 4,
 99: 12,
 29: 3,
 52: 3,
 20: 4,
 67: 6,
 8: 2,
 57: 3,
 60: 6,
 74: 7,
 76: 4,
 68: 2,
 6: 2,
 1: 2,
 4: 2,
 27: 3,
 9: 2,
 18: 4,
 26: 4,
 5: 2,
 0: 2,
 23: 4,
 10: 2,
 40: 3,
 35: 5,
 34: 1,
 39: 3,
 38: 3,
 42: 2,
 32: 2}

In [31]:
!python ../../scripts/precompute_balanced_dataloading_metadata.py \
        --iuxray-qa-dataset-filename "qa_adapted_reports__20220904_091601.json" \
        --mimiccxr-qa-dataset-filename "qa_adapted_reports__20220904_095810.json" \
        --chexpert-labels-cache-filename "precomputed_chexpert_labels_20220904_105948.pkl" \
        --medical-terms-frequency-filename "medical_terms_frequency__20220918_184255.pkl"

Loading files ...
Loading /home/pamessina/medvqa-workspace/cache/vocab__min_freq=5__mode=report__from(qa_adapted_reports__20220904_091601.json;qa_adapted_reports__20220904_095810.json).pkl ...
1427203it [00:07, 181729.44it/s]
Vocabulary saved to /home/pamessina/medvqa-workspace/cache/vocab__min_freq=5__mode=report__from(qa_adapted_reports__20220904_091601.json;qa_adapted_reports__20220904_095810.json).pkl
Precomputing metadata ...
3927it [00:01, 2573.44it/s]
100%|██████████████████████████████████████████| 97/97 [00:00<00:00, 205.95it/s]
3927it [00:00, 7556.48it/s]
Balanced dataloading metadata saved to /home/pamessina/medvqa-workspace/cache/iuxray/balanced_dataloading_metadata__20220918_210029.pkl
227835it [02:13, 1705.12it/s]
100%|███████████████████████████████████████████| 97/97 [00:42<00:00,  2.26it/s]
227835it [00:48, 4660.74it/s]
Balanced dataloading metadata saved to /home/pamessina/medvqa-workspace/cache/mimiccxr/balanced_dataloading_metadata__20220918_210415.pkl


In [32]:
from medvqa.utils.files import load_pickle
from collections import Counter

In [33]:
iuxray_balanced_metadata = load_pickle('/home/pamessina/medvqa-workspace/cache/iuxray/balanced_dataloading_metadata__20220629_123626.pkl')

In [34]:
mimiccxr_balanced_metadata = load_pickle('/home/pamessina/medvqa-workspace/cache/mimiccxr/balanced_dataloading_metadata__20220629_123956.pkl')

In [35]:
def get_answers(q_id, metadata, reports):
    output = []
    for ri, report in enumerate(reports):
        if q_id in report['qa']:
            answer = '. '.join(report['sentences'][i] for i in report['qa'][q_id])
            output.append((metadata['healthy'][ri][q_id],
                           metadata['tags_based_class'][ri][q_id],
                           answer))
    return output

In [36]:
iuxray_qid2answers = {}
for qid in range(len(iuxray_qa_adapted_reports['questions'])):
    iuxray_qid2answers[qid] = get_answers(str(qid), iuxray_balanced_metadata, iuxray_qa_adapted_reports['reports'])

In [37]:
mimiccxr_qid2answers = {}
for qid in range(len(mimiccxr_qa_adapted_reports['questions'])):
    mimiccxr_qid2answers[qid] = get_answers(str(qid), mimiccxr_balanced_metadata, mimiccxr_qa_adapted_reports['reports'])

In [38]:
def print_health_statistics(qid2answers, questions):
    for qid, answers in qid2answers.items():
        print("------------------")
        print(qid, questions[qid])
        print(Counter(x[0] for x in answers))

In [39]:
# print_health_statistics(mimiccxr_qid2answers, mimiccxr_qa_adapted_reports['questions'])

In [40]:
print_health_statistics(iuxray_qid2answers, iuxray_qa_adapted_reports['questions'])

------------------
0 ARDS?
Counter()
------------------
1 COPD?
Counter({0: 44, 1: 5})
------------------
2 abscess and cavitation?
Counter({1: 3})
------------------
3 adenopathy?
Counter({0: 67, 1: 33})
------------------
4 air collections?
Counter({1: 55, 0: 15})
------------------
5 air space disease?
Counter({1: 405, 0: 147})
------------------
6 air-fluid level?
Counter({0: 4, 1: 4})
------------------
7 airways?
Counter({1: 27, 0: 4})
------------------
8 apical zone?
Counter({0: 83, 1: 29})
------------------
9 ascites?
Counter()
------------------
10 aspiration?
Counter({0: 10, 1: 3})
------------------
11 atelectasis?
Counter({0: 387, 1: 8})
------------------
12 azygos lobe?
Counter({1: 2})
------------------
13 azygos vein?
Counter({1: 1})
------------------
14 bleeding?
Counter({0: 1})
------------------
15 blurring?
Counter()
------------------
16 bones?
Counter({1: 1197, 0: 763})
------------------
17 bowel obstruction and loops?
Counter({1: 7, 0: 1})
------------------


In [41]:
_q = 'stomach?'
# _q = iuxray_qa_adapted_reports['questions'][74]
_qid = iuxray_qa_adapted_reports['questions'].index(_q)
answers = get_answers(str(_qid), iuxray_balanced_metadata, iuxray_qa_adapted_reports['reports'])

In [42]:
_q, len(answers)

('stomach?', 18)

In [43]:
iuxray_balanced_metadata['top_ngrams'][(_qid, 1)]

[(1, ('band', 'procedure')),
 (1, ('below', 'diaphragm')),
 (1, ('catheter',)),
 (1, ('contours', 'normal')),
 (1, ('course', 'inferiorly')),
 (1, ('esophagogastric', 'distal')),
 (1, ('fluid',)),
 (1, ('large',)),
 (1, ('lateral',)),
 (1, ('loops',)),
 (1, ('mild',)),
 (1, ('multiple', 'distended')),
 (1, ('nasogastric',)),
 (1, ('not',)),
 (1, ('overlying',)),
 (1, ('postoperative', 'esophagectomy')),
 (1, ('projecting', 'body')),
 (1, ('pull',)),
 (1, ('rectal',)),
 (1, ('removed', 'gastric')),
 (1, ('small', 'bowel')),
 (1, ('suction',)),
 (1, ('upper', 'quadrant')),
 (2, ('enteric',)),
 (2, ('tube', 'tip')),
 (3, ('aorta',)),
 (3, ('cardiac', 'apex')),
 (3, ('sided',)),
 (3, ('stomach', 'left'))]

In [44]:
[x for x in answers if x[0] == 0]

[(0, 1, 'To the stomach contours appear grossly clear'),
 (0,
  0,
  'Large hiatal hernia is identified containing stomach and colon. Stable appearance of large hiatal hernia containing stomach and large bowel as well as possible small bowel loops'),
 (0,
  2,
  'The stomach is distended with an air-fluid level. Large hiatal hernia with dilated intrathoracic stomach. CT findings suggestive organoaxial gastric volvulus')]

In [45]:
Counter([x[1] for x in answers if x[0] == 1])

Counter({4: 1,
         9: 1,
         0: 1,
         10: 1,
         7: 1,
         15: 1,
         25: 3,
         12: 1,
         19: 1,
         1: 1,
         23: 1,
         6: 1,
         3: 1})

In [24]:
from torch.utils.data import Dataset
from collections import Counter

In [3]:
class AtomicDataset(Dataset):
    def __init__(self, label, k):
        self.data = [f'{label}_{i}' for i in range(k)]
        self._length = int(1e12)
    
    def __len__(self):
        return self._length

    def __getitem__(self, i):
        data = self.data
        return data[i % len(data)]

In [10]:
class CompositeDataset(Dataset):
    def __init__(self, datasets, weights):
        self.datasets = datasets
        self._init_indices(datasets, weights)
        self._length = int(1e12)
    
    def _init_indices(self, datasets, weights):
        tot_w = sum(weights)
        freqs = [int(len(datasets) * 200 * w/tot_w) for w in weights]
        count = sum(freqs)
        indices = [None] * count
        dataset_ids = list(range(len(datasets)))
        dataset_ids.sort(key = lambda i : freqs[i], reverse=True)
        available_slots = list(range(count))
        for i in dataset_ids:
            assert len(available_slots) >= freqs[i]
            step = len(available_slots) / freqs[i]
            for j in range(freqs[i]):
                jj = int(j * step)
                indices[available_slots[jj]] = i
            available_slots = [s for s in available_slots if indices[s] is None]
        indices = [i for i in indices if i is not None]
        
        dataset_counts = [[0] * len(indices) for _ in range(len(datasets))]
        for i in range(len(datasets)):
            for j in range(len(indices)):
                dataset_counts[i][j] = (indices[j] == i) + (dataset_counts[i][j-1] if j > 0 else 0)
            assert dataset_counts[i][-1] > 0, (i, dataset_counts[i], indices)

        self.indices = indices
        self.counts = dataset_counts
    
    def __len__(self):
        return self._length
    
    def __getitem__(self, i):
        indices = self.indices        
        ii = i % len(indices)
        idx = indices[ii]
        assert idx < len(self.datasets)
        counts = self.counts[idx]
        j = (i // len(indices)) * counts[-1] + (counts[ii - 1] if ii > 0 else 0)
        assert j < len(self.datasets[idx])
        return self.datasets[idx][j]

In [37]:
import numpy as np

INFINITE_DATASET_LENGTH = int(1e18)

def _get_balancedly_distributed_class_indices(class_weights):
    w_sum = sum(class_weights)
    ws = [w / w_sum for w in class_weights]
    w_min = min(ws)
    assert w_min > 0
    freqs = [int(20 * w/w_min) for w in ws]
    count = sum(freqs)
    indices = [None] * count
    class_ids = list(range(len(class_weights)))
    class_ids.sort(key = lambda i : freqs[i], reverse=True)
    available_slots = list(range(count))
    for i in class_ids:
        assert len(available_slots) >= freqs[i]
        step = len(available_slots) / freqs[i]
        for j in range(freqs[i]):
            jj = int(j * step)
            indices[available_slots[jj]] = i
        available_slots = [s for s in available_slots if indices[s] is None]
    indices = [i for i in indices if i is not None]
    return np.array(indices, dtype=int)

class BatchedCompositeInfiniteDataset(Dataset):
    def __init__(self, datasets, weights, batch_size):
        self.datasets = datasets
        self._init_indices(datasets, weights)
        self.batch_size = batch_size
    
    def _init_indices(self, datasets, weights):
        assert len(datasets) == len(weights)
        
        dataset_indices = _get_balancedly_distributed_class_indices(weights)
        
        dataset_counts = np.zeros((len(datasets), len(dataset_indices)), dtype=int)
        for i in range(len(datasets)):
            for j in range(len(dataset_indices)):
                dataset_counts[i][j] = (dataset_indices[j] == i) + (dataset_counts[i][j-1] if j > 0 else 0)
            assert dataset_counts[i][-1] > 0, (i, dataset_counts[i], dataset_indices)

        self.indices = dataset_indices
        self.counts = dataset_counts
    
    def __len__(self):
        return INFINITE_DATASET_LENGTH
    
    def __getitem__(self, i):
        indices = self.indices
        batch_size = self.batch_size
        batch_i = i // batch_size
        dataset_i = batch_i % len(indices)
        dataset_id = indices[dataset_i]
        assert dataset_id < len(self.datasets)
        counts = self.counts[dataset_id]
        j = (counts[-1] * (batch_i // len(indices)) +
            (counts[dataset_i - 1] if dataset_i > 0 else 0) - batch_i) * batch_size + i
        assert j < len(self.datasets[dataset_id])
        return self.datasets[dataset_id][j]

In [6]:
datasetA = AtomicDataset('A', 5123)

In [7]:
datasetB = AtomicDataset('B', 1022)

In [8]:
datasetC = AtomicDataset('C', 3033)

In [9]:
datasetD = AtomicDataset('D', 232)

In [38]:
batched_dataset = BatchedCompositeInfiniteDataset([datasetA, datasetB, datasetC], [1,2,3], 2)

In [39]:
Counter(batched_dataset.indices)

Counter({2: 60, 1: 40, 0: 20})

In [36]:
batched_dataset.indices[:50]

array([2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 2, 0, 0, 2, 1, 2, 2, 1, 0, 2, 0, 1,
       1, 0, 2, 0, 2, 1, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2,
       0, 1, 2, 2, 1, 2])

In [31]:
[batched_dataset[i] for i in range(50)]

['C_0',
 'C_1',
 'C_2',
 'C_3',
 'B_0',
 'B_1',
 'C_4',
 'C_5',
 'B_2',
 'B_3',
 'C_6',
 'C_7',
 'B_4',
 'B_5',
 'C_8',
 'C_9',
 'A_0',
 'A_1',
 'C_10',
 'C_11',
 'B_6',
 'B_7',
 'C_12',
 'C_13',
 'B_8',
 'B_9',
 'C_14',
 'C_15',
 'A_2',
 'A_3',
 'C_16',
 'C_17',
 'B_10',
 'B_11',
 'C_18',
 'C_19',
 'B_12',
 'B_13',
 'C_20',
 'C_21',
 'A_4',
 'A_5',
 'C_22',
 'C_23',
 'B_14',
 'B_15',
 'C_24',
 'C_25',
 'B_16',
 'B_17']

In [508]:
complex_dataset = ComposedDataset([datasetA, datasetB, datasetC], [5, 7, 2000])

freqs =  [1, 2, 596]


In [506]:
complex_dataset_2 = ComposedDataset([complex_dataset, datasetD], [1, 1])

freqs =  [200, 200]


In [507]:
[complex_dataset_2[i] for i in range(50)]

['C_0',
 'D_0',
 'C_1',
 'D_1',
 'B_0',
 'D_2',
 'C_2',
 'D_3',
 'C_3',
 'D_4',
 'B_1',
 'D_5',
 'C_4',
 'D_6',
 'C_5',
 'D_7',
 'A_0',
 'D_8',
 'C_6',
 'D_9',
 'B_2',
 'D_10',
 'C_7',
 'D_11',
 'C_8',
 'D_12',
 'A_1',
 'D_13',
 'C_9',
 'D_14',
 'C_10',
 'D_15',
 'B_3',
 'D_16',
 'C_11',
 'D_17',
 'B_4',
 'D_18',
 'C_12',
 'D_19',
 'C_13',
 'D_20',
 'A_2',
 'D_21',
 'C_14',
 'D_22',
 'C_15',
 'D_23',
 'B_5',
 'D_24']

In [510]:
import math

In [576]:
def normalize_weights(ws):
    w_sum = sum(ws)
    ws = [w/w_sum for w in ws]
    ws = [math.log(1 + w * 2e3) for w in ws]
    w_sum = sum(ws)
    ws = [w/w_sum for w in ws]
    return ws

In [580]:
normalize_weights([5000, 100])

[0.672374888615261, 0.32762511138473915]

In [554]:
x = [math.log(1 + x * 0.15) for x in [4e4, 5, 100]]
x_sum = sum(x)
x, [y/x_sum for y in x]

([8.699681400989514, 0.5596157879354227, 2.772588722239781],
 [0.7230521852702111, 0.04651106169616683, 0.2304367530336221])