In [1]:
import os
import pandas as pd
from config import DATASET_DIR
from utils.dataset import ChestXrayDataset 
from utils.transform import train_transform, val_test_transform
from utils.model import ChestXrayDenseNet121
from utils.train import train, validate

In [2]:
test_path = os.path.join(DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_test.csv')
train_path = os.path.join(DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_train.csv')
val_path = os.path.join(DATASET_DIR, 'miccai2023_nih-cxr-lt_labels_val.csv')

# Load all CSVs
df_train = pd.read_csv(train_path)
df_val = pd.read_csv(val_path)
df_test = pd.read_csv(test_path)

# Combine them
full_df = pd.concat([df_train, df_val, df_test], ignore_index=True)

In [3]:
from torch.utils.data import random_split, DataLoader, Subset
import torch

full_dataset = ChestXrayDataset(dataframe=full_df, img_dir=os.path.join(DATASET_DIR, 'cxr', 'images'), transform=None)


total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_subset, val_subset, test_subset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_dataset = Subset(
    ChestXrayDataset(dataframe=full_df, img_dir=os.path.join(DATASET_DIR, 'cxr', 'images'), transform=train_transform),
    train_subset.indices
)

val_dataset = Subset(
    ChestXrayDataset(dataframe=full_df, img_dir=os.path.join(DATASET_DIR, 'cxr', 'images'), transform=val_test_transform),
    val_subset.indices
)

test_dataset = Subset(
    ChestXrayDataset(dataframe=full_df, img_dir=os.path.join(DATASET_DIR, 'cxr', 'images'), transform=val_test_transform),
    test_subset.indices
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
# Initialize model
model = ChestXrayDenseNet121(num_classes=19)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train
train(model, train_loader, val_loader, device, epochs=10, lr=1e-4)

Epoch 1/10:  49%|████▉     | 1196/2453 [13:19<13:42,  1.53it/s, loss=0.101] 