# CXR - FMKD

### Foundation Model Knowledge Distillation (FMKD) for Chest X-Rays (CXR)

The aim is to develop a robust, distilled reconstruction of foundation models (FMs) that maintains performance on par with the original models and provides a practical means to explore and potentially mitigate inherent bias issues.

The focus of this work is on Google’s proprietary CXR Foundation Model (CXR-FM), which was trained on 821,544 labelled and mostly private chest X-rays (CXRs). In this context, the task of disease detection using CXRs from the publicly available CheXpert dataset will first be explored first.

***First Goal:*** Leverage the embeddings from CXR-FM, the teacher model, to train a student model using Knowledge Distillation (KD). This process utilises a subset of the CheXpert dataset as a transfer set to align the features of the student model with those of the teacher model.


In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import torchvision
import torchvision.transforms as T
from torchvision import models
import pytorch_lightning as pl

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from skimage.io import imread
from skimage.io import imsave
from tqdm import tqdm
from argparse import ArgumentParser

In [16]:
d121 = models.densenet121()
ni = d121.classifier.in_features
no = d121.classifier.out_features
print("d121: ni", ni, "; no", no)
d161 = models.densenet161()
ni = d161.classifier.in_features
no = d161.classifier.out_features
print("d161: ni", ni, "; no", no)
d169 = models.densenet169()
ni = d169.classifier.in_features
no = d169.classifier.out_features
print("d169: ni", ni, "; no", no)
d201 = models.densenet201()
ni = d201.classifier.in_features
no = d201.classifier.out_features
print("d201: ni", ni, "; no", no)

print(d121)

d121: ni 1024 ; no 1000
d161: ni 2208 ; no 1000
d169: ni 1664 ; no 1000
d201: ni 1920 ; no 1000
DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1