In [None]:
from transformers import (
    AutoFeatureExtractor, 
    AutoTokenizer, 
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer, 
    default_data_collator,
)

from torch.utils.data import Dataset

import pandas as pd
from sklearn.model_selection import train_test_split

from pathlib import Path
from PIL import Image

In [None]:
df2 = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv')
df1 = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv')

In [None]:
images_captions_df = pd.DataFrame({'imgs': [],
                                    'captions': []})
for i in range(len(df2)):
    uid = df2.iloc[i]['uid']
    image = df2.iloc[i]['filename']
    index = df1.loc[df1['uid'] ==uid]
    
    if not index.empty:    
        index = index.index[0]
        caption = df1.iloc[index]['findings']
        if type(caption) == float:
         
            continue 
        images_captions_df = pd.concat([images_captions_df, pd.DataFrame([{'imgs': image, 'captions': caption}])], ignore_index=True)
images_captions_df.head()

In [None]:
encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "Molkaatb/ChestX"

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
p = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized/'
images_captions_df['imgs'] = p+ images_captions_df['imgs']
images_captions_df.head()

In [None]:
# maximum length for the captions
max_length = 1024
sample = images_captions_df.iloc[99]

# sample image
image = Image.open(sample['imgs']).convert('RGB')
# sample caption
caption = sample['captions']

# apply feature extractor on the sample image
inputs = feature_extractor(images=image, return_tensors='pt')
# apply tokenizer
outputs = tokenizer(
            caption, 
            max_length=max_length, 
            
            padding='max_length',
            return_tensors='pt',
        )
print(len(outputs[0]))

In [None]:
class LoadDataset(Dataset):
    def __init__(self, df):
        self.images = images_captions_df['imgs'].values
        self.captions = images_captions_df['captions'].values
        
    def __getitem__(self, idx):
        # everything to return is stored inside this dict
        inputs = dict()

        # load the image and apply feature_extractor
        image_path = str(self.images[idx])
        image = Image.open(image_path).convert("RGB")
        image = feature_extractor(images=image, return_tensors='pt')

        # load the caption and apply tokenizer
        caption = self.captions[idx]
        labels = tokenizer(
            caption, 
            max_length=max_length, 
            truncation=True, 
            padding='max_length',
            return_tensors='pt',
        )['input_ids'][0]
        
        # store the inputs and labels in the dict we created
        inputs['pixel_values'] = image['pixel_values'].squeeze()   
        inputs['labels'] = labels
        return inputs
    
    def __len__(self):
        return len(self.images)

In [None]:
train_,test_df =train_test_split(images_captions_df, test_size=0.10, shuffle=True, random_state=42)


In [None]:
train_df,val_df =train_test_split(train_, test_size=0.10, shuffle=True, random_state=42)


In [None]:
print(len(train_df))
print(len(val_df))
print(len(test_df))

In [None]:
train_ds = LoadDataset(train_df)
test_ds = LoadDataset(test_df)
val_ds = LoadDataset(val_df)

In [None]:
test_df.head()


In [None]:
model = VisionEncoderDecoderModel.from_pretrained("Molkaatb/ChestX").to('cuda')
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# model.config.vocab_size = model.config.decoder.vocab_size
model.config.num_beams = 4

In [None]:
import tqdm 
predicted_captions = [] 
for i in tqdm.tqdm( val_df['imgs']):
    img =  Image.open(i).convert("RGB")
    features = feature_extractor(img, return_tensors="pt").pixel_values.to("cuda")
    caption = tokenizer.decode(model.generate(features,max_length = 1024)[0],skip_special_tokens=True)
    predicted_captions.append(caption)
print(len(predicted_captions))

In [None]:
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

# Assuming you have a list of predicted captions and a list of ground truth captions
generated_captions = predicted_captions
ground_truth_captions = val_df['captions'].values
# Convert the caption lists into the format expected by nltk
ground_truth_captions = [[caption.split() for caption in captions] for captions in ground_truth_captions]
generated_captions = [caption.split() for caption in generated_captions]


# Define the smoothing function to use
smoothie = SmoothingFunction().method4

# Compute the BLEU score with smoothing
weights = (0.25, 0.25, 0.25, 0.25)  # equal weights for 1-4 gram BLEU scores
score = corpus_bleu(ground_truth_captions, predicted_captions,weights =weights)
print(f'The BELU Score Is: {score}')

In [None]:
import torch

In [None]:
import matplotlib.pyplot as plt
import torch

# Assuming you have a loop 
for idx in range(40, 50):
    inputs = val_ds[idx]['pixel_values']
    with torch.no_grad():
        # Model prediction 
        out = model.generate(
            inputs.unsqueeze(0).to('cuda'),  # Move inputs to GPU
            num_beams=4,
            max_length=512
        )

    # Convert token ids to string format
    decoded_out = tokenizer.decode(out[0], skip_special_tokens=True)

    # Display the result
    print(f"Prediction for index {idx}: {decoded_out}")

    # Display the image
    plt.figure()
    plt.axis('off')
    plt.imshow(torch.permute(inputs, (1, 2, 0)))
    plt.show()
    print("\n\nActual Image and Text\n\n")
    # Display the text
    labels_tensor = val_ds[idx]['labels']

    decoded_out = tokenizer.decode(labels_tensor, skip_special_tokens=True)

    print(decoded_out)

    # Display actual image
    inputs = val_ds[idx]['pixel_values']

    # Convert the PyTorch tensor to a NumPy array
    image_array = inputs.permute(1, 2, 0).numpy()
    # Display the image using matplotlib
    plt.figure()
    plt.axis('off')
    plt.imshow(image_array)
    plt.show()
