In [1]:
import os
import pandas as pd
import config
from utils.dataset import chest_xray_with_mask_datasplit, ChestXrayDatasetWithMask
from utils.model import ChestXrayDenseNet121WithMask
from utils.train import train, validate, compute_pos_weight
import torch

In [2]:
test_path = os.path.join(config.DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_test.csv')
train_path = os.path.join(config.DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_train.csv')
val_path = os.path.join(config.DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_val.csv')

# Load all CSVs
df_train = pd.read_csv(train_path)
df_val = pd.read_csv(val_path)
df_test = pd.read_csv(test_path)

# Combine them
full_df = pd.concat([df_train, df_val, df_test], ignore_index=True)

In [3]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
full_dataset = ChestXrayDatasetWithMask(dataframe=full_df, img_dir=os.path.join(config.DATASET_DIR, 'cxr', 'images'), device=device)

train_dataset, val_dataset, test_dataset = chest_xray_with_mask_datasplit(full_df, full_dataset, dataset_dir=os.path.join(config.DATASET_DIR, 'cxr', 'images'), device=device)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
# Initialize model
model = ChestXrayDenseNet121WithMask(num_classes=19)

In [5]:
def compute_pos_weight_from_df(full_df, label_cols):
    num_samples = len(full_df)

    # Sum positives per class
    pos_counts = full_df[label_cols].sum().values  # shape [num_classes]
    neg_counts = num_samples - pos_counts

    pos_weight = neg_counts / (pos_counts + 1e-6)  # avoid div by zero
    return torch.tensor(pos_weight, dtype=torch.float32)

classes_df = full_df.drop(columns=['id', 'No Finding', 'subj_id'])
pos_weight = compute_pos_weight_from_df(classes_df, list(classes_df.columns))

In [6]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Train
train(model, train_loader, val_loader, device, epochs=10, lr=1e-4, save_path=config.MODEL_FOLDER, file_name="PulmoScanX_v1.1", pos_weight=pos_weight)

Epoch 1/10:   2%|▏         | 51/2453 [01:07<52:56,  1.32s/it, loss=1.21] 


KeyboardInterrupt: 