In [35]:
onColab = False

if onColab:
    ! pip install kaggle
    ! mkdir ~/.kaggle
    ! cp kaggle.json ~/.kaggle/
    ! chmod 600 ~/.kaggle/kaggle.json

In [36]:
if onColab:
    ! kaggle datasets download raddar/chest-xrays-indiana-university

In [37]:
import zipfile
import os

if onColab:
    file_name = "chest-xrays-indiana-university.zip"
    
    # extract the file from the zip
    with zipfile.ZipFile(file_name, 'r') as zip_ref:
        zip_ref.extractall("chest_xrays_data")

In [38]:
if onColab:
    !ls chest_xrays_data

In [39]:
if onColab: 
    img_dir = 'chest_xrays_data/images/images_normalized/'
    reports_dir = 'chest_xrays_data/indiana_reports.csv'
    projections_dir = 'chest_xrays_data/indiana_projections.csv'
else:
    img_dir = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized/'
    reports_dir = '/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv'
    projections_dir = '/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv'

In [40]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [41]:
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [42]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BioGptTokenizer, BioGptForCausalLM

from tqdm import tqdm
from tqdm.auto import trange

import torchvision
from torchvision import transforms as T

In [43]:
# for BioGPT tokenizer
!pip install sacremoses



In [44]:
torch.backends.cudnn.benchmark = True

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Using device: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "Using CPU")

Using device: Tesla P100-PCIE-16GB


#### **Preprocessing**

In [45]:
reports_df = pd.read_csv(reports_dir)
reports_df.head()

Unnamed: 0,uid,MeSH,Problems,image,indication,comparison,findings,impression
0,1,normal,normal,Xray Chest PA and Lateral,Positive TB test,None.,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.
1,2,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,None.,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.
2,3,normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p..."
3,4,"Pulmonary Disease, Chronic Obstructive;Bullous...","Pulmonary Disease, Chronic Obstructive;Bullous...","PA and lateral views of the chest XXXX, XXXX a...",XXXX-year-old XXXX with XXXX.,None available,There are diffuse bilateral interstitial and a...,1. Bullous emphysema and interstitial fibrosis...
4,5,Osteophyte/thoracic vertebrae/multiple/small;T...,Osteophyte;Thickening;Lung,Xray Chest PA and Lateral,Chest and nasal congestion.,,The cardiomediastinal silhouette and pulmonary...,No acute cardiopulmonary abnormality.


In [46]:
projections_df = pd.read_csv(projections_dir)
projections_df.head()

Unnamed: 0,uid,filename,projection
0,1,1_IM-0001-4001.dcm.png,Frontal
1,1,1_IM-0001-3001.dcm.png,Lateral
2,2,2_IM-0652-1001.dcm.png,Frontal
3,2,2_IM-0652-2001.dcm.png,Lateral
4,3,3_IM-1384-1001.dcm.png,Frontal


In [47]:
# filter the rows with null findings
reports_filtered = reports_df.dropna(subset=["findings"])

# keep only entries in projections that have a filtered report associated (association through uid)
projections_filtered = projections_df[projections_df["uid"].isin(reports_filtered["uid"])]
reports_filtered.shape, projections_filtered.shape

((3337, 8), (6469, 3))

In [48]:
VAL_SIZE = 0.1

uids = reports_filtered.uid.unique()

train_ds, val_ds = train_test_split(
    uids,
    test_size=VAL_SIZE,
    random_state=42
)

len(train_ds), len(val_ds)

(3003, 334)

#### **Load the pre-trained transformer and build the dataset**

In [49]:
def load_model_and_tokenizer(model_name="gpt2"): 
    if model_name == "BioGPT":
        tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt") 
        tokenizer.pad_token = tokenizer.eos_token
        model = BioGptForCausalLM.from_pretrained("microsoft/biogpt").to(device)
        hidden_size = model.config.hidden_size
    else:
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token
        model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
        hidden_size = model.config.n_embd
        
    for param in model.parameters():
        param.requires_grad = True  # Freezes all transformer parameters

    transformer_parameters= sum(p.numel() for p in model.parameters())
    print(f"Number of transformer parameters: {transformer_parameters}")

    return tokenizer, model, hidden_size

# you need to change only this variable to change the transformer!!
model_name = "BioGPT"
tokenizer, transformerModel, hidden_size = load_model_and_tokenizer(model_name)

Number of transformer parameters: 346763264


In [50]:
# adjusted dataset
class ChestXRayDataset(Dataset):
    def __init__(self, reports_df, projections_df, image_folder, tokenizer, uids, transforms):
        self.reports_df = reports_df[reports_df["uid"].isin(uids)].reset_index(drop=True)
        self.projections_df = projections_df
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        # a series of transformations to be applied to images before feeding them into a model
        self.transform = transforms

    def __len__(self):
        return len(self.reports_df)

    def __getitem__(self, idx):
        row = self.reports_df.iloc[idx]
        uid = row["uid"]
        text = row["findings"]

        # tokenize findings column
        encoded_text = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=144,
            return_tensors="pt"
        )

        # find the path and filename of the associated image
        image_filename = self.projections_df[self.projections_df["uid"] == uid]["filename"].values[0]
        image_path = f"{self.image_folder}/{image_filename}"

        # load and trasform the image
        image = Image.open(image_path).convert("L")  # conversion to grayscale
        image = self.transform(image)

        # return the image, label (finding)
        return image, encoded_text["input_ids"].squeeze(0), encoded_text["attention_mask"].squeeze(0)

tf = T.Compose([
    T.Resize((224, 224)),  # resizing for pre-trained models
    T.ToTensor(),
])

train_dataset = ChestXRayDataset(reports_filtered, projections_filtered, img_dir, tokenizer, train_ds, tf)
val_dataset = ChestXRayDataset(reports_filtered, projections_filtered, img_dir, tokenizer, val_ds, tf)

In [51]:
BATCH_SIZE = 16

# create the DataLoader to generate batches of the dataset and iterate over them
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

In [52]:
def conv_layer(n_input, n_output, kernel_size, stride=1):
    return nn.Sequential(
        nn.Conv2d(n_input, n_output, kernel_size, stride),
        nn.ReLU(),
        nn.BatchNorm2d(n_output),
        nn.MaxPool2d(2)
    )

In [53]:
encoder = nn.Sequential(
            conv_layer(1, 64, 3),
            conv_layer(64, 128, 3),
            conv_layer(128, 256, 3),
            conv_layer(256, 512, 3)
        )

encoder.load_state_dict(torch.load("/kaggle/input/encodercnn/pytorch/default/1/encoder.pth"))
encoder.to(device)

  encoder.load_state_dict(torch.load("/kaggle/input/encodercnn/pytorch/default/1/encoder.pth"))


Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_runni

#### **Visualize the latent space generated**

In [54]:
data_iter = iter(val_loader)
inputs, _, _ = next(data_iter)

inputs = inputs.to(device)

with torch.no_grad():
    latent_space = encoder(inputs)

inputs = inputs.cpu().numpy()
latent_space = latent_space.cpu().numpy()

for idx in range(2):
    reconstructed = latent_space[idx, 0]

    print(f"{idx+1}) Latent Space (dim={len(latent_space[idx, 0])}) -> {reconstructed}")

1) Latent Space (dim=12) -> [[ 2.3342481e+00  7.4157872e+00  6.0161729e+00  4.6807127e+00
  -2.1490312e-01 -2.1490312e-01 -2.1490312e-01 -2.1490312e-01
  -2.1490312e-01  2.8534083e+00  6.5627036e+00  7.1844749e+00]
 [ 3.2660177e+00  2.3519361e+00  1.9366018e+00  7.5438648e-01
  -2.1490312e-01 -2.1490312e-01 -2.1490312e-01  3.1695980e-01
  -2.1490312e-01 -2.1490312e-01 -2.1490312e-01  5.0425382e+00]
 [-2.2234943e-02 -2.1490312e-01 -2.1490312e-01 -2.1490312e-01
  -2.1490312e-01 -2.1490312e-01  2.7784839e+00 -1.7517197e-01
  -2.1490312e-01 -2.1490312e-01  1.4653642e+00  1.7443779e+00]
 [-2.1490312e-01 -2.1490312e-01 -1.7824513e-01  1.9602889e+00
  -2.1490312e-01  6.8956339e-03  4.6908006e-01 -1.3318159e-01
  -2.1490312e-01 -2.1490312e-01 -8.7809145e-02  4.9835712e-01]
 [-2.1490312e-01 -2.1490312e-01  3.5161749e-01 -2.1490312e-01
  -2.1490312e-01  1.2977669e+00  3.8835564e+00  1.8848686e+00
  -2.1490312e-01 -2.1490312e-01 -6.6855192e-02 -2.1490312e-01]
 [-2.1490312e-01  9.4445586e-01  2.72

### **Build and train the FF mapper model**

In [55]:
def linear_layer(dim_input, dim_output, drop_p=0.1, last=False):
    layers = [nn.Linear(dim_input, dim_output)]
    if not last:
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(p=drop_p))
    return nn.Sequential(*layers)

In [56]:
class FF_mapper(nn.Module):

    def __init__(self, dim_input, dim_output):
        super().__init__()
        self.ff = nn.Sequential(
            linear_layer(dim_input, 640),
            linear_layer(640, 896),
            #linear_layer(896, 1024),
            linear_layer(896, dim_output, last=True),
            nn.LayerNorm(dim_output)
        )
        

    def forward(self, latent_space):
        # flatter, permute and stuff
        batch_size, C, H, W = latent_space.shape
        latent_space = latent_space.permute(0, 2, 3, 1)  # (1, 12, 12, 512)
        latent_space = latent_space.view(batch_size, H * W, C)  # (1, 144, 512)
        return self.ff(latent_space)

In [57]:
def soft_generate(inputs_embeds, attention_mask, labels, temperature=1.0):
    outputs = transformerModel(
        inputs_embeds=inputs_embeds, 
        attention_mask=attention_mask,
        labels=labels,
        return_dict=True
    )
    logits = outputs.logits  # [batch, seq_len, vocab_size]
    # Apply softmax with temperature to get differentiable probabilities
    soft_tokens = nn.functional.softmax(logits / temperature, dim=-1)
    return soft_tokens  # This is a differentiable approximation

In [58]:
def train(train_x, val_x, model, epochs=10):
    criterion = mse_cos_sim_loss
    alpha = 0.0    # used in mixed loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    history = []
    log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} val_loss: {val_loss:0.4f}"

    with tqdm(desc="epoch", total=epochs) as pbar_outer:
        for epoch in range(epochs):
            #if epoch/epochs >= 0.5:
            #    alpha = 0.2
            
            train_loss = fit_epoch(model, train_x, criterion, optimizer, alpha)
            val_loss = eval_epoch(model, val_x, criterion, alpha)
            print("loss: ", train_loss)

            history.append((train_loss,val_loss))

            pbar_outer.update(1)
            tqdm.write(log_template.format(ep=epoch+1, t_loss=train_loss, val_loss=val_loss))            
        
    return history

def cosine_similarity_loss(predicted_embedding, target_embedding):
    return 1 - nn.functional.cosine_similarity(predicted_embedding, target_embedding, dim=-1).mean()

def mse_cos_sim_loss(pred, true, alpha=0.05, rescale=10):
    # try with a rescale factor of 10
    return alpha * nn.functional.mse_loss(pred, true) + (1-alpha) * rescale * cosine_similarity_loss(pred, true)

def fit_epoch(model, train_x, criterion, optimizer, alpha):
    running_loss = 0.0
    processed_data = 0

    # for epoch progress
    old_progress = -0.1
    new_progress = 0

    for idx, (images, text, attention) in enumerate(train_x):

        new_progress = idx/len(train_x)
        if (new_progress-old_progress >= 0.1):
            print(f"Epoch progress: {new_progress*100}%")
            old_progress = new_progress

        images = images.to(device)
        text = text.to(device)
        attention = attention.to(device)

        optimizer.zero_grad()

        # Get latent space representation
        with torch.no_grad():
            latent_space = encoder(images).to(device)

        # FFNN generates transformer input embeddings
        pred_embeds = model(latent_space)

        # len(pred_embeds), out => 32
        # len(pred_embeds[0]), out => 144
        # len(pred_embeds[0][0]), out => 768
        # pred_embeds.shape, out => [32, 144, 768] => [batch_size, text_dim, emb_dim]

        # generating the logits
        generated_tokens = soft_generate(pred_embeds, attention, text)

        emb_layer = transformerModel.get_input_embeddings()
        vocab_embeddings = emb_layer.weight
        true_y = emb_layer(text)
        pred_y = torch.matmul(generated_tokens, vocab_embeddings)

        #if (idx == 5):
        #    print(f"real tokens:\n{text}")
        #    print(f"pred tokens:\n{generated_text}")

        # print(f"Real Text:\n{tokenizer.decode(text[0], skip_special_tokens=True)}")
        # print(f"Generated Text:\n{tokenizer.decode(generated_text[0], skip_special_tokens=True)}")

        # Compute loss
        loss = criterion(pred_y, true_y, alpha)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.shape[0]
        processed_data += images.shape[0]
        
    return running_loss / processed_data

def eval_epoch(model, val_x, criterion, alpha):
    running_loss = 0.0
    processed_data = 0
    model.eval()

    with torch.no_grad():
        for images, text, attention in val_x:
            
            images = images.to(device)
            text = text.to(device)
            attention = attention.to(device)

            # using imported models to create the data we need
            latent_space = encoder(images).to(device)

            pred_embeds = model(latent_space)
        
            # generating the logits
            generated_tokens = soft_generate(pred_embeds, attention, text)

            emb_layer = transformerModel.get_input_embeddings()
            vocab_embeddings = emb_layer.weight
            true_y = emb_layer(text)
            pred_y = torch.matmul(generated_tokens, vocab_embeddings)
            
            # Compute loss
            loss = criterion(pred_y, true_y, alpha)
            
            running_loss += loss.item() * images.shape[0]
            processed_data += images.shape[0]
    
    return running_loss / processed_data

In [59]:
mapper = FF_mapper(512, hidden_size).to(device)    # hidden_size = 768 for GPT2 and 1024 for BioGPT

mapper_parameters= sum(p.numel() for p in mapper.parameters())
print(f"number of mapper parameters: {mapper_parameters}")

number of mapper parameters: 1823232


In [60]:
import time

start = time.time()
history = train(train_loader, val_loader, mapper, epochs=20)
print(f"Training duration: {(time.time() - start) / 60} (min)")

epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch progress: 0.0%
Epoch progress: 10.106382978723403%


epoch:   0%|          | 0/20 [00:26<?, ?it/s]


KeyboardInterrupt: 

In [None]:
train_loss, val_loss = zip(*history)
plt.figure(figsize=(15,10))
plt.plot(train_loss, label='Train loss')
plt.plot(val_loss, label='Val loss')
plt.legend(loc='best')
plt.xlabel("epochs")
plt.ylabel("loss")
plt.plot();

In [None]:
def generate_text(inputs_embeds, attention_mask):
    return transformerModel.generate(
        inputs_embeds=inputs_embeds, 
        max_length=288,
        attention_mask=attention_mask,
        pad_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=2,   # avoid repetitions
        #top_k=50,   # considers only the 50 most probable words
        eos_token_id=None,
        do_sample=False
    )

In [64]:
data_iter = iter(train_loader)
image, text, attention = next(data_iter)

print(f"Real Text:\n{tokenizer.decode(text[0], skip_special_tokens=True)}\n\n")

image = image.to(device)
text = text.to(device)
attention = attention.to(device)

with torch.no_grad():
    latent_space = encoder(image).to(device)
    predicted_embedding = mapper(latent_space).to(device)    

predicted_text = generate_text(predicted_embedding, attention)

print(f"Predicted Text:\n{tokenizer.decode(predicted_text[0], skip_special_tokens=True)}")

Real Text:
Apparent scarring within the lingula. Lungs are otherwise clear. No pleural effusions or pneumothoraces. Heart and mediastinum of normal size and contour.


Predicted Text:
The Effect of Different Doses of Vitamin C on the Growth of A study was carried out to evaluate the effect of different doses of vitamin C (1, 2, 4, 8, 16, 32, 64, 128, 256, and 256 mg / kg body weight) on growth of rats. It was observed that the growth was significantly (How to Avoid the Risk of Inappropriate Use of Antibiotics in the Intensive Care Unit. A Systematic Review. Part I: The Use and Timing of Antibiotic Therapy. The Role of the Microbiome. An Overview of Current Evidence. What Works. How To Avoids the Risks of Overuse of Antimicrobial Therapy in Intensive care Unit Patients. "Antibiotic


In [65]:
torch.save(mapper.state_dict(), f"ff_mapper_{model_name}.pth")