In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC
from torch import nn
from torchvision import models
from torch.optim.lr_scheduler import StepLR
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import torchmetrics
import seaborn as sns
import matplotlib.pyplot as plt

# Tokenizer
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

# Transformations
transformations = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


class CustomDataset(Dataset):
  def __init__(self, image_paths, labels, text_inputs, tokenizer, transform = None, file_path ='/scratch/pyc298/data'): #make sure to update the file_path
    self.image_paths = image_paths
    self.labels = labels 
    self.text_inputs = text_inputs
    self.transform = transform
    self.file_path = file_path
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    #Get the image 
    image_path = self.image_paths[idx]

    #Generate the Full Path
    full_path = os.path.join(self.file_path, image_path)
    #Open Image from Path
    image = Image.open(full_path).convert('L')  # Grayscale
    #Transform
    if self.transform:
            image = self.transform(image)

    text = self.text_inputs[idx]
    #add encoder and vocab things
    encoded_dict = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=250,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

    # Extract tensors
    input_ids = encoded_dict['input_ids'].squeeze(0)
    attention_mask = encoded_dict['attention_mask'].squeeze(0)

    return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(self.labels[idx], dtype=torch.long) #convert to tensor of long
        }

In [2]:
import numpy as np
import pandas as pd
np.random.seed(2024)
torch.manual_seed(2024)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(2024)

#make this a config
path_to_dir = "/scratch/pyc298/code"
csv_name = "final_cxr_free_text.csv"

dataset = pd.read_csv(os.path.join(path_to_dir, csv_name ))

# Split the dataset
train_data, test_data = train_test_split(dataset, test_size=0.40, random_state=42)
valid_data, test_data = train_test_split(test_data, test_size = 0.50, random_state = 42)
train_data.reset_index(drop = True, inplace = True)
valid_data.reset_index(drop = True, inplace = True)
test_data.reset_index(drop = True, inplace = True)



In [3]:
print(train_data["is_pneumonia"].value_counts(),
test_data["is_pneumonia"].value_counts(),
valid_data["is_pneumonia"].value_counts())

is_pneumonia
0    13059
1     1408
Name: count, dtype: int64 is_pneumonia
0    4391
1     432
Name: count, dtype: int64 is_pneumonia
0    4340
1     483
Name: count, dtype: int64


In [4]:
train_dataset = CustomDataset(image_paths = train_data["image_path"], labels = train_data["is_pneumonia"], text_inputs = train_data["free_text"], tokenizer = tokenizer, transform = transformations)
test_dataset = CustomDataset(image_paths = test_data["image_path"], labels = test_data["is_pneumonia"], text_inputs = test_data["free_text"], tokenizer = tokenizer, transform = transformations)
valid_dataset = CustomDataset(image_paths = valid_data["image_path"], labels = valid_data["is_pneumonia"], text_inputs = valid_data["free_text"], tokenizer = tokenizer, transform = transformations)
batch_size = 64
print(len(train_dataset), len(valid_dataset), len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=2)

14467 4823 4823


In [5]:
import torch
from torch import nn
import torchmetrics
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import pytorch_lightning as pl
import seaborn as sns
import matplotlib.pyplot as plt


class BaseModel(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-4, loss_function=nn.BCEWithLogitsLoss()):
        super().__init__()
        self.save_hyperparameters()

        # Loss function
        self.loss_function = loss_function

        # Metrics
        self.train_accuracy = torchmetrics.Accuracy(num_classes=num_classes, task='binary')
        self.val_accuracy = torchmetrics.Accuracy(num_classes=num_classes, task='binary')
        self.test_accuracy = torchmetrics.Accuracy(num_classes=num_classes, task='binary')
        self.precision = torchmetrics.Precision(num_classes=num_classes, task='binary')
        self.recall = torchmetrics.Recall(num_classes=num_classes, task='binary')
        self.f1_score = torchmetrics.F1Score(num_classes=num_classes, task='binary')
        self.auroc = torchmetrics.AUROC(num_classes=num_classes, task='binary')
        self.confmat = torchmetrics.ConfusionMatrix(num_classes=num_classes, task='binary')

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        targets = batch['labels'].unsqueeze(1).float()
        outputs = self(batch['image'], batch['input_ids'], batch['attention_mask'])
        loss = self.loss_function(outputs, targets)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy(outputs,targets), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        targets = batch['labels'].unsqueeze(1).float()
        outputs = self(batch['image'], batch['input_ids'], batch['attention_mask'])
        loss = self.loss_function(outputs, targets)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_accuracy(outputs, targets), on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        targets = batch['labels'].unsqueeze(1).float()
        outputs = self(batch['image'], batch['input_ids'], batch['attention_mask'])
        loss = self.loss_function(outputs, targets)
        preds = torch.sigmoid(outputs).round()

        self.confmat.update(preds, targets)

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', self.test_accuracy(outputs, targets),on_epoch=True, prog_bar=True)
        self.log('precision', self.precision(outputs, targets),on_epoch=True )
        self.log('recall', self.recall(outputs, targets),on_epoch=True)
        self.log('f1_score', self.f1_score(outputs, targets),on_epoch=True)
        self.log('auroc', self.auroc(outputs, targets),on_epoch=True)

    def log_confusion_matrix(self, cm):
        fig, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
        ax.set_xlabel('Predicted Labels')
        ax.set_ylabel('True Labels')
        ax.set_title('Confusion Matrix')
        plt.close(fig)
        
        # Log to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)

    def on_test_epoch_end(self):
        cm = self.confmat.compute().cpu().numpy()
        self.log_confusion_matrix(cm)
        
    def set_parameter_requires_grad(self, freeze_it):
      # Freeze layers up to the specified block
      layers_to_freeze = {
          'conv1': self.image_encoder.conv1,
          'layer1': self.image_encoder.layer1,
          'layer2': self.image_encoder.layer2,
          'layer3': self.image_encoder.layer3,
          'layer4': self.image_encoder.layer4
      }

      for name, layer in layers_to_freeze.items():
          if name in freeze_it:
              for param in layer.parameters():
                  param.requires_grad = False
          else:
              break  # Stop at the first unfrozen layer
  


In [6]:
class TextModel(BaseModel):
    def __init__(self, num_classes, learning_rate=1e-4, loss_function = nn.BCEWithLogitsLoss()):
        super().__init__(num_classes, learning_rate, loss_function)
        self.save_hyperparameters()
        
        #Loss Function
        self.loss_function = loss_function

        # Text encoder setup
        self.text_encoder = BertModel.from_pretrained('dmis-lab/biobert-v1.1')
        self.text_fc = nn.Linear(768, 256)

        # Combined layers
        self.final_fc = nn.Linear(256, num_classes)


    def forward(self, images, input_ids, attention_mask):
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = self.text_fc(text_outputs.pooler_output)
        return self.final_fc(text_features)

In [None]:
from pytorch_lightning import Callback
from pytorch_lightning.loggers import TensorBoardLogger
train_steps_per_epoch = len(train_dataset) // batch_size + (len(train_dataset) % batch_size > 0)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint



# ModelCheckpoint callback to save the best model based on the 'val_loss'.
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Metric to monitor
    dirpath='model_checkpoints',  # Directory where the checkpoints will be saved
    filename='best_model-{epoch:02d}-{val_loss:.2f}',  # Filename with epoch number and validation loss
    save_top_k=1,  # Number of best models to save; 1 means save the best model only
    mode='min',  # 'min' mode means the 'val_loss' should be minimized
    verbose=True,
)

class ClearCacheCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        torch.cuda.empty_cache()
        print("Cleared GPU cache")

logger = TensorBoardLogger('model_logs', name='bce_loss_Text_Full')
# Callbacks
clear_cache_callback = ClearCacheCallback()

early_stop_callback = EarlyStopping(
    monitor='val_loss',    
    patience=3,          
    verbose=True,
    mode='min'            # 'min' or 'max' (whether the monitored quantity should decrease or increase)
)
trainer = pl.Trainer(logger=logger, 
                     log_every_n_steps=50,  # Log at the end of each epoch
                     callbacks=[early_stop_callback, clear_cache_callback, checkpoint_callback],
                     max_epochs=10, 
                     devices=1, 
                     accelerator="gpu" if torch.cuda.is_available() else "cpu") # Automatically choose GPU if available

model = TextModel(num_classes=1, loss_function = nn.BCEWithLogitsLoss())
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders =valid_loader)

trainer.test(ckpt_path=checkpoint_callback.best_model_path, dataloaders=test_loader)


# With 75 Subsampled

## Getting the dataset

In [None]:
import numpy as np
import pandas as pd
np.random.seed(2024)
torch.manual_seed(2024)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(2024)

#make this a config
path_to_dir = "/scratch/pyc298/code"
csv_name = "final_cxr_free_text75.csv"

dataset = pd.read_csv(os.path.join(path_to_dir, csv_name ))

# Split the dataset
train_data, test_data = train_test_split(dataset, test_size=0.40, random_state=42)
valid_data, test_data = train_test_split(test_data, test_size = 0.50, random_state = 42)
train_data.reset_index(drop = True, inplace = True)
valid_data.reset_index(drop = True, inplace = True)
test_data.reset_index(drop = True, inplace = True)

print(train_data["is_pneumonia"].value_counts(),
test_data["is_pneumonia"].value_counts(),
valid_data["is_pneumonia"].value_counts())

train_dataset = CustomDataset(image_paths = train_data["image_path"], labels = train_data["is_pneumonia"], text_inputs = train_data["free_text"], tokenizer = tokenizer, transform = transformations)
test_dataset = CustomDataset(image_paths = test_data["image_path"], labels = test_data["is_pneumonia"], text_inputs = test_data["free_text"], tokenizer = tokenizer, transform = transformations)
valid_dataset = CustomDataset(image_paths = valid_data["image_path"], labels = valid_data["is_pneumonia"], text_inputs = valid_data["free_text"], tokenizer = tokenizer, transform = transformations)
batch_size = 64
print(len(train_dataset), len(valid_dataset), len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=2)

## Training loop

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Callback
from pytorch_lightning.loggers import TensorBoardLogger
train_steps_per_epoch = len(train_dataset) // batch_size + (len(train_dataset) % batch_size > 0)

# ModelCheckpoint callback to save the best model based on the 'val_loss'.
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Metric to monitor
    dirpath='model_checkpoints',  # Directory where the checkpoints will be saved
    filename='best_75Text_bce_model-{epoch:02d}-{val_loss:.2f}',  # Filename with epoch number and validation loss
    save_top_k=1,  # Number of best models to save; 1 means save the best model only
    mode='min',  # 'min' mode means the 'val_loss' should be minimized
    verbose=True,
)


class ClearCacheCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        torch.cuda.empty_cache()
        print("Cleared GPU cache")

logger = TensorBoardLogger('model_logs', name='bce_loss_Text_75')
# Callbacks
clear_cache_callback = ClearCacheCallback()

early_stop_callback = EarlyStopping(
    monitor='val_loss',    
    patience=3,          
    verbose=True,
    mode='min'            # 'min' or 'max' (whether the monitored quantity should decrease or increase)
)
trainer = pl.Trainer(logger=logger, 
                     log_every_n_steps=50,  # Log at the end of each epoch
                     callbacks=[early_stop_callback, clear_cache_callback, checkpoint_callback],
                     max_epochs=10, 
                     devices=1, 
                     accelerator="gpu" if torch.cuda.is_available() else "cpu") # Automatically choose GPU if available

model = TextModel(num_classes=1, loss_function = nn.BCEWithLogitsLoss())
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders =valid_loader)
trainer.test(ckpt_path=checkpoint_callback.best_model_path, dataloaders=test_loader)
