In [6]:
import numpy as np
import pandas as pd
import torchvision
import torchvision.transforms.v2 as v2
import torch
import torch.nn as nn
import torch.optim as optim
import os
from sklearn.metrics import roc_auc_score
import mlflow
from sklearn.metrics import RocCurveDisplay
import matplotlib.pyplot as plt

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chk = torch.load("./checkpoint.pth", map_location=device)
print(chk)

{'epoch': 15, 'model_state_dict': OrderedDict([('backbone.features.0.0.weight', tensor([[[[-0.1624, -0.0062, -0.2342],
          [ 0.1050,  0.4118, -0.0385],
          [ 0.2082,  0.2498,  0.0724]]],


        [[[ 0.6261, -0.4116,  0.0470],
          [ 0.6484, -0.6640,  0.1541],
          [ 0.0712, -0.6306,  0.1259]]],


        [[[ 0.3995,  0.0627, -0.1441],
          [ 0.4336, -0.1410, -0.2059],
          [ 0.5984,  0.0536, -0.2916]]],


        [[[-0.4069, -0.3150,  0.2186],
          [-0.2972, -0.3328,  0.3142],
          [-0.2789, -0.1684,  0.2472]]],


        [[[ 0.8462,  0.0296, -0.5856],
          [ 0.4713,  0.2171, -0.2862],
          [-0.7185, -0.0772,  0.3369]]],


        [[[-0.4272, -0.7628, -0.3737],
          [-0.1510,  0.4384,  0.3357],
          [-0.1477,  0.8567,  0.3795]]],


        [[[ 0.3014,  0.3746,  0.0935],
          [ 0.4526,  0.0722,  0.2408],
          [ 0.4830,  0.2494,  0.2255]]],


        [[[ 0.2364, -0.4249, -0.9359],
          [ 0.0661, -1.0498, -1.01

In [8]:
from huggingface_hub import PyTorchModelHubMixin 
# EfficientNetB0 with 14 class output
class ChestXRayModel(nn.Module, PyTorchModelHubMixin,repo_url="chestxpert",
    license="mit",):
    def __init__(self, num_classes=14):
        super(ChestXRayModel, self).__init__()
        self.backbone = torchvision.models.efficientnet_b0(weights = torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        orig_conv = self.backbone.features[0][0]
        self.backbone.features[0][0] = nn.Conv2d(1, orig_conv.out_channels,
                        kernel_size=orig_conv.kernel_size,
                        stride=orig_conv.stride,
                        padding=orig_conv.padding,
                        bias=False)
        self.classifier = torch.nn.Linear(self.backbone.classifier[1].in_features, num_classes)
        self.backbone.classifier = self.backbone.classifier[0]
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

In [9]:
model = ChestXRayModel()
model.load_state_dict(chk['model_state_dict'])
model.save_pretrained("my-awesome-model")


In [10]:
# push to the hub
model.push_to_hub("Lait-au-pole/chestxpert")
# reload
model = ChestXRayModel.from_pretrained("Lait-au-pole/chestxpert")

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/16.3M [00:00<?, ?B/s]

In [None]:
from huggingface_hub import PyTorchModelHubMixin 

class HfModel(nn.Module,
    PyTorchModelHubMixin, 
    # optionally, you can add metadata which gets pushed to the model card
    repo_url="your-repo-url",
    pipeline_tag="text-to-image",
    license="mit",):
    def __init__(self, model: nn.Module):
        super().__init__()

In [None]:
# Random seeding for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Data Pre-Processing
df = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv')

# get unique "Finding Labels"
findings = df['Finding Labels'].unique()

for i, label in enumerate(findings):
    # split by '|'
    split_labels = label.split('|')
    for split_label in split_labels:
        findings = np.append(findings, split_label)
    # remove the original label
    if len(split_labels) > 1 or split_labels[0] == "No Finding":
        findings = np.delete(findings, np.where(findings == label)) 
findings = np.unique(findings)
print(f"Unique Findings: {findings}")

#data splits
#get unique patient id
patient_idf = df['Patient ID'].value_counts(normalize=True, sort=False)
splits={'train':0.7,'val':0.15,'test':0.15}
total_count = 0.0
train_ids, val_ids, test_ids = [], [], []
for id, count in patient_idf.items():
    if total_count < splits['train']:
        train_ids.append(id)
    elif total_count < splits['train'] + splits['val']:
        val_ids.append(id)
    else:
        test_ids.append(id)
    total_count += count

train_df = df[df['Patient ID'].isin(train_ids)]
val_df = df[df['Patient ID'].isin(val_ids)]
test_df = df[df['Patient ID'].isin(test_ids)]

def reduce_size(df, reduction_factor=0.1):
    return df.iloc[:int(df.shape[0]*reduction_factor)]
reduction_factor = 0.05
train_df = reduce_size(train_df, reduction_factor)
val_df = reduce_size(val_df, reduction_factor)
test_df = reduce_size(test_df, reduction_factor)

In [None]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, data_dir, transform=None):
        self.dataframe = dataframe
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_path = ""
        for dir in os.listdir(self.data_dir):
            if os.path.isdir(os.path.join(self.data_dir, dir)):
                potential_path = os.path.join(self.data_dir, dir, "images", row['Image Index'])
                if os.path.exists(potential_path):
                    img_path = potential_path
                    break
        labels = row['Finding Labels'].split('|')
        # Convert labels to multi-hot encoding
        multi_hot = np.where(np.isin(findings, labels), 1, 0).astype(np.float32)
        multi_hot = torch.tensor(multi_hot, dtype=torch.float32)
        image = torchvision.io.decode_image(img_path, 'GRAY')
        image = v2.functional.to_dtype(image, torch.float32, scale=True)
        if self.transform: image = self.transform(image)
        return image, multi_hot

In [None]:
# transforms
resize = (128,128)#(224,224)
train_tf = v2.Compose([
    v2.Resize(resize),
    v2.RandomRotation(7),
    v2.ColorJitter(brightness=0.1, contrast=0.1),
    v2.Normalize([0.5], [0.5])
])

val_tf = v2.Compose([v2.Resize(resize), v2.Normalize([0.5],[0.5])])


In [None]:
def train_one_epoch(model, loader, opt, device, scaler, loss_fn):
    model.train()
    running_loss = 0.0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.amp.autocast(str(device)):
            logits = model(imgs)
            loss = loss_fn(logits, labels)
        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(loader.dataset)


def validate(model, loader, device, loss_fn):
    model.eval()
    ys, preds = [], []
    running_loss = 0.0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            loss = loss_fn(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            preds.append(torch.sigmoid(logits).cpu().numpy())
            ys.append(labels.cpu().numpy())
    preds = np.vstack(preds); ys = np.vstack(ys)
    per_class_auc = []
    for i in range(ys.shape[1]):
        try:
            per_class_auc.append(roc_auc_score(ys[:,i], preds[:,i]))
        except:
            per_class_auc.append(np.nan)
    return np.nanmean(per_class_auc), per_class_auc, running_loss/ len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            preds.append(torch.sigmoid(logits).cpu().numpy())
            ys.append(labels.numpy())
    preds = np.vstack(preds); ys = np.vstack(ys)
    per_class_auc = []
    for i in range(ys.shape[1]):
        try:
            per_class_auc.append(roc_auc_score(ys[:,i], preds[:,i]))
        except:
            per_class_auc.append(np.nan)
    return np.nanmean(per_class_auc), per_class_auc, (preds, ys)

In [None]:
# Train/validation split
data_dir = '/kaggle/input/data'
train_dataset = ChestXRayDataset(train_df, data_dir, train_tf)
val_dataset = ChestXRayDataset(val_df, data_dir, val_tf)
test_dataset = ChestXRayDataset(test_df, data_dir, val_tf)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Training
model = ChestXRayModel()
model.to(device)
# params

loss_fn = torch.nn.BCEWithLogitsLoss()
batch_size = 128 
warmup = 0
lr = 1e-2
epochs = 2
optimizer = optim.AdamW(model.parameters(), lr)
schedulers = [optim.lr_scheduler.LinearLR(optimizer, total_iters=warmup), optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)]
scheduler = optim.lr_scheduler.ChainedScheduler(schedulers, optimizer)
scaler = torch.amp.GradScaler(str(device))

params = {
    'train_transforms': train_tf,
    'validation_transforms': val_tf,
    'loss_fn': loss_fn,
    'batch_size': batch_size,
    'warmup': warmup,
    'lr': lr,
    'epochs': epochs,
    'optimizer': optimizer,
    'scheduler': scheduler,
    "findings": ",".join(findings)
}

train_dataset_mlfow = mlflow.data.from_pandas(train_df, source='kaggle/dataset/nih-chest-xrays/data', name='ChestXRay')


# Resume helpers (Kaggle dataset: "resume-training")
resume_dir = '/kaggle/input/resume-training'
mlflow_db_file = 'test_mlflow.db'
ckpt_path_input = os.path.join(resume_dir, 'checkpoint.pth')
mlflow_db_input = os.path.join(resume_dir, mlflow_db_file)

if os.path.exists(mlflow_db_input):
    import shutil
    print("Found mlflow db in resume dataset")
    shutil.copy(mlflow_db_input, mlflow_db_file)
else:
    print("no mlfow db found, new mlflow db will be created")

# Use local DB path for mlflow (will create if missing)
mlflow.set_tracking_uri(f"sqlite:///{mlflow_db_file}")
mlflow.set_experiment("test_model_log_register")
mlflow.enable_system_metrics_logging()

start_epoch = 0
best_auc = -float('inf')
epochs_no_improve = 0

# If a checkpoint exists in the resume dataset, load states
if os.path.exists(ckpt_path_input):
    print("Found checkpoint in resume dataset, loading...")
    ckpt = torch.load(ckpt_path_input, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    if 'scaler_state_dict' in ckpt:
        scaler.load_state_dict(ckpt['scaler_state_dict'])
    if 'scheduler_state_dict' in ckpt:
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    if 'epochs_no_improve' in ckpt and 'best_auc' in ckpt:
        epochs_no_improve = ckpt['epochs_no_improve']
        best_auc = ckpt['best_auc']
    start_epoch = int(ckpt.get('epoch', 0)) + 1
    print(f"Resuming from epoch {start_epoch}, best_auc={best_auc}, epochs_no_improve={epochs_no_improve}")

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,pin_memory=True, num_workers=0, persistent_workers=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0, persistent_workers=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0, persistent_workers=False)

#last_run = mlflow.last_active_run() #TODO:Think about cases where we want a new run
with mlflow.start_run() as run: 
    # Log parameters
    mlflow.log_params(params)
    for epoch in range(start_epoch, epochs):
        mlflow.log_metric("lr", scheduler.get_last_lr()[0], step=epoch)
        train_loss = train_one_epoch(model, train_loader, optimizer, device, scaler, loss_fn)
        scheduler.step()
        glob_auc, aucs, val_loss = validate(model, val_loader, device, loss_fn)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {glob_auc:.4f}")

        # Log metrics
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_auc", glob_auc, step=epoch)
        for i, auc in enumerate(aucs):
            if np.isnan(auc):
                continue
            mlflow.log_metric(f"val_auc_class_{findings[i]}", float(auc), step=epoch)

        # Early stopping if val_auc does not improve for 5 epochs
        if glob_auc > best_auc:
            best_auc = glob_auc
            epochs_no_improve = 0
            model_signature = mlflow.models.signature.infer_signature(
                np.random.randn(1, 1, resize[0], resize[1]).astype(np.float32),
                model(torch.randn(1, 1, resize[0], resize[1]).to(device)).detach().cpu().numpy()
            )
            model_info = mlflow.pytorch.log_model(pytorch_model=model, name=f'ChestXRayModel_effnetb0_{epoch}', signature=model_signature, step=epoch, code_paths="./models/efficientnetb0.py")
            mlflow.log_metric("best_auc", best_auc, model_id=model_info.model_id, step=epoch, dataset=train_dataset_mlfow)
        else:
            epochs_no_improve += 1
        if epochs_no_improve == 5:
            print("Early stopping!")
            break

        # Save the training state in case of a timeout stopping
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'epochs_no_improve': epochs_no_improve,
            'best_auc': best_auc,
        }, 'checkpoint.pth')

    ranked_models = mlflow.search_logged_models(
        filter_string=f"source_run_id='{run.info.run_id}'",
        order_by=[{"field_name": "metrics.best_auc", "ascending": False}],
        output_format="list",
    )
    best_model = ranked_models[0]
    mlflow.register_model(best_model.model_uri, 'ChestXRayModel_effnet0_v1.0.0')
    model = mlflow.pytorch.load_model(best_model.model_uri)

    model_signature = mlflow.models.signature.infer_signature(
        np.random.randn(1, 1, resize[0], resize[1]).astype(np.float32),
        model(torch.randn(1, 1, resize[0], resize[1]).to(device)).detach().cpu().numpy()
    )
    # log metricts
    test_glob_auc, test_aucs, classes_pred_and_truth = evaluate(model, test_loader, device)
    mlflow.log_metric("test_auc", test_glob_auc, model_id=best_model.model_id)
    metrics_test = dict()
    for i, auc in enumerate(test_aucs):
        if np.isnan(auc):
            continue
        metrics_test[f"test_auc_class_{findings[i]}"] = float(auc)
    mlflow.log_metrics(metrics_test, model_id=best_model.model_id, dataset=train_dataset_mlfow)
    preds, ys = classes_pred_and_truth
    for i in range(ys.shape[1]):
        try:
            RocCurveDisplay.from_predictions(ys[:,i], preds[:,i])
            plt.title(f'ROC Curve for {findings[i]}')
            plt.savefig(f'roc_curve_class_{findings[i]}.png')
            mlflow.log_artifact(f'roc_curve_class_{findings[i]}.png', artifact_path='roc_curves', model_id=best_model.model_id)
            plt.close()
        except Exception as e:
            print(f"Could not plot ROC curve for class {findings[i]}: {e}")