<a href="https://colab.research.google.com/github/Quaser001/scribbles_ml/blob/main/TB_xray_audio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# ================================
# MODULE 1: Kaggle Data Access + Setup
# ================================

import os
from google.colab import files

# --- Step 0: Upload Kaggle API key (kaggle.json) ---
print("📂 Please upload your kaggle.json file (from Kaggle > Account > Create API Token)")
uploaded = files.upload()

os.makedirs("/root/.kaggle", exist_ok=True)
os.rename("kaggle.json", "/root/.kaggle/kaggle.json")
os.chmod("/root/.kaggle/kaggle.json", 0o600)

# --- Step 1: Install Kaggle API ---
!pip install kaggle --quiet

# --- Step 2: Download & unzip datasets ---
DATA_DIR = "/content/data"
os.makedirs(DATA_DIR, exist_ok=True)

# TB Audio dataset
!kaggle datasets download -d ruchikashirsath/tb-audio -p $DATA_DIR --unzip
TB_AUDIO_ROOT = os.path.join(DATA_DIR, "tb-audio")

# Shenzhen X-ray dataset
!kaggle datasets download -d raddar/tuberculosis-chest-xrays-shenzhen -p $DATA_DIR --unzip
SHENZHEN_ROOT = os.path.join(DATA_DIR, "tuberculosis-chest-xrays-shenzhen")

print("✅ Datasets ready!")
print("TB Audio path:", TB_AUDIO_ROOT)
print("Shenzhen X-ray path:", SHENZHEN_ROOT)


📂 Please upload your kaggle.json file (from Kaggle > Account > Create API Token)


Saving kaggle.json to kaggle.json
Dataset URL: https://www.kaggle.com/datasets/ruchikashirsath/tb-audio
License(s): MIT
Downloading tb-audio.zip to /content/data
 97% 1.04G/1.08G [00:11<00:00, 94.1MB/s]
100% 1.08G/1.08G [00:11<00:00, 103MB/s] 
Dataset URL: https://www.kaggle.com/datasets/raddar/tuberculosis-chest-xrays-shenzhen
License(s): unknown
Downloading tuberculosis-chest-xrays-shenzhen.zip to /content/data
100% 3.51G/3.51G [00:38<00:00, 19.1MB/s]
100% 3.51G/3.51G [00:38<00:00, 99.2MB/s]
✅ Datasets ready!
TB Audio path: /content/data/tb-audio
Shenzhen X-ray path: /content/data/tuberculosis-chest-xrays-shenzhen


In [3]:
import zipfile
import os

DATA_DIR = "/content/data"

# Unzip TB Audio
tb_audio_zip = os.path.join(DATA_DIR, "tb-audio.zip")
tb_audio_dir = os.path.join(DATA_DIR, "tb-audio")
if os.path.exists(tb_audio_zip):
    print("🔓 Extracting TB Audio...")
    with zipfile.ZipFile(tb_audio_zip, 'r') as zip_ref:
        zip_ref.extractall(tb_audio_dir)

# Unzip Shenzhen X-ray
shenzhen_zip = os.path.join(DATA_DIR, "tuberculosis-chest-xrays-shenzhen.zip")
shenzhen_dir = os.path.join(DATA_DIR, "tuberculosis-chest-xrays-shenzhen")
if os.path.exists(shenzhen_zip):
    print("🔓 Extracting Shenzhen X-ray...")
    with zipfile.ZipFile(shenzhen_zip, 'r') as zip_ref:
        zip_ref.extractall(shenzhen_dir)

print("✅ Extraction complete!")

# Re-list structure after extraction
print("\n📂 TB Audio Dataset Structure")
for root, dirs, files in os.walk(tb_audio_dir):
    level = root.replace(tb_audio_dir, "").count(os.sep)
    indent = " " * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = " " * 2 * (level + 1)
    for f in files:
        print(f"{subindent}{f}")

print("\n" + "="*80 + "\n")

print("📂 Shenzhen X-ray Dataset Structure")
for root, dirs, files in os.walk(shenzhen_dir):
    level = root.replace(shenzhen_dir, "").count(os.sep)
    indent = " " * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = " " * 2 * (level + 1)
    for f in files:
        print(f"{subindent}{f}")


✅ Extraction complete!

📂 TB Audio Dataset Structure


📂 Shenzhen X-ray Dataset Structure


In [4]:
# ==========================================
# MODULE 2: Dataset Verification (Corrected Paths)
# ==========================================
import os
import pandas as pd
from collections import Counter

# -----------------------
# 1) TB AUDIO DATASET
# -----------------------
TB_AUDIO_ROOT = "/content/data/Tuberculosis/raw_data/solicited_data"
TB_AUDIO_META = "/content/data/Tuberculosis/metadata/CODA_TB_Clinical_Meta_Info.csv"

# Load metadata
audio_metadata = pd.read_csv(TB_AUDIO_META)

# List all audio files in dataset
audio_files = [f for f in os.listdir(TB_AUDIO_ROOT) if f.lower().endswith(".wav")]

print("\n🔍 TB Audio Verification")
print("="*50)
print("Metadata entries:", len(audio_metadata))
print("Total audio files found:", len(audio_files))

# Assuming 'filename' column contains audio file names
if "filename" in audio_metadata.columns:
    meta_audio_names = set(audio_metadata["filename"].astype(str))
    actual_audio_names = set(audio_files)

    missing_in_folder = meta_audio_names - actual_audio_names
    extra_in_folder = actual_audio_names - meta_audio_names

    print("Missing files (in metadata but not in folder):", len(missing_in_folder))
    print("Extra files (in folder but not in metadata):", len(extra_in_folder))
else:
    print("⚠️ No 'filename' column found in audio metadata")

# Label distribution (assuming TB column exists, 1 = TB positive, 0 = negative)
if "TB" in audio_metadata.columns:
    print("Label distribution:", Counter(audio_metadata["TB"]))
else:
    print("⚠️ No 'TB' column found in audio metadata")

# -----------------------
# 2) SHENZHEN X-RAY DATASET
# -----------------------
SHENZHEN_ROOT = "/content/data/images/images"
SHENZHEN_META = "/content/data/shenzhen_metadata.csv"

# List all X-ray files
xray_files = [f for f in os.listdir(SHENZHEN_ROOT) if f.lower().endswith((".png", ".jpg", ".jpeg"))]

print("\n🔍 Shenzhen X-ray Verification")
print("="*50)
print("Total X-ray images found:", len(xray_files))

# Load Shenzhen metadata
xray_metadata = pd.read_csv(SHENZHEN_META)
print("Shenzhen metadata shape:", xray_metadata.shape)
print(xray_metadata.head())

# Quick label summary if 'Finding' or 'Label' column exists
label_col = None
for col in ["Finding", "Label", "diagnosis"]:
    if col in xray_metadata.columns:
        label_col = col
        break

if label_col:
    print("X-ray label distribution:", Counter(xray_metadata[label_col]))
else:
    print("⚠️ No recognizable label column found in Shenzhen metadata")

# -----------------------
# ✅ Done
print("\n✅ Dataset verification complete!")



🔍 TB Audio Verification
Metadata entries: 1105
Total audio files found: 9772
⚠️ No 'filename' column found in audio metadata
⚠️ No 'TB' column found in audio metadata

🔍 Shenzhen X-ray Verification
Total X-ray images found: 662
Shenzhen metadata shape: (662, 4)
            study_id     sex  age findings
0  CHNCXR_0001_0.png    Male   45   normal
1  CHNCXR_0002_0.png    Male   63   normal
2  CHNCXR_0003_0.png  Female   48   normal
3  CHNCXR_0004_0.png    Male   58   normal
4  CHNCXR_0005_0.png    Male   28   normal
⚠️ No recognizable label column found in Shenzhen metadata

✅ Dataset verification complete!


In [5]:
import pandas as pd

# X-ray metadata
xray_meta = pd.read_csv("/content/data/shenzhen_metadata.csv")
print(xray_meta.head())
print(xray_meta.columns)

# Audio metadata
audio_meta = pd.read_csv("/content/data/Tuberculosis/metadata/manifest.csv")
print(audio_meta.head())
print(audio_meta.columns)


            study_id     sex  age findings
0  CHNCXR_0001_0.png    Male   45   normal
1  CHNCXR_0002_0.png    Male   63   normal
2  CHNCXR_0003_0.png  Female   48   normal
3  CHNCXR_0004_0.png    Male   58   normal
4  CHNCXR_0005_0.png    Male   28   normal
Index(['study_id', 'sex', 'age', 'findings'], dtype='object')
            ID                                    name  versionNumber  \
0  syn41743692  CODA TB Challenge data dictionary.xlsx              4   
1  syn52357041  CODA_TB_additional_variables_train.csv              1   
2  syn41604915          CODA_TB_Clinical_Meta_Info.csv              3   
3  syn41604939         CODA_TB_Solicited_Meta_Info.csv              2   
4  syn41604935       CODA_TB_Longitudnal_Meta_Info.csv              2   

                                         contentType  dataFileSizeBytes  \
0  application/vnd.openxmlformats-officedocument....              19702   
1                                           text/csv              74624   
2               

In [7]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")


In [11]:
# ================================
# Multimodal Training Module (Sanity Check, 2 Epochs)
# ================================
import os
import time
import torch
import torchaudio
import torchvision
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -----------------------------
# Audio Dataset
# -----------------------------
class AudioSpectrogramDataset(Dataset):
    def __init__(self, df, sr=16000, n_mels=64, max_len=256):
        self.df = df.reset_index(drop=True)
        self.sr = sr
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=400, hop_length=160, n_mels=n_mels
        )
        self.db = torchaudio.transforms.AmplitudeToDB()
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path, label = self.df.loc[idx, "filepath"], self.df.loc[idx, "label"]
        wav, sr = torchaudio.load(path)
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        spec = self.mel(wav)   # (1, n_mels, time)
        spec = self.db(spec)
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)

        # pad or truncate to max_len
        if spec.shape[-1] < self.max_len:
            pad = self.max_len - spec.shape[-1]
            spec = torch.nn.functional.pad(spec, (0, pad))
        else:
            spec = spec[:, :, :self.max_len]

        return spec, torch.tensor(label).long()

# -----------------------------
# X-ray Dataset
# -----------------------------
class XrayDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_path = os.path.join(self.root_dir, row['filename'])
        img = torchvision.io.read_image(img_path).float() / 255.0  # [C,H,W]

        # Convert 1 channel → 3 channels
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)

        if self.transform:
            img = self.transform(img)
        label = torch.tensor(row['label']).long()
        return img, label


# -----------------------------
# Prepare Audio Dataset
# -----------------------------
def prepare_audio_dataset(meta_csvs, audio_root, test_size=0.15, val_size=0.15):
    dfs = [pd.read_csv(f) for f in meta_csvs]
    meta = pd.concat(dfs, ignore_index=True)

    if 'filename' not in meta.columns or 'tb_status' not in meta.columns:
        raise ValueError("Metadata missing required columns 'filename' and 'tb_status'!")

    # drop missing TB labels
    meta = meta.dropna(subset=['tb_status'])
    meta['filepath'] = meta['filename'].apply(lambda x: os.path.join(audio_root, x))
    meta['label'] = meta['tb_status'].astype(int)

    train_df, test_df = train_test_split(meta, test_size=test_size, stratify=meta['label'], random_state=42)
    train_df, val_df = train_test_split(train_df, test_size=val_size, stratify=train_df['label'], random_state=42)
    return train_df, val_df, test_df

# -----------------------------
# Prepare X-ray Dataset
# -----------------------------
def prepare_xray_dataset(meta_csv, xray_root, test_size=0.15, val_size=0.15):
    meta = pd.read_csv(meta_csv)

    if 'study_id' not in meta.columns or 'findings' not in meta.columns:
        raise ValueError("Metadata missing required columns 'study_id' and 'findings'!")

    # Map TB: normal → 0, anything else → 1
    meta['label'] = meta['findings'].apply(lambda x: 0 if str(x).lower() == 'normal' else 1)
    meta['filename'] = meta['study_id'].astype(str)  # match file naming

    train_df, test_df = train_test_split(meta, test_size=test_size, stratify=meta['label'], random_state=42)
    train_df, val_df = train_test_split(train_df, test_size=val_size, stratify=train_df['label'], random_state=42)
    return train_df, val_df, test_df

# -----------------------------
# Build Models
# -----------------------------
def build_resnet_audio(n_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model.to(device)

def build_resnet_xray(n_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model.to(device)

# -----------------------------
# Training Loop
# -----------------------------
def train_model(model, train_loader, val_loader, epochs=2, lr=1e-3, save_path="best_model.pt"):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_f1 = 0.0

    for epoch in range(1, epochs+1):
        start_time = time.time()
        model.train()
        train_loss = 0.0
        y_true, y_pred = [], []
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*x.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(out.argmax(dim=1).cpu().numpy())
        train_loss /= len(train_loader.dataset)
        train_acc = accuracy_score(y_true, y_pred)
        train_f1 = f1_score(y_true, y_pred, average='weighted')

        model.eval()
        val_loss = 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                val_loss += loss.item()*x.size(0)
                y_true.extend(y.cpu().numpy())
                y_pred.extend(out.argmax(dim=1).cpu().numpy())
        val_loss /= len(val_loader.dataset)
        val_acc = accuracy_score(y_true, y_pred)
        val_f1 = f1_score(y_true, y_pred, average='weighted')

        epoch_time = time.time() - start_time
        print(f"Epoch {epoch}/{epochs} | Train Loss {train_loss:.4f} Acc {train_acc:.4f} F1 {train_f1:.4f} | "
              f"Val Loss {val_loss:.4f} Acc {val_acc:.4f} F1 {val_f1:.4f} | Time {epoch_time:.2f}s")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), save_path)
            print("-> New best model saved!")

    return model

# ==============================
# Run Audio Pipeline
# ==============================
audio_meta_csvs = [
    "/content/data/Tuberculosis/metadata/X_train_Fold_1.csv",
    "/content/data/Tuberculosis/metadata/X_train_Fold_2.csv"
]  # only existing files
audio_root = "/content/data/Tuberculosis/raw_data/solicited_data"

train_df_a, val_df_a, test_df_a = prepare_audio_dataset(audio_meta_csvs, audio_root)
train_loader_a = DataLoader(AudioSpectrogramDataset(train_df_a), batch_size=16, shuffle=True)
val_loader_a = DataLoader(AudioSpectrogramDataset(val_df_a), batch_size=16, shuffle=False)

audio_model = build_resnet_audio()
print("\n--- Training Audio Model ---")
audio_model = train_model(audio_model, train_loader_a, val_loader_a, epochs=2, save_path="audio_best.pt")

# ==============================
# Run X-ray Pipeline
# ==============================
xray_csv = "/content/data/shenzhen_metadata.csv"
xray_root = "/content/data/images/images"

train_df_x, val_df_x, test_df_x = prepare_xray_dataset(xray_csv, xray_root)

transform = transforms.Compose([
    transforms.Resize((224,224))
])

train_loader_x = DataLoader(XrayDataset(train_df_x, xray_root, transform), batch_size=16, shuffle=True)
val_loader_x = DataLoader(XrayDataset(val_df_x, xray_root, transform), batch_size=16, shuffle=False)

xray_model = build_resnet_xray()
print("\n--- Training X-ray Model ---")
xray_model = train_model(xray_model, train_loader_x, val_loader_x, epochs=2, save_path="xray_best.pt")


Device: cuda

--- Training Audio Model ---
Epoch 1/2 | Train Loss 0.5935 Acc 0.6866 F1 0.6358 | Val Loss 0.5786 Acc 0.7020 F1 0.5791 | Time 75.37s
-> New best model saved!
Epoch 2/2 | Train Loss 0.5698 Acc 0.7011 F1 0.6588 | Val Loss 0.6063 Acc 0.6418 F1 0.6570 | Time 74.89s
-> New best model saved!

--- Training X-ray Model ---
Epoch 1/2 | Train Loss 0.5867 Acc 0.7841 F1 0.7841 | Val Loss 0.5697 Acc 0.7647 F1 0.7637 | Time 105.62s
-> New best model saved!
Epoch 2/2 | Train Loss 0.3277 Acc 0.8700 F1 0.8693 | Val Loss 0.9422 Acc 0.7882 F1 0.7856 | Time 103.90s
-> New best model saved!


In [None]:
# ==============================
# Multimodal Training Module (20 Epochs)
# ==============================

# -----------------------------
# Audio Training (20 epochs)
# -----------------------------
print("\n=== Training Audio Model for 20 Epochs ===")
audio_model_20 = build_resnet_audio()  # new model instance
audio_model_20 = train_model(
    audio_model_20,
    train_loader_a,
    val_loader_a,
    epochs=15,
    lr=1e-3,
    save_path="audio_best_20.pt"
)

# -----------------------------
# X-ray Training (20 epochs)
# -----------------------------
print("\n=== Training X-ray Model for 20 Epochs ===")
xray_model_20 = build_resnet_xray()  # new model instance
xray_model_20 = train_model(
    xray_model_20,
    train_loader_x,
    val_loader_x,
    epochs=15,
    lr=1e-3,
    save_path="xray_best_20.pt"
)



=== Training Audio Model for 20 Epochs ===
Epoch 1/20 | Train Loss 0.6075 Acc 0.6865 F1 0.6296 | Val Loss 0.5719 Acc 0.7178 F1 0.6747 | Time 75.25s
-> New best model saved!
Epoch 2/20 | Train Loss 0.5653 Acc 0.7061 F1 0.6598 | Val Loss 0.5598 Acc 0.7028 F1 0.5850 | Time 76.20s
Epoch 3/20 | Train Loss 0.5514 Acc 0.7131 F1 0.6755 | Val Loss 0.5574 Acc 0.7239 F1 0.7036 | Time 75.20s
-> New best model saved!
Epoch 4/20 | Train Loss 0.5355 Acc 0.7223 F1 0.6982 | Val Loss 0.5961 Acc 0.7126 F1 0.6905 | Time 74.92s
Epoch 5/20 | Train Loss 0.5088 Acc 0.7474 F1 0.7319 | Val Loss 0.5436 Acc 0.7494 F1 0.7208 | Time 76.90s
-> New best model saved!
Epoch 6/20 | Train Loss 0.4796 Acc 0.7659 F1 0.7553 | Val Loss 0.5125 Acc 0.7562 F1 0.7516 | Time 75.33s
-> New best model saved!
Epoch 7/20 | Train Loss 0.4274 Acc 0.8041 F1 0.7985 | Val Loss 0.5892 Acc 0.7126 F1 0.7241 | Time 75.38s
Epoch 8/20 | Train Loss 0.3879 Acc 0.8315 F1 0.8291 | Val Loss 0.4431 Acc 0.7976 F1 0.7913 | Time 76.35s
-> New best mode

In [10]:
import os

def list_folders(base_path, level=1, prefix=""):
    if not os.path.exists(base_path):
        print(f"Path not found: {base_path}")
        return
    for root, dirs, files in os.walk(base_path):
        depth = root.replace(base_path, "").count(os.sep)
        if depth < level:
            for d in dirs:
                print(f"[DIR] {os.path.join(root, d)}")
        else:
            # Don’t go deeper than the specified level
            dirs[:] = []

# 🔹 Change this to explore deeper if needed
list_folders("/content/data/Tuberculosis", level=3)


[DIR] /content/data/Tuberculosis/metadata
[DIR] /content/data/Tuberculosis/raw_data
[DIR] /content/data/Tuberculosis/.ipynb_checkpoints
[DIR] /content/data/Tuberculosis/raw_data/solicited_data
[DIR] /content/data/Tuberculosis/raw_data/longitudinal_data
[DIR] /content/data/Tuberculosis/raw_data/longitudinal_data/longitudinal_1
[DIR] /content/data/Tuberculosis/raw_data/longitudinal_data/longitudinal_2
