In [None]:
from transformers import (
    AutoFeatureExtractor, 
    AutoTokenizer, 
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer, 
    default_data_collator,
)
import pandas as pd 
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


In [None]:
df_image = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv')
df_report = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv')

In [None]:
df = pd.DataFrame({'imgs': [],'captions': []})
for i in range(len(df_image)):
    uid = df_image.iloc[i]['uid']
    image = df_image.iloc[i]['filename']
    index = df_report.loc[df_report['uid'] ==uid]
    
    if not index.empty:    
        index = index.index[0]
        caption = df_report.iloc[index]['findings']
        if type(caption) == float:
         
            continue 
        df = pd.concat([df,pd.DataFrame([{'imgs': image, 'captions': caption}])],axis=0)
df.head()

In [None]:
loc = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized/'
df['imgs'] = loc + df['imgs']
df.head()

In [None]:
encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "gpt2"

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# maximum length for the captions
max_length = 128
sample = df.iloc[99]

# sample image
image = Image.open(sample['imgs']).convert('RGB')
# sample caption
caption = sample['captions']

# apply feature extractor on the sample image
inputs = feature_extractor(images=image, return_tensors='pt')
# apply tokenizer
outputs = tokenizer(
            caption, 
            max_length=max_length, 
            truncation=True, 
            padding='max_length',
            return_tensors='pt',
        )

In [None]:
print(f"Inputs:\n{inputs}\nOutputs:\n{outputs}")

In [None]:
class LoadDataset(Dataset):
    def __init__(self, df):
        self.images = df['imgs'].values
        self.captions = df['captions'].values
    
    def __getitem__(self, idx):
        # everything to return is stored inside this dict
        inputs = dict()

        # load the image and apply feature_extractor
        image_path = str(self.images[idx])
        image = Image.open(image_path).convert("RGB")
        image = feature_extractor(images=image, return_tensors='pt')

        # load the caption and apply tokenizer
        caption = self.captions[idx]
        labels = tokenizer(
            caption, 
            max_length=max_length, 
            truncation=True, 
            padding='max_length',
            return_tensors='pt',
        )['input_ids'][0]
        
        # store the inputs and labels in the dict we created
        inputs['pixel_values'] = image['pixel_values'].squeeze()   
        inputs['labels'] = labels
        return inputs
    
    def __len__(self):
        return len(self.images)

In [None]:
train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True, random_state=42)
train_df, valid_df = train_test_split(train_df, test_size=0.1, shuffle=True, random_state=42)
train_ds = LoadDataset(train_df)
test_ds = LoadDataset(test_df)
valid_ds = LoadDataset(valid_df)

In [None]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_checkpoint, 
    decoder_checkpoint
    
)
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# model.config.vocab_size = model.config.decoder.vocab_size
model.config.num_beams = 4

In [None]:
batch = next(iter(train_ds))

model(pixel_values=batch['pixel_values'].unsqueeze(0), labels=batch['labels'].unsqueeze(0))

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="image-caption-generator", # name of the directory to store training outputs
    evaluation_strategy="epoch",          # evaluate after each epoch
    per_device_train_batch_size=16,       # batch size during training
    per_device_eval_batch_size=16,        # batch size during evaluation
    learning_rate=0.00005,
    weight_decay=0.01,                    # weight decay for AdamW optimizer
    num_train_epochs=5,                   # number of epochs to train
    save_strategy='epoch',                # save checkpoints after each epoch
    report_to='none',                     # prevents logging to wandb, mlflow...
)

trainer = Seq2SeqTrainer(
    model=model, 
    tokenizer=feature_extractor, 
    data_collator=default_data_collator,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    args=training_args,
)

In [None]:
trainer.train()

In [None]:
import torch

torch.save(model.state_dict(), "model.h5")

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
inputs = test_ds[43]['pixel_values']
model.eval()
with torch.no_grad():
    # uncomment the below line if feature extractor is not applied to the image already
    # inputs = feature_extractor(images=inputs, return_tensors='pt').pixel_values

    # model prediction 
    out = model.generate(
        inputs.unsqueeze(0).to('cuda'), # move inputs to GPU
        num_beams=4, 
#         max_length=17
        )
# convert token ids to string format
decoded_out = tokenizer.decode(out[0], skip_special_tokens=True)

print(decoded_out)
plt.axis('off')
plt.imshow(torch.permute(inputs, (1, 2, 0)));