In [None]:
!pip install --upgrade transformers
!pip install --upgrade nltk
!pip install -q evaluate
!pip install bert-score
!pip install -q rouge_score
!pip install -q git+https://github.com/salaniz/pycocoevalcap.git

In [None]:
import nltk
nltk.__version__

In [None]:
import torch

# device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
torch.cuda.device_count()

In [None]:
def count_params(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, BeitImageProcessor, AutoModel

# Vision
vision_module = 'microsoft/beit-base-patch16-224-pt22k-ft22k'
feature_extractor = BeitImageProcessor.from_pretrained(vision_module)

# Language
language_module = 'luqh/ClinicalT5-base'
tokenizer = T5Tokenizer.from_pretrained(language_module)

In [None]:
import glob
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
import pickle

class ImageCLEF(Dataset):
  def __init__(self, tokenizer = None,
               feature_extractor = None,
               image_folder = None,
               data_csv_path = None,
               one_hot_path = None,
               cuis_mapping = None, 
               data=None):

    self.tokenizer = tokenizer
    self.feature_extractor = feature_extractor

    # if data is None all others argument cant be None
    assert (tokenizer is not None
            and feature_extractor is not None
            and image_folder is not None
            and data_csv_path is not None
            and one_hot_path is not None 
#            and cuis_mapping is not None
           ) or data is not None, "All other arguments must be passed if data is None!"

    if data is None:
      self.load_data(image_folder, data_csv_path, one_hot_path)
    else:
      self.data = data

  def load_data(self, image_folder, data_csv_path, one_hot_path):

    with open(one_hot_path, 'rb') as file:
      one_hot = pickle.load(file)

    self.one_hot = one_hot
    self.data = []

    data_csv = pd.read_csv(data_csv_path)

    # Setup the total feature file
    image_paths = glob.glob(image_folder + '/*')

    image_ids_list = list(data_csv['ID'])

    for path in tqdm(image_paths):

      # Obtaining Image Id
      image_id = path.split('/')[-1].split('.')[0]

      if image_id in image_ids_list:

        # Mapping image id with other variables
        # caption = data_csv[data_csv['image_id'] == image_id]['caption'].item()
        caption = data_csv[data_csv['ID'] == image_id]['non_stopwords'].item()
        label_one_hot = data_csv[data_csv['ID'] == image_id]['labels'].item()
        label_id = data_csv[data_csv['ID'] == image_id]['CUIs'].item()
        label_name = data_csv[data_csv['ID'] == image_id]['Canonical_name'].item()

        sample = {
            'image_id': image_id,
            'path' : path,
            'captions': caption,
            'label_id': label_id,
            'label_one_hot' : label_one_hot,
            'label_name' : label_name
        }

        self.data.append(sample)

      else : continue

  def num_classes(self):
    return len(self.one_hot.classes_)

  def get_classes(self):
    return self.one_hot.classes_

  def __getitem__(self, idx):
    sample = self.data[idx]

    return {
      'image_id': sample['image_id'],
      'path' : sample['path'],
      'captions': sample['captions'],
      'label_id': sample['label_id'],
      'label_one_hot' : sample['label_one_hot'],
      'label_name' : sample['label_name']
    }

  def split_data(self, validation_size, random_state=42):

    # Split train and evaluation set
    train_data, val_data = train_test_split(self.data,
                                                 test_size=validation_size,
                                                 random_state=random_state)

    return (ImageCLEF(tokenizer=self.tokenizer, data=train_data),
            ImageCLEF(tokenizer=self.tokenizer, data=val_data))

  def __len__(self):
    return len(self.data)

  def collate_fn(self, batch):

    label_one_hot = [each['label_one_hot'] for each in batch]
    label_id = [each['label_id'] for each in batch]
    label_name = [each['label_name'].split("['")[-1].split("']")[0].replace("'", "") for each in batch]

    images = [Image.open(each['path']).convert('RGB') for each in batch]

    raw_captions = [each['captions'] for each in batch]
    image_ids = [each['image_id'] for each in batch]

    extracted_images = self.feature_extractor(images = images, return_tensors = 'pt')
    tokenized_captions = self.tokenizer(raw_captions, padding = True, truncation = True, max_length = 128, return_tensors = 'pt')
    tokenized_label_name = self.tokenizer(label_name, padding = True, truncation = True, max_length = 128, return_tensors = 'pt')

    sample = {
      'ids' : image_ids,
      'raw_captions' : raw_captions,
      'pixel_values' : extracted_images.pixel_values, # tensor
      'labels' : tokenized_captions.input_ids, # tensor
      'attention_mask' : tokenized_captions.attention_mask,
      'label_one_hot' : label_one_hot,
      'label_id' : label_id,
      'label_name' : label_name,
      'tokenized_label_name' : tokenized_label_name.input_ids,
      'tokenized_label_mask' : tokenized_label_name.attention_mask
    }

    return sample

In [None]:
data = ImageCLEF(
    tokenizer = tokenizer,
    feature_extractor = feature_extractor,
    image_folder = '/kaggle/input/imageclef2024/train_images/train',
    data_csv_path = '/kaggle/input/imageclef2024/merged_captions_concepts_.csv',
#     data_csv_path = '/kaggle/input/imageclef2024/5000_merged_captions_concepts_.csv',
    one_hot_path = '/kaggle/input/imageclef2024/one_hot.pkl',
)

In [None]:
train_data, val_data = data.split_data(validation_size=0.0375)

In [None]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True, collate_fn=data.collate_fn)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=False, collate_fn=data.collate_fn)

In [None]:
from Captioning_Module import BeIT_T5_Concepts

In [None]:
model = nn.DataParallel(BeIT_T5_Concepts(vision_module = vision_module, 
                                language_module = language_module, 
                                device= device))

In [None]:
import evaluate
import numpy as np

meteor = evaluate.load('meteor')
rouge = evaluate.load('rouge')
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")

In [None]:
def ignore_padding(outputs, labels, padding_values = 1):

  mask = labels != padding_values

  new_outputs, new_labels = [], []

  for i, each in enumerate(mask):
    ignore_outputs = outputs[i][each]
    ignore_labels = labels[i][each]

    new_outputs.append(ignore_outputs), new_labels.append(ignore_labels)

  return new_outputs, new_labels

def convert_prediction(outputs):
  res = []

  for each in outputs:
    get = each.argmax(dim = -1)
    res.append(get)

  return res

def convert_ids_to_string(ids, tokenizer):

  new_ids = ids.copy()

  if new_ids[0].dim() == 2: # prediction
    new_ids = convert_prediction(new_ids) # batch_size, seq_len

  return [tokenizer.decode(id) for id in new_ids]

def compute_accuracy(outputs_, labels_, padding_idx):

  arg_max_outputs = convert_prediction(outputs_)

  flat_output_ = []
  flat_labels_ = []

  for output, label in zip(arg_max_outputs, labels_):
    flat_output_.extend(output), flat_labels_.extend(label)

  flat_output_ = torch.tensor(flat_output_)
  flat_labels_ = torch.tensor(flat_labels_)

  acc = (flat_output_ == flat_labels_).sum().item() / flat_labels_.shape[0]
  return acc

def calculate_mean(numbers):
    total = sum(numbers)
    count = len(numbers)
    mean = total / count
    return mean

def compute_metrics(outputs, label_ids, tokenizer):

  outputs_, labels_ = ignore_padding(outputs, label_ids, tokenizer.pad_token_id)

  pred_ans = convert_ids_to_string(outputs_, tokenizer)
  ground_t = convert_ids_to_string(labels_, tokenizer)

  print(f'Prediction : {pred_ans}')
  print(f'Ground_truth : {ground_t}')
  print('       ')

  # Compute Accuracy
  accuracy_score = compute_accuracy(outputs_, labels_, tokenizer.pad_token_id)

  # Compute BLEU, ROUGE, METEOR
  bleu1_score = bleu.compute(predictions=pred_ans, references=ground_t, max_order=1)['bleu']
  bleu2_score = bleu.compute(predictions=pred_ans, references=ground_t, max_order=2)['bleu']
  bleu3_score = bleu.compute(predictions=pred_ans, references=ground_t, max_order=3)['bleu']
  bleu4_score = bleu.compute(predictions=pred_ans, references=ground_t, max_order=4)['bleu']
  rouge_score = rouge.compute(predictions=pred_ans, references=ground_t)['rougeL']
  meteor_score = meteor.compute(predictions=pred_ans, references=ground_t)['meteor']

  # Compute Bert Score
  bert_score = bertscore.compute(predictions=pred_ans, references=ground_t, model_type = 'microsoft/deberta-xlarge-mnli', device = device)
  bert_score_F1_mean = calculate_mean(bert_score['f1'])

  return np.array([bert_score_F1_mean, accuracy_score, bleu1_score, bleu2_score, bleu3_score, bleu4_score, rouge_score, meteor_score])

In [None]:
from tqdm.auto import tqdm
import os

def process_data(model, dataloader, criterion, optimizer=None, compute_metrics=None, saved_metrics = None, scaler=None, device='cpu', padding_values = 0, epoch = 0, train_mode=True):
    total_loss = 0
    total_scores = 0
    total_samples = 0
    # count = 0
    
    flag = 'Val'
    if train_mode : 
        flag = 'Train'
        
    save_part = 50

    model.to(device)
    if train_mode:
        model.train()
    else:
        model.eval()

    with torch.set_grad_enabled(train_mode):
        for i, samples in enumerate(tqdm(dataloader)):

          # Forward
          # with torch.cuda.amp.autocast():

          outputs = model(samples)
          loss = criterion(*outputs)

          # Backward (for training)
          if train_mode:
              optimizer.zero_grad()

              if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
              else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            # Calculate score
          total_loss += loss.item()
          total_scores += compute_metrics(*outputs)
            
          if i%save_part == 0 and i != 0:
            saved_loss = total_loss/i
            saved_scores = total_scores/i
            
            if train_mode : 
                torch.save(model.module.state_dict(), f'State_dict_BeIT5_Epoch_{epoch}_{i}.pth')
                
                if i/save_part > 1:
                    os.remove(f'State_dict_BeIT5_Epoch_{epoch}_{i-save_part}.pth')
                
            for idx, k in enumerate(saved_metrics.keys()):
                if k == 'loss':
                    saved_metrics[k].append(saved_loss)
                else : 
                    saved_metrics[k].append(saved_scores[idx-1])
                    
            torch.save(saved_metrics, f'{flag}_Scores_BeIT5_Epoch_{epoch}_{i}.pth')
            
            if i/save_part > 1:
                os.remove(f'{flag}_Scores_BeIT5_Epoch_{epoch}_{i-save_part}.pth')

    # if count == 0 : count = 1e-08
    total_loss /= len(dataloader)
    total_scores = total_scores /len(dataloader)

    if len(total_scores) == 1:
      total_scores = (total_scores, )

    if not isinstance(total_scores, tuple):
      total_scores = tuple(total_scores)

    return (total_loss, ) + total_scores

def print_metrics(metrics_dict, epoch, prefix=''):
    metric_strings = [f"{key.capitalize()}: {value[epoch]:.4f}" for key, value in metrics_dict.items()]
    print(f"{prefix} -> {', '.join(metric_strings)}")


def train_and_test(model, train_dataloader, test_dataloader, criterion, optimizer, compute_metrics, scaler, scheduler, epochs, padding_values, device):

  train_metrics = {'loss': [], 'bert_score': [], 'accuracy': [], 'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': [], 'rouge': [], 'meteor': []}
  test_metrics = {'loss': [], 'bert_score': [], 'accuracy': [], 'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': [], 'rouge': [], 'meteor': []}

  list_train_loss, list_test_loss = [], []

  for epoch in range(epochs):

    # Train data
    # process_data(model, dataloader, criterion, optimizer=None, compute_metrics=None, scaler=None, device='cpu', train_mode=True):
    scores = process_data(model=model,
                          dataloader=train_dataloader,
                          criterion=criterion,
                          optimizer=optimizer,
                          compute_metrics=compute_metrics,
                          saved_metrics = train_metrics,
                          scaler=None,
                          device=device,
                          padding_values = padding_values,
                          epoch = epoch,
                          train_mode=True)

    # Save scores
    for i, k in enumerate(train_metrics.keys()):
      train_metrics[k].append(scores[i])

    # Test data
    scores = process_data(model=model,
                          dataloader=test_dataloader,
                          criterion=criterion,
                          optimizer=optimizer,
                          compute_metrics=compute_metrics,
                          saved_metrics = test_metrics,
                          scaler=None,
                          device=device,
                          padding_values = padding_values,
                          epoch = epoch,
                          train_mode=False)

    # Save scores
    for i, k in enumerate(test_metrics.keys()):
      test_metrics[k].append(scores[i])

    # Using Scheduler
    scheduler.step(test_metrics['loss'][-1])

    # Tracking
    print(f'Epoch {epoch}:')
    print_metrics(train_metrics, epoch, prefix='Train')
    print_metrics(test_metrics, epoch, prefix='Test')

  return train_metrics, test_metrics

In [None]:
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.01, lr = 2e-05)
cm = lambda outputs, labels: compute_metrics(outputs.mT, labels, tokenizer)
scaler = torch.cuda.amp.GradScaler()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

In [None]:
epochs = 2
train_metrics, test_metrics = train_and_test(model = model,
                                             train_dataloader = train_dataloader,
                                             test_dataloader = val_dataloader,
                                             criterion = criterion,
                                             optimizer = optimizer,
                                             compute_metrics = cm,
                                             scaler = scaler,
                                             scheduler = scheduler,
                                             epochs = epochs,
                                             padding_values = tokenizer.pad_token_id,
                                             device = device)