Imports

In [None]:
!pip install timm
!pip install transformers

In [None]:
from IPython import display as ipythondisplay
from torch import nn
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModel
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import cv2
import gc
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import time
import timm
import torch
import torch.nn.functional as F

  


## Some pre-preocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive')
df = pd.read_csv('/content/drive/Shareddrives/DeepLearning/datav2/dataset.csv')

Mounted at /content/drive


In [None]:

# Ops to consolidate dataset to the notebook standards

# Create a compound column using other columns. Comment lines for less or more complexity

# "a sagital plane
# MR scan
# of cranial trauma
# on a 58 year old male
# diagnosed with brain death
df["grouped_diag"] = df.apply(lambda row: \
                              "a " + str(row.Plane) + " plane "\
                              + str(row.Core_Modality) + " scan"\
                              + " of "  + str(row.Location) + " " + str(row.Category)\
                              + " on a " + str(row.Patient_Age) + " year old " + str(row.Patient_Gender)
                              + " depicting "+ str(row.Case_Diagnosis)
                              , axis = 1)


#  reformat to fit the shape
#  	image 	caption id
df.rename(columns={"filename":"image", "grouped_diag":"caption"},inplace=True)
needed_cols = ["image","caption"]
df = df[df.columns.intersection(needed_cols)]

image_path = "/content/drive/Shareddrives/DeepLearning/datav2/output"
captions_path = "/content/drive/Shareddrives/DeepLearning/datav2"

#  	image 	caption
# 0 	1000268201_693b08cb0e.jpg 	A child in a pink dress is climbing up a set o...
# 1 	1000268201_693b08cb0e.jpg 	A girl going into a wooden building .
# 2 	1000268201_693b08cb0e.jpg 	A little girl climbing into a wooden playhouse .
# 3 	1000268201_693b08cb0e.jpg 	A little girl climbing the stairs to her playh...
# 4 	1000268201_693b08cb0e.jpg 	A little girl in a pink dress going into a woo...

# I have no idea why tf he writes the dataset to memory but whatev
df.to_csv(captions_path + "/captions.csv", index=False)

df.head()

Unnamed: 0,image,caption
0,synpic100377.jpg,a Multiple or Montage plane CT scan of Musculo...
1,synpic100378.jpg,a Multiple or Montage plane CT scan of Musculo...
2,synpic100379.jpg,a Multiple or Montage plane CT scan of Musculo...
3,synpic100380.jpg,a Oblique plane XR scan of Musculoskeletal Tra...
4,synpic100381.jpg,a Multiple or Montage plane MR scan of Musculo...


## Config
Global config for all the project.

In [None]:
class CFG:
    debug = False
    image_path = image_path
    captions_path = captions_path
    batch_size = 12
    num_workers = 2
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 2
    saved_model_clinical = '/content/drive/Shareddrives/DeepLearning/datav2/withDiagnostics2.pt'
    trained_model = 'clinical_bert_weights.pt'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256
    dropout = 0.1

Create globally available test and validation datasets.

In [None]:
from sklearn.model_selection import train_test_split

def make_train_valid_dfs():
    dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
    train, test = train_test_split(dataframe, test_size=.1, train_size=.9, shuffle=True, random_state=77,stratify=None)
    return train, test

testing_df , training_df = make_train_valid_dfs()

## Utils

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


## Dataset

In [None]:
# Custom dataset object. Will tokenize text and apply transforms to images before yielding them.

class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.skippedImgCount = 0
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=CFG.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }
        ################################
        # MASSIVE GDRIVE BUG HERE
        # Sometimes, reading an image from disk fails, which crashes the entirety of the program
        # Here we default to adding the image at dataset[0]
        ################################
        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        if image is None:
          self.skippedImgCount += 1
          return self.__getitem__(1)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

        return item


    def __len__(self):
        return len(self.captions)



def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )



## Image Encoder

In [None]:
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

## Text Encoder

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            # self.model = DistilBertModel.from_pretrained(model_name)

            # Use Bio-ClinicalBERT
            self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model)

        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]