In [2]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import numpy as np
import torch.nn as nn
import hydra
from torchmetrics import PrecisionRecallCurve, F1Score, ConfusionMatrix
from typing import List, Any
from torcheval.metrics.functional import binary_auprc, binary_auroc
from collections import defaultdict
from torch.utils.data import TensorDataset, DataLoader
from loguru import logger
from tqdm import tqdm
from hydra import initialize, compose
from omegaconf import OmegaConf

In [3]:
def _load_concat_dataset(embed_paths, label_paths):
    """
    Load multiple .pt files and concatenate along dim 0
    """
    X_list = [torch.load(p) for p in embed_paths]
    y_list = [torch.load(p) for p in label_paths]
    
    X = torch.cat(X_list, dim=0)
    y = torch.cat(y_list, dim=0)
    
    return TensorDataset(X, y)

In [8]:
with initialize(version_base=None, config_path="../training/classification/configs"):
    cfg = compose(config_name="config_multilabel_test.yaml",)
    print(OmegaConf.to_yaml(cfg))

optimizer:
  _target_: torch.optim.Adam
  lr: 0.03
scheduler:
  body:
    _target_: torch.optim.lr_scheduler.CyclicLR
    base_lr: 1.0e-06
    max_lr: 0.03
    cycle_momentum: false
    mode: triangular2
  pl_cfg:
    interval: step
dataset:
  train:
    embeds:
    - /home/free4ky/projects/chest-diseases/data/preprocessed_train_20/train_data.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_train_20/val_data.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/train_data.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/val_data.pt
    labels:
    - /home/free4ky/projects/chest-diseases/data/preprocessed_train_20/train_labels.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_train_20/val_labels.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/train_labels.pt
    - /home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/val_labels.pt
  val:
    embeds:
    - /home/free4ky/projects/chest

In [16]:
ds = _load_concat_dataset(
        cfg.dataset.train.embeds + cfg.dataset.val.embeds,
        cfg.dataset.train.labels + cfg.dataset.val.labels
        )

In [18]:
ds.tensors[0].shape

torch.Size([52051, 512])

In [22]:
binary_labels = ds.tensors[1].any(dim=-1)
unique_elements, count = binary_labels.unique(return_counts=True)

In [24]:
multilabel_labels = ds.tensors[1].sum(dim=0)
multilabel_labels

tensor([ 6131., 14244.,  5633.,  3638., 12790.,  7168., 13009.,  9722., 12976.,
        22742., 18604., 13420.,  6081.,  3891.,  5328.,  8899.,  5062.,  3994.,
          926.,   586.])

In [26]:
deseases_names= [
    "Medical material",
    "Arterial wall calcification",
    "Cardiomegaly",
    "Pericardial effusion",
    "Coronary artery wall calcification",
    "Hiatal hernia",
    "Lymphadenopathy",
    "Emphysema",
    "Atelectasis",
    "Lung nodule",
    "Lung opacity",
    "Pulmonary fibrotic sequela",
    "Pleural effusion",
    "Mosaic attenuation pattern",
    "Peribronchial thickening",
    "Consolidation",
    "Bronchiectasis",
    "Interlobular septal thickening",
    "COVID-19",
    "Cancer",
]

In [29]:
name2count = dict(zip(deseases_names,multilabel_labels.tolist()))
name2count

{'Medical material': 6131.0,
 'Arterial wall calcification': 14244.0,
 'Cardiomegaly': 5633.0,
 'Pericardial effusion': 3638.0,
 'Coronary artery wall calcification': 12790.0,
 'Hiatal hernia': 7168.0,
 'Lymphadenopathy': 13009.0,
 'Emphysema': 9722.0,
 'Atelectasis': 12976.0,
 'Lung nodule': 22742.0,
 'Lung opacity': 18604.0,
 'Pulmonary fibrotic sequela': 13420.0,
 'Pleural effusion': 6081.0,
 'Mosaic attenuation pattern': 3891.0,
 'Peribronchial thickening': 5328.0,
 'Consolidation': 8899.0,
 'Bronchiectasis': 5062.0,
 'Interlobular septal thickening': 3994.0,
 'COVID-19': 926.0,
 'Cancer': 586.0}

In [36]:
df = pd.DataFrame.from_dict(name2count, orient='index', columns=['count']).reset_index()
df = df.rename(columns={'index': 'class'})
df.to_excel('multilabel_counts.xlsx', index=False)

In [None]:
pd.to_xl

In [23]:
count

tensor([ 5924, 46127])

In [7]:
cfg.dataset.train.labels.extend(cfg.dataset.val.labels)

In [12]:
len(cfg.dataset.train.embeds + cfg.dataset.val.embeds)

6

In [None]:
len(cfg.dataset.train.labels + cfg.dataset.val.labels)

6