In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, accuracy_score
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import os
from collections import OrderedDict

2025-04-16 15:28:54.076767: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 15:28:54.162269: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-16 15:28:54.162330: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-16 15:28:54.164381: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-16 15:28:54.182365: I tensorflow/core/platform/cpu_feature_guar

# data

In [2]:
import os
import os.path
import random

import numpy as np
import pandas as pd
from typing import Dict, List
import torch
from torch.utils.data import Dataset
import tensorflow as tf

In [None]:
def load_embedding(embedding_path):
    raw_dataset = tf.data.TFRecordDataset([embedding_path])

    for raw_record in raw_dataset.take(1):
        example = tf.train.Example()

        example.ParseFromString(raw_record.numpy())

        embedding_feature = example.features.feature['embedding']

        embedding_values = embedding_feature.float_list.value

    return torch.tensor(embedding_values)

In [None]:
class MIMIC_Embed_Dataset(Dataset):

    pathologies = [
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices",
    ]

    split_ratio = [0.8, 0.1, 0.1]

    def __init__(
        self,
        embedpath,
        csvpath,
        metacsvpath,
        views=["PA"],
        data_aug=None,
        seed=0,
        unique_patients=True,
        mode=["train", "valid", "test"][0],
    ):

        super().__init__()
        np.random.seed(seed)  # Reset the seed so all runs are the same.

        self.pathologies = sorted(self.pathologies)

        self.mode = mode
        self.embedpath = embedpath
        self.data_aug = data_aug
        self.csvpath = csvpath
        self.csv = pd.read_csv(self.csvpath)
        self.metacsvpath = metacsvpath
        self.metacsv = pd.read_csv(self.metacsvpath)

        self.csv = self.csv.set_index(["subject_id", "study_id"])
        self.metacsv = self.metacsv.set_index(["subject_id", "study_id"])

        self.csv = self.csv.join(self.metacsv).reset_index()

        # Keep only the desired view
        self.csv["view"] = self.csv["ViewPosition"]
        self.limit_to_selected_views(views)

        if unique_patients:
            self.csv = self.csv.groupby("subject_id").first().reset_index()

        n_row = self.csv.shape[0]

        # spit data to one of train valid test
        if self.mode == "train":
            self.csv = self.csv[: int(n_row * self.split_ratio[0])]
        elif self.mode == "valid":
            self.csv = self.csv[
                int(n_row * self.split_ratio[0]) : int(
                    n_row * (self.split_ratio[0] + self.split_ratio[1])
                )
            ]
        elif self.mode == "test":
            self.csv = self.csv[-int(n_row * self.split_ratio[-1]) :]
        else:
            raise ValueError(
                f"attr:mode has to be one of [train, valid, test] but your input is {self.mode}"
            )

        # Get our classes.
        healthy = self.csv["No Finding"] == 1
        labels = []
        for pathology in self.pathologies:
            if pathology in self.csv.columns:
                self.csv.loc[healthy, pathology] = 0
                mask = self.csv[pathology]

            labels.append(mask.values)
        self.labels = np.asarray(labels).T
        self.labels = self.labels.astype(np.float32)

        # Make all the -1 values into nans to keep things simple
        self.labels[self.labels == -1] = np.nan

        # Rename pathologies
        self.pathologies = list(
            np.char.replace(self.pathologies, "Pleural Effusion", "Effusion")
        )

        # add consistent csv values

        # offset_day_int
        self.csv["offset_day_int"] = self.csv["StudyDate"]

        # patientid
        self.csv["patientid"] = self.csv["subject_id"].astype(str)

    def string(self):
        return self.__class__.__name__ + " num_samples={} views={}".format(
            len(self), self.views,
        )

    def limit_to_selected_views(self, views):
        """This function is called by subclasses to filter the
        images by view based on the values in .csv['view']
        """
        if type(views) is not list:
            views = [views]
        if '*' in views:
            # if you have the wildcard, the rest are irrelevant
            views = ["*"]
        self.views = views

        # missing data is unknown
        self.csv.view.fillna("UNKNOWN", inplace=True)

        if "*" not in views:
            self.csv = self.csv[self.csv["view"].isin(self.views)]  # Select the view

    def __len__(self):
        return len(self.labels)

    # def __getitem__(self, idx):
    #     sample = {}
    #     sample["idx"] = idx
    #     sample["lab"] = self.labels[idx]

    #     subjectid = str(self.csv.iloc[idx]["subject_id"])
    #     studyid = str(self.csv.iloc[idx]["study_id"])
    #     dicom_id = str(self.csv.iloc[idx]["dicom_id"])


    #     #data_aug
    #     embed_file = os.path.join(
    #         self.embedpath,
    #         "p" + subjectid[:2],
    #         "p" + subjectid,
    #         "s" + studyid,
    #         dicom_id + ".tfrecord",
    #     )
    #     sample["embedding"] = load_embedding(embed_file)
    #     #sample["embedding"] = embed_file

    #     return sample
    def __getitem__(self, idx):
        subjectid = str(self.csv.iloc[idx]["subject_id"])
        studyid = str(self.csv.iloc[idx]["study_id"])
        dicom_id = str(self.csv.iloc[idx]["dicom_id"])

        embed_file = os.path.join(
            self.embedpath,
            "p" + subjectid[:2],
            "p" + subjectid,
            "s" + studyid,
            dicom_id + ".tfrecord",
        )
        embedding = load_embedding(embed_file)     # shape: [1376]
        label = self.labels[idx]                   # shape: [13]

        return {
            "x": embedding,   
            "lab": label       
        }


In [5]:
embedpath = "/d/hd04/armstrong/MIMIC/data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/files"
csvpath = "/d/hd04/armstrong/MIMIC/data/mimic-cxr-2.0.0-chexpert.csv"
metacsvpath = "/d/hd04/armstrong/MIMIC/data/mimic-cxr-2.0.0-metadata.csv"

dataset = MIMIC_Embed_Dataset(embedpath,csvpath,metacsvpath,mode = "train")

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  self.csv.view.fillna("UNKNOWN", inplace=True)


In [6]:
sample = dataset[1000]
sample

2025-04-16 11:19:44.132466: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


{'x': tensor([-0.6009, -2.6448,  1.1589,  ...,  0.3574,  1.4271, -2.1083]),
 'lab': array([nan,  1., nan, nan, nan,  1., nan,  1., nan, nan, nan,  1.,  0.],
       dtype=float32)}

In [None]:
from torch.utils.data import DataLoader

N = 36000
subset = [dataset[i] for i in range(N)]
train_len = int(N * 0.8)
train_subset = subset[:train_len]
val_subset = subset[train_len:]

def collate_fn(batch):
    x = torch.stack([sample['x'] for sample in batch])  
    y = torch.stack([torch.tensor(sample['lab']) for sample in batch])  
    return {
        "x": x,
        "y": y
    }


train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False, collate_fn=collate_fn)
test_subset = [dataset[i] for i in range(N, N + 600)]
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False, collate_fn=collate_fn)

In [None]:
from torch.utils.data import DataLoader, Subset

full_len = len(dataset)
indices = list(range(full_len))
train_len = int(full_len * 0.8)

train_indices = indices[:train_len]
val_indices = indices[train_len:]

train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)

def collate_fn(batch):
    x = torch.stack([sample['x'] for sample in batch])
    y = torch.stack([torch.tensor(sample['lab']) for sample in batch])
    return {"x": x, "y": y}

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, collate_fn=collate_fn, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory=True)

test_indices = list(range(full_len - 600, full_len))
test_subset = Subset(dataset, test_indices)
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory=True)


# models

In [None]:
class FeatureReformer(nn.Module):
    def __init__(self, input_dim=1376, proj_dim=512, norm_type='layer', patchify=False, patch_num=16):
        super().__init__()
        self.patchify = patchify
        self.patch_num = patch_num
        self.proj_dim = proj_dim

        if patchify:
            assert proj_dim % patch_num == 0, "proj_dim must be divisible by patch_num"
            self.patch_dim = proj_dim // patch_num
        else:
            self.patch_dim = proj_dim  

        if norm_type == 'layer':
            self.norm = nn.LayerNorm(self.patch_dim)
        elif norm_type == 'batch':
            self.norm = nn.BatchNorm1d(self.patch_dim)
        else:
            raise ValueError(f"Unsupported norm_type: {norm_type}")

        self.fc = nn.Linear(input_dim, proj_dim)
        self.gate = nn.Linear(input_dim, proj_dim)
        self.activation = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):  # x: (B, D)
        h = self.activation(self.fc(x))
        g = self.sigmoid(self.gate(x))
        z = h * g  # gated projection

        if self.patchify:
            # reshape: (B, patch_num, patch_dim)
            z = z.view(z.size(0), self.patch_num, self.patch_dim)

        if isinstance(self.norm, nn.BatchNorm1d):
            if self.patchify:
                B, N, D = z.shape
                z = self.norm(z.view(B * N, D)).view(B, N, D)
            else:
                z = self.norm(z)
        else:
            z = self.norm(z)

        return z


In [None]:
import torch
import torch.nn as nn

class VanillaTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(1)  # [B, 1, D]
        h = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = x + h
        h = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + h
        return x.squeeze(1)  

class gMLPBlock(nn.Module):
    def __init__(self, dim, hidden_dim, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.sgu = nn.Linear(seq_len, seq_len)  # spatial gating
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        x = self.norm(x)
        residual = x
        x = self.act(self.fc1(x))
        x = x.transpose(1, 2)        # [B, C, L]
        x = self.sgu(x)
        x = x.transpose(1, 2)        # [B, L, C]
        return self.fc2(x) + residual

class SwinMLPBlock(nn.Module):
    def __init__(self, input_dim, proj_dim=512, window_size=16,
                 mlp_ratio=2.0, shift=False, debug=False,
                 reform_input=True, flatten_output=True):
        super().__init__()
        self.reform_input = reform_input
        self.flatten_output = flatten_output
        self.shift = shift
        self.debug = debug
        self.window_size = window_size

        if self.reform_input:
            self.reformer = FeatureReformer(
                input_dim=input_dim,
                proj_dim=proj_dim,
                norm_type='layer',
                patchify=True,
                patch_num=window_size  # patch_num = window_size
            )
            input_dim = proj_dim  

        assert input_dim % window_size == 0, "input_dim must be divisible by window_size"
        self.patch_dim = input_dim // window_size

        self.norm = nn.LayerNorm(self.patch_dim)
        self.proj = nn.Linear(input_dim, input_dim)
        self.gmlp = gMLPBlock(self.patch_dim, int(self.patch_dim * mlp_ratio), seq_len=window_size)

    def forward(self, x):  # x: [B, input_dim] or [B, N, D]
        if self.reform_input:
            x = self.reformer(x)  # [B, N, D]
            B, N, D = x.shape
            x = x.view(B, -1) 
        B, D = x.shape

        if self.debug:
            print(f"[SwinMLPBlock] Input shape: {x.shape}")

        x = self.proj(x)
        x = x.view(B, D // self.window_size, self.window_size)  # [B, N_win, patch_dim]

        if self.debug:
            print(f"  → Windowed shape: {x.shape}")

        if self.shift:
            x = torch.roll(x, shifts=1, dims=1)

        x = self.norm(x)
        x = self.gmlp(x)

        if self.shift:
            x = torch.roll(x, shifts=-1, dims=1)

        if self.flatten_output:
            x = x.view(B, -1)  # [B, proj_dim]

        if self.debug:
            print(f"[SwinMLPBlock] Output shape: {x.shape}")

        return x

class WindowAttention1D(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):  # x: [B, N, D]
        h = x
        x = self.norm(x)
        x, _ = self.attn(x, x, x)
        return h + self.dropout(x)  # [B, N, D]


class CrossWindowAttention1D(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):  # x: [B, N, D]
        h = x
        x = self.norm(x)
        x, _ = self.attn(x, x, x)
        return h + self.dropout(x)  # [B, N, D]


class MlTrBlock1D(nn.Module):
    def __init__(self, input_dim=1376, proj_dim=512, patch_num=16,
                 num_heads=4, dropout=0.1, ffn_ratio=2.0,
                 reform_input=True, flatten_output=True):
        super().__init__()
        self.reform_input = reform_input
        self.flatten_output = flatten_output

        if self.reform_input:
            self.reformer = FeatureReformer(
                input_dim=input_dim,
                proj_dim=proj_dim,
                norm_type='layer',
                patchify=True,
                patch_num=patch_num
            )

        patch_dim = proj_dim // patch_num
        self.patch_num = patch_num
        self.patch_dim = patch_dim

        self.window_attn = WindowAttention1D(patch_dim, num_heads, dropout)
        self.cross_attn = CrossWindowAttention1D(patch_dim, num_heads, dropout)

        self.ffn = nn.Sequential(
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, int(patch_dim * ffn_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(patch_dim * ffn_ratio), patch_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):  # x: [B, input_dim] or [B, N_win, patch_dim]
        if self.reform_input:
            assert x.shape[1] == self.reformer.fc.in_features, \
                f"Input dim mismatch: got {x.shape[1]}, expected {self.reformer.fc.in_features}"
            x = self.reformer(x)  # [B, patch_num, patch_dim]

        x = self.window_attn(x)
        x = self.cross_attn(x)

        residual = x
        B, N, D = x.shape
        x = self.ffn(x.view(B * N, D)).view(B, N, D)
        x = x + residual

        return x.view(x.size(0), -1) if self.flatten_output else x


In [None]:
from functools import partial
import torch.nn as nn

class TransformerEncoderBlock(nn.Module):
    def __init__(self,
                 encoder_type="vanilla",
                 num_layers=2,
                 return_all_layers=False,
                 **block_kwargs):
        super().__init__()
        self.return_all_layers = return_all_layers

        block_map = {
            "vanilla": VanillaTransformerBlock,
            "swinmlp": SwinMLPBlock,
            "mltr": MlTrBlock1D,
        }

        encoder_type = encoder_type.lower()
        if encoder_type not in block_map:
            raise ValueError(f"Unsupported encoder_type: {encoder_type}")

        block_class = block_map[encoder_type]
        self.layers = nn.ModuleList()

        if encoder_type in ["mltr", "swinmlp"]:
            for i in range(num_layers):
                self.layers.append(block_class(
                    reform_input=(i == 0),
                    flatten_output=(i == num_layers - 1),
                    **block_kwargs
                ))
        else:
            block = partial(block_class, **block_kwargs)
            self.layers.extend([block() for _ in range(num_layers)])

    def forward(self, x):
        outputs = []
        for layer in self.layers:
            x = layer(x)
            if self.return_all_layers:
                outputs.append(x)
        return outputs if self.return_all_layers else x

In [None]:
import numpy as np

def build_label_cooccurrence(dataset, num_labels):
    co_matrix = np.zeros((num_labels, num_labels))

    for sample in dataset:
        labels = sample['lab']
        if labels is None:
            continue
        binary_mask = ~np.isnan(labels)
        present_labels = np.where(labels == 1)[0]
        for i in present_labels:
            for j in present_labels:
                if i != j:
                    co_matrix[i, j] += 1
    return co_matrix

def normalize_adjacency(A):
    D = np.diag(np.power(A.sum(axis=1) + 1e-5, -0.5))
    return D @ A @ D

def build_label_graph(dataset, num_labels):
    co_matrix = build_label_cooccurrence(dataset, num_labels)
    A = normalize_adjacency(co_matrix)
    return torch.tensor(A, dtype=torch.float32)

class LabelGCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, A):
        super().__init__()
        self.A = A  # [num_labels, num_labels]
        self.gcn1 = nn.Linear(in_dim, hidden_dim)
        self.gcn2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, label_embed):  
        x = torch.matmul(self.A, label_embed)
        x = F.relu(self.gcn1(x))
        x = torch.matmul(self.A, x)
        x = self.gcn2(x)
        return x 

In [None]:
class Query2LabelDecoder(nn.Module):
    def __init__(self, embed_dim=128, num_labels=13, num_heads=4, num_layers=2, dropout=0.1, use_norm=True):
        super().__init__()
        self.num_labels = num_labels
        self.embed_dim = embed_dim
        self.use_norm = use_norm

        self.label_queries = nn.Parameter(torch.randn(num_labels, embed_dim))

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.norm = nn.LayerNorm(embed_dim) if use_norm else nn.Identity()

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(embed_dim, 1)

    def forward(self, encoder_feats, label_embed):
        B = encoder_feats.size(0)

        if encoder_feats.dim() == 2:
            encoder_feats = encoder_feats.unsqueeze(1)  # [B, 1, d]

        queries = label_embed.unsqueeze(0).expand(B, -1, -1)  # [B, L, d]

        out = self.decoder(tgt=queries, memory=encoder_feats)  # [B, L, d]
        out = self.norm(out)
        out = self.dropout(out)
        logits = self.classifier(out).squeeze(-1)  # [B, L]
        probs = torch.sigmoid(logits)
        return probs

In [None]:
class MultiLabelClassifier(nn.Module):
    def __init__(self, embed_dim, num_labels):
        super().__init__()
        self.linear = nn.Linear(embed_dim, 1)

    def forward(self, x):
        logits = self.linear(x).squeeze(-1)  # (B, L)
        return torch.sigmoid(logits)

def masked_bce_loss(pred, target, eps=1e-8):
    mask = ~torch.isnan(target)                            
    target_clean = torch.nan_to_num(target, nan=0.0)    
    loss_fn = nn.BCELoss(reduction='none')            
    loss = loss_fn(pred, target_clean)                   
    return (loss * mask).sum() / (mask.sum() + eps)       

In [None]:
import torch
import torch.nn as nn

class EndToEndModel(nn.Module):
    def __init__(self,
                 input_dim=1376,
                 proj_dim=512,
                 num_labels=13,
                 A=None,
                 norm_type='layer',
                 patchify=False,
                 patch_num=16,
                 encoder_type="vanilla",
                 encoder_layers=3,
                 decoder_layers=3,
                 num_heads=4,
                 use_label_gcn=True,
                 output_mode="logits"  # or "features"
                 ):
        super().__init__()

        self.patchify = patchify
        self.encoder_type = encoder_type.lower()
        self.output_mode = output_mode
        self.use_label_gcn = use_label_gcn

        incompatible_patch_structures = ["mltr", "swinmlp"]
        if not patchify and self.encoder_type in incompatible_patch_structures:
            raise ValueError(f"Encoder '{encoder_type}' requires patchify=True, but you set patchify=False.")

        if self.encoder_type == "vanilla":
            self.reformer = FeatureReformer(
                input_dim=input_dim,
                proj_dim=proj_dim,
                norm_type=norm_type,
                patchify=patchify,
                patch_num=patch_num
            )
            encoder_kwargs = {
                "embed_dim": proj_dim,
                "num_heads": num_heads
            }

        elif self.encoder_type == "swinmlp":
            encoder_kwargs = {
                "input_dim": proj_dim,     
                "window_size": patch_num
            }

        elif self.encoder_type == "mltr":
            encoder_kwargs = {
                "input_dim": input_dim, 
                "proj_dim": proj_dim,
                "patch_num": patch_num,
                "num_heads": num_heads
            }

        else:
            raise ValueError(f"Unsupported encoder_type: {encoder_type}")

        self.encoder = TransformerEncoderBlock(
            encoder_type=self.encoder_type,
            num_layers=encoder_layers,
            return_all_layers=False,
            **encoder_kwargs
        )

        if self.use_label_gcn:
            assert A is not None, "Adjacency matrix A must be provided for LabelGCN"
            self.label_gcn = LabelGCN(
                in_dim=proj_dim,
                hidden_dim=proj_dim,
                out_dim=proj_dim,
                A=A
            )

        self.decoder = Query2LabelDecoder(
            embed_dim=proj_dim,
            num_labels=num_labels,
            num_heads=num_heads,
            num_layers=decoder_layers
        )

    def forward(self, x, label_embed=None):
        assert label_embed is not None, "label_embed must be provided"

        if self.encoder_type == "vanilla":
            x = self.reformer(x)

        x = self.encoder(x)

        if self.use_label_gcn:
            label_embed = self.label_gcn(label_embed)

        probs = self.decoder(x, label_embed)

        if self.output_mode == "logits":
            return probs
        elif self.output_mode == "features":
            return x, probs
        else:
            raise ValueError(f"Unsupported output_mode: {self.output_mode}")


# train

In [None]:
import pandas as pd
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, average_precision_score

class Trainer:
    def __init__(self, model, optimizer, patience=5, device='cuda'):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.criterion = masked_bce_loss
        self.device = device
        self.best_model = None
        self.best_val_auc = -np.inf
        self.patience = patience
        self.counter = 0
        self.best_scores = None  

    def evaluate(self, dataloader, label_embed, pathologies=None):
        self.model.eval()
        all_preds, all_targets = [], []

        with torch.no_grad():
            for batch in dataloader:
                x = batch['x'].to(self.device)
                y = batch['y'].to(self.device)

                preds = self.model(x, label_embed=label_embed)
                all_preds.append(preds.detach().cpu())
                all_targets.append(y.detach().cpu())

        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)

        num_labels = all_targets.shape[1]
        aucs, f1s, maps, accs = [], [], [], []

        for i in range(num_labels):
            mask = ~torch.isnan(all_targets[:, i])
            if mask.sum() < 10:
                aucs.append(None)
                f1s.append(None)
                maps.append(None)
                accs.append(None)
                continue

            y_true = all_targets[:, i][mask].numpy()
            y_prob = all_preds[:, i][mask].numpy()
            y_pred = (y_prob > 0.5).astype(int)

            aucs.append(roc_auc_score(y_true, y_prob))
            f1s.append(f1_score(y_true, y_pred))
            maps.append(average_precision_score(y_true, y_prob))
            accs.append(accuracy_score(y_true, y_pred))

        valid_aucs = [a for a in aucs if a is not None]
        macro_auc = sum(valid_aucs) / len(valid_aucs)

        index = pathologies if pathologies is not None else [f"Label_{i}" for i in range(num_labels)]
        metrics_df = pd.DataFrame({
            'AUC': aucs,
            'F1': f1s,
            'mAP': maps,
            'Accuracy': accs
        }, index=index).round(4)

        return macro_auc, metrics_df  

    def train(self, train_loader, val_loader, label_embed_init, epochs=50, pathologies=None):
        self.model.train()
        label_embed_init = label_embed_init.to(self.device)

        loss_list, auc_list = [], []

        for epoch in range(epochs):
            epoch_loss = 0
            self.model.train()

            for batch in train_loader:
                x = batch['x'].to(self.device)
                y = batch['y'].to(self.device)

                self.optimizer.zero_grad()
                preds = self.model(x, label_embed=label_embed_init)
                loss = self.criterion(preds, y)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            loss_list.append(epoch_loss / len(train_loader))

            val_auc, val_df = self.evaluate(val_loader, label_embed_init, pathologies)
            auc_list.append(val_auc)

            print(f"Epoch {epoch+1} | Loss: {loss_list[-1]:.4f} | AUC: {val_auc:.4f}")
            display(val_df)

            if val_auc > self.best_val_auc:
                self.best_val_auc = val_auc
                self.best_model = copy.deepcopy(self.model.state_dict())
                self.best_scores = val_df
                self.counter = 0
            else:
                self.counter += 1
                if self.counter >= self.patience:
                    print("Early stopping triggered.")
                    break

        if self.best_model is not None:
            self.model.load_state_dict(self.best_model)

        return self.model, self.best_scores, loss_list, auc_list



In [None]:
import torch
import copy
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F


input_dim = 1376
proj_dim = 512
num_labels = 13
patchify = True # mltr
# patchify = False # vanilla 
patch_num = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

A = build_label_graph(train_loader.dataset, num_labels).to(device)

label_embed_init = torch.randn(num_labels, proj_dim).to(device)

model = EndToEndModel(
    input_dim=input_dim,
    proj_dim=proj_dim,
    num_labels=num_labels,
    A=A,
    norm_type='layer',
    patchify=patchify,
    patch_num=patch_num,
    encoder_type="mltr",          # vanilla, mltr
    encoder_layers=2,
    decoder_layers=2,
    num_heads=4,
    use_label_gcn=True,
    output_mode="logits"
).to(device)


optimizer = Adam(model.parameters(), lr=1e-3)
trainer = Trainer(model, optimizer, patience=10, device=device)

model, metrics_df, loss_list, auc_list = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    label_embed_init=label_embed_init,
    epochs=100,
    pathologies=dataset.pathologies
)


display(metrics_df)


# torch.save(model.state_dict(), "end2end_model.pt")

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(loss_list, label="EndToEndModel")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(auc_list, label="EndToEndModel")
plt.title("Validation AUC")
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
metrics = trainer.evaluate(test_loader, label_embed_init)
print("Test Set Metrics:", metrics)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(loss_list, label="EndToEndModel")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(auc_list, label="EndToEndModel")
plt.title("Validation AUC")
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
from torch.optim import Adam

results = {}  
input_dim = 1376
proj_dim = 512
num_labels = 13
# patchify = True
# patchify = False
patch_num = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

A = build_label_graph(train_loader.dataset, num_labels).to(device)

label_embed_init = torch.randn(num_labels, proj_dim).to(device)

for encoder_type in ["vanilla", "mltr"]:
    print(f"\nTraining model with encoder: {encoder_type}")

    patchify = True if encoder_type == "mltr" else False 

    model = EndToEndModel(
        input_dim=input_dim,
        proj_dim=proj_dim,
        num_labels=num_labels,
        A=A,
        norm_type='layer',
        patchify=patchify,
        patch_num=patch_num,
        encoder_type=encoder_type,
        encoder_layers=2,
        decoder_layers=2,
        num_heads=4,
        use_label_gcn=True,
        output_mode="logits"
    ).to(device)

    optimizer = Adam(model.parameters(), lr=1e-3)
    trainer = Trainer(model, optimizer, patience=5, device=device)

    model, metrics_df, loss_list, auc_list = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        label_embed_init=label_embed_init,
        epochs=30,
        pathologies=dataset.pathologies
    )

    results[encoder_type] = {
        "loss": loss_list,
        "auc": auc_list,
        "df": metrics_df
    }

    print(f"Finished training: {encoder_type}")
    display(metrics_df)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for name in results:
    plt.plot(results[name]["loss"], label=f"{name}")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
for name in results:
    plt.plot(results[name]["auc"], label=f"{name}")
plt.title("Validation AUC")
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()