In [None]:
import pandas as pd
from PIL import Image
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
import os
import zipfile
import requests
from transformers import ViltConfig
import torch
import io

In [None]:
config = ViltConfig.from_pretrained("vilt-Med_PMC")

In [None]:
config.label2id

In [None]:
# load data set combined_datasets
train= load_from_disk('./PreprocessedData/train')
validation= load_from_disk('./PreprocessedData/validation')

dataset_dict= DatasetDict({'train': train, 'validation': validation})
dataset_dict

In [None]:
dataset_dict['train'][0]

In [None]:
class VQADataset(torch.utils.data.Dataset):
    """VQA dataset."""

    def __init__(self, dataset, processor, id2label):
        self.dataset = dataset
        self.processor = processor
        self.id2label = id2label

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get image, question, and label
        data = self.dataset[idx]
        image_path = data['image_path'].replace('\\', '/')
        question = data['question']
        label = data['label']

        # Open image
        image = Image.open(image_path).convert('RGB')
        #image = data['image']
        
        # Process image and question
        encoding = self.processor(image, question, padding="max_length", truncation=True, return_tensors="pt")
        
        # Remove batch dimension
        for k, v in encoding.items():
            encoding[k] = v.squeeze()

        # Create target tensor
        targets = torch.zeros(len(self.id2label))
        targets[label] = 1  # Set the label index to 1

        encoding["labels"] = targets

        return encoding

In [None]:

from transformers import ViltProcessor

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

train_dataset = VQADataset(dataset=dataset_dict['train'], processor=processor, id2label=config.id2label)
validation_dataset = VQADataset(dataset=dataset_dict['validation'], processor=processor, id2label=config.id2label)


In [None]:
len(train_dataset), len(validation_dataset)

In [None]:
train_dataset[0].keys()

In [None]:
processor.decode(train_dataset[0]['input_ids'])

In [None]:
labels = torch.nonzero(train_dataset[0]['labels']).squeeze().tolist()

config.id2label[labels]

In [None]:
train_dataset[0]['pixel_values']

In [None]:
from transformers import ViltForQuestionAnswering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViltForQuestionAnswering.from_pretrained("./vilt-PMC_VQA",
                                                 id2label=config.id2label,
                                                 label2id=config.label2id)
model.to(device)

In [None]:
from transformers import ViltForQuestionAnswering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm",
                                                 id2label=config.id2label,
                                                 label2id=config.label2id)
model.to(device)

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=4, shuffle=False)

In [None]:
len(train_dataloader), len(validation_dataloader)

In [None]:
batch = next(iter(train_dataloader))

In [None]:
from PIL import Image
import numpy as np

image_mean = processor.image_processor.image_mean
image_std = processor.image_processor.image_std

batch_idx = 0

unnormalized_image = (batch["pixel_values"][batch_idx].numpy() * np.array(image_mean)[:, None, None]) + np.array(image_std)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
Image.fromarray(unnormalized_image)

In [None]:
processor.decode(batch["input_ids"][batch_idx])

In [None]:
labels = torch.nonzero(batch['labels'][batch_idx]).squeeze().tolist()

In [None]:
config.id2label[labels] 

In [None]:
import torch
from tqdm.notebook import tqdm
import os

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Define the directory to save checkpoints
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Define the save step (e.g., save every 200 steps)
save_steps = 10000
global_step = 0

# Define early stopping criteria
patience = 3  # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_without_improvement = 0

# Define the function to save the checkpoint
def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-epoch-{epoch}-step-{step}.pt")
    torch.save({
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

# Function to load the checkpoint
def load_checkpoint(checkpoint_path, model, optimizer):
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        step = checkpoint['step']
        print(f"Resumed from epoch {epoch}, step {step}")
        return epoch, step
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, 0  # Starting from scratch if no checkpoint


# Function to evaluate the model on the validation set
def evaluate(model, validation_dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in validation_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.item()
    avg_loss = total_loss / len(validation_dataloader)
    model.train()  # Set the model back to training mode
    return avg_loss

# Training loop
loss_list = []
model.train()

# Modify the training loop to start from the correct epoch and step
for epoch in range(1, 50):  # Start from the loaded epoch
    print(f"Epoch: {epoch}")
    for batch in tqdm(train_dataloader):
        # Training logic (same as before)
        batch = {k:v.to(device) for k,v in batch.items()}
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss_list.append(loss.item())
        loss.backward()
        optimizer.step()

        # Increment global step
        global_step += 1

        # Save the model at save_steps intervals
        if global_step % save_steps == 0:
            # Save latest checkpoint
            save_checkpoint(model, optimizer, epoch, global_step, checkpoint_dir)
            #model.save_pretrained("./vilt-PMC_VQA")

    # Evaluate and save at the end of each epoch
    val_loss = evaluate(model, validation_dataloader)
    print(f"Validation Loss: {val_loss}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        # Save the best model as the latest checkpoint
        save_checkpoint(model, optimizer, epoch, global_step, checkpoint_dir)
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered")
            break



In [None]:
import matplotlib.pyplot as plt
# write loss_list to a file
with open('loss_list.csv', 'w') as f:
    for item in loss_list:
        f.write("%s\n" % item)
        
# plot loss_list
plt.figure(figsize=(12, 6)) 
step = 106139 // 100 #calculate the step size based on each epoch's size
loss_epoch = loss_list[0:106139]
plt.plot(loss_epoch[::step], color='blue', linewidth=1.0)  
plt.title("Epoch 0 (downsampled)", fontsize=16)  
plt.xlabel("Steps", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.grid(True, which='both', linestyle='--', linewidth=0.7)
plt.xticks(fontsize=12) 
plt.yticks(fontsize=12) 
plt.tight_layout() 
plt.show()

In [None]:
# save model in local directory
model.save_pretrained("./vilt-PMC_VQA")

In [None]:
import numpy as np

In [None]:
# test on sample picked validation dataset

import random


# Pick a random sample from the validation dataset
sample = random.choice(validation_dataset)
sample

In [None]:
# Process the sample
sample = {k: v.unsqueeze(0).to(device) for k, v in sample.items()}
outputs = model(**sample)

# Get the predicted label
predicted_label = outputs.logits.argmax().item()
predicted_answer = config.id2label[predicted_label]
predicted_answer


In [None]:
# Get the ground truth question image und answer
question = processor.decode(sample['input_ids'].squeeze())
question

In [None]:
answer = config.id2label[torch.nonzero(sample['labels'].squeeze()).item()]
answer

In [None]:
unnormalized_image = (sample['pixel_values'].squeeze().cpu().numpy() * np.array(image_mean)[:, None, None]) + np.array(image_std)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
Image.fromarray(unnormalized_image)

In [None]:
""" from tqdm.notebook import tqdm


# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Define the directory to save checkpoints
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Define the save step (e.g., save every 200 steps)
save_steps = 200
global_step = 0

# Define the function to save the checkpoint
def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-epoch-{epoch}-step-{step}.pt")
    torch.save({
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

# Training loop
model.train()
for epoch in range(50):  # loop over the dataset multiple times
    print(f"Epoch: {epoch}")
    for batch in tqdm(train_dataloader):
        # get the inputs;
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()

        # Increment global step
        global_step += 1

        # Save the model at save_steps intervals
        if global_step % save_steps == 0:
            save_checkpoint(model, optimizer, epoch, global_step, checkpoint_dir)

    # Save the model at the end of each epoch
    save_checkpoint(model, optimizer, epoch, global_step, checkpoint_dir)
 """

In [None]:
""" from tqdm.notebook import tqdm
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(50):  # loop over the dataset multiple times
      print(f"Epoch: {epoch}")
      for batch in tqdm(train_dataloader):
            # get the inputs;
            batch = {k:v.to(device) for k,v in batch.items()}

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(**batch)
            loss = outputs.loss
            print("Loss:", loss.item())
            loss.backward()
            optimizer.step() """