In [None]:
import os
import torch
import pandas as pd

root = "/home/free4ky/projects/chest-diseases/data/mosmed/mosmed_embeddings/embeds_not_normalized_base"

# load your dataframes (update with actual paths)
df_covid = pd.read_excel("/home/free4ky/projects/chest-diseases/data/mosmed/dataset_registry_COVID19 type I-v 4.xlsx")    # must have ['study_instance_anon', 'pathology']
df_lungcr = pd.read_excel("/home/free4ky/projects/chest-diseases/data/mosmed/dataset_registry_LDCT.xlsx")  # must have ['study_instance_anon', 'pathology']

# build lookup maps
covid_map = df_covid.set_index("study_instance_anon")["pathology"].to_dict()
lungcr_map = df_lungcr.set_index("study_instance_anon")["pathology"].replace({1: 2, 0: 0}).to_dict()

all_tensors, all_labels, all_filenames = [], [], []

# === 1) COVID19_1110 ===
covid19_root = os.path.join(root, "COVID19_1110", "studies")
for subfolder in os.listdir(covid19_root):
    subpath = os.path.join(covid19_root, subfolder)
    if not os.path.isdir(subpath):
        continue

    if subfolder == "CT-0":
        label = 0   # no pathology
    elif subfolder.startswith("CT-") and subfolder != "CT-0":
        label = 1   # covid
    else:
        continue 

    for dirpath, _, files in os.walk(subpath):
        for f in files:
            if f.endswith(".pt"):
                tensor = torch.load(os.path.join(dirpath, f)).view(-1)
                all_tensors.append(tensor)
                all_labels.append(label)
                all_filenames.append(f[:-3])

# === 2) MosMedData-CT-COVID19-type I-v 4 ===
covid_mos_root = os.path.join(root, "MosMedData-CT-COVID19-type I-v 4", "studies")
for dirpath, _, files in os.walk(covid_mos_root):
    for f in files:
        if f.endswith(".pt"):
            key = f[:-3]
            if key in covid_map:
                label = int(covid_map[key])  # 0=no, 1=covid
                tensor = torch.load(os.path.join(dirpath, f)).view(-1)
                all_tensors.append(tensor)
                all_labels.append(label)
                all_filenames.append(key)

# === 3) MosMedData-CT-COVID19-type VII-v 1 ===
covid7_root = os.path.join(root, "MosMedData-CT-COVID19-type VII-v 1", "studies")
for dirpath, _, files in os.walk(covid7_root):
    for f in files:
        if f.endswith(".pt"):
            tensor = torch.load(os.path.join(dirpath, f)).view(-1)
            all_tensors.append(tensor)
            all_labels.append(1)  # covid
            all_filenames.append(f[:-3])

# === 4) MosMedData-LDCT-LUNGCR-type I-v 1 ===
lungcr_root = os.path.join(root, "MosMedData-LDCT-LUNGCR-type I-v 1", "studies")
for dirpath, _, files in os.walk(lungcr_root):
    for f in files:
        if f.endswith(".pt"):
            key = f[:-3]
            if key in lungcr_map:
                label = int(lungcr_map[key])  # 0=no, 2=cancer
                tensor = torch.load(os.path.join(dirpath, f)).view(-1)
                all_tensors.append(tensor)
                all_labels.append(label)
                all_filenames.append(key)

# === 5) CT_LUNGCANCER_500 ===
lung500_root = os.path.join(root, "CT_LUNGCANCER_500")
for dirpath, _, files in os.walk(lung500_root):
    for f in files:
        if f.endswith(".pt"):
            tensor = torch.load(os.path.join(dirpath, f)).view(-1)
            all_tensors.append(tensor)
            all_labels.append(2)  # cancer
            all_filenames.append(f[:-3])

# === Final tensors ===
X = torch.stack(all_tensors)
y = torch.tensor(all_labels)

print("Final dataset shapes:")
print("X:", X.shape)  # [N, dim]
print("y:", y.shape)  # [N]
print("Label counts:", torch.bincount(y))
print("Example:", all_filenames[:5], all_labels[:5])


Final dataset shapes:
X: torch.Size([1866, 512])
y: torch.Size([1866])
Label counts: tensor([354, 926, 586])
Example: ['study_0567.nii.gz', 'study_0619.nii.gz', 'study_0357.nii.gz', 'study_0444.nii.gz', 'study_0618.nii.gz'] [1, 1, 1, 1, 1]


['study_0567.nii.gz',
 'study_0619.nii.gz',
 'study_0357.nii.gz',
 'study_0444.nii.gz',
 'study_0618.nii.gz',
 'study_0343.nii.gz',
 'study_0874.nii.gz',
 'study_0531.nii.gz',
 'study_0549.nii.gz',
 'study_0601.nii.gz',
 'study_0836.nii.gz',
 'study_0693.nii.gz',
 'study_0451.nii.gz',
 'study_0832.nii.gz',
 'study_0743.nii.gz',
 'study_0338.nii.gz',
 'study_0279.nii.gz',
 'study_0318.nii.gz',
 'study_0609.nii.gz',
 'study_0803.nii.gz',
 'study_0580.nii.gz',
 'study_0397.nii.gz',
 'study_0650.nii.gz',
 'study_0911.nii.gz',
 'study_0694.nii.gz',
 'study_0659.nii.gz',
 'study_0460.nii.gz',
 'study_0606.nii.gz',
 'study_0692.nii.gz',
 'study_0550.nii.gz',
 'study_0747.nii.gz',
 'study_0719.nii.gz',
 'study_0867.nii.gz',
 'study_0791.nii.gz',
 'study_0403.nii.gz',
 'study_0427.nii.gz',
 'study_0597.nii.gz',
 'study_0620.nii.gz',
 'study_0848.nii.gz',
 'study_0513.nii.gz',
 'study_0268.nii.gz',
 'study_0603.nii.gz',
 'study_0805.nii.gz',
 'study_0314.nii.gz',
 'study_0856.nii.gz',
 'study_04

In [2]:
covid_map

{'1.2.643.5.1.13.13.12.2.77.8252.07020008000803040006150512120214': 1,
 '1.2.643.5.1.13.13.12.2.77.8252.10060504080402140013031304000505': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.05061401141109080004150810060610': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.14091504110604010913110114080102': 1,
 '1.2.643.5.1.13.13.12.2.77.8252.03111400040804030306140202121209': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.04030602081204060315100109150603': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.14090201080411110508140614131403': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.09000804020815001413000304081115': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.03010502100705040412060811051407': 1,
 '1.2.643.5.1.13.13.12.2.77.8252.04130304041314081106100501141208': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.14120102080510140012040710030701': 1,
 '1.2.643.5.1.13.13.12.2.77.8252.08110109061302020409130606150100': 1,
 '1.2.643.5.1.13.13.12.2.77.8252.04111404110704041103151006140914': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.09060712070101090006110808010001': 0,
 '1.2.

In [3]:
lungcr_map

{'1.2.643.5.1.13.13.12.2.77.8252.14020912071413071311010704150801': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.01020109051302151412030112071414': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.00020801120805130807110203100907': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.02121208000007081411091503110907': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.01130702041001000405061205120502': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.07070004020514121411011509080313': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.01140703011513140706080707021214': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.13040303121011030314040011130801': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.01110706001204091100140512150014': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.03051515111505151515010812140415': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.12121200010312140710130509130104': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.04021505121008030001041100120407': 2,
 '1.2.643.5.1.13.13.12.2.77.8252.06021506101003111002091006100101': 0,
 '1.2.643.5.1.13.13.12.2.77.8252.00070114131113031501120012040609': 0,
 '1.2.

In [28]:
import torch

num_classes = 20  # total vector length

def encode_label(label: int) -> torch.Tensor:
    """
    Custom one-hot encoding:
      0 -> vector of 20 zeros
      1 -> one-hot with '1' at index 18 (COVID)
      2 -> one-hot with '1' at index 19 (Cancer)
    """
    v = torch.zeros(num_classes, dtype=torch.long)
    if label == 1:
        v[18] = 1
    elif label == 2:
        v[19] = 1
    return v

# after collecting all_labels
y_encoded = torch.stack([encode_label(lbl) for lbl in all_labels])

print("X:", X.shape)                   # [N, dim]
print("y_encoded:", y_encoded.shape)   # [N, 20]
print("Examples:")
for lbl in [0, 1, 2]:
    print(lbl, "->", encode_label(lbl).tolist())


X: torch.Size([1866, 512])
y_encoded: torch.Size([1866, 20])
Examples:
0 -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
2 -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]


In [29]:
y_encoded.shape

torch.Size([1866, 20])

In [35]:
X.shape

torch.Size([1866, 512])

In [31]:
from sklearn.model_selection import train_test_split
import os
save_dir = '/home/free4ky/projects/chest-diseases/data/preprocessed_mosmed_base'
os.makedirs(save_dir,exist_ok=True)
# --- Train/Val/Test split (80/10/10) ---
train_idx, temp_idx = train_test_split(range(len(X)), test_size=0.1, random_state=42, shuffle=True)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42, shuffle=True)

splits = {
    "train": train_idx,
    "val": val_idx,
    "test": test_idx
}

for split_name, indices in splits.items():
    split_X = X[indices]
    split_Y = y_encoded[indices]
    torch.save(split_X, os.path.join(save_dir, f"{split_name}_data.pt"))
    torch.save(split_Y, os.path.join(save_dir, f"{split_name}_labels.pt"))
    print(f"Saved {split_name}: {split_X.shape}, {split_Y.shape}")

Saved train: torch.Size([1679, 512]), torch.Size([1679, 20])
Saved val: torch.Size([93, 512]), torch.Size([93, 20])
Saved test: torch.Size([94, 512]), torch.Size([94, 20])


In [21]:
len(y_encoded) - y_encoded.any(dim=-1).sum() # correct 254 + 50 + 50

tensor(354)

# Update old labels for 2 new classes

In [23]:
import torch
old_y_train = torch.load('/home/free4ky/projects/chest-diseases/data/preprocessed_val/validation_labels.pt')

Y = [old_y_train]

In [24]:
# update old labels for two new classes

import torch

new_ys = []
for old_y in Y:
# suppose old_y is [num_samples, 18]
    num_samples = old_y.shape[0]

    # make zeros of shape [num_samples, 2]
    zeros = torch.zeros((num_samples, 2), dtype=old_y.dtype, device=old_y.device)

    # concatenate along the last dimension
    new_y = torch.cat([old_y, zeros], dim=1)

    print(old_y.shape)  # [N, 18]
    print(new_y.shape)  # [N, 20]
    new_ys.append(new_y)


torch.Size([3039, 18])
torch.Size([3039, 20])


In [9]:
new_ys[0].shape

torch.Size([42431, 20])

In [None]:
new_ys[0][:,-2:].shape

tensor(0.)

In [26]:
import os
save_dir = '/home/free4ky/projects/chest-diseases/data/preprocessed_val_20'
os.makedirs(save_dir, exist_ok=True)
splits = {
    # "train": new_ys[0],
    "validation": new_ys[0],
}

for split_name, labels in splits.items():
    torch.save(labels, os.path.join(save_dir, f"{split_name}_labels.pt"))
    print(f"Saved {split_name}: {labels.shape}")

Saved validation: torch.Size([3039, 20])
