In [1]:
import os 
import json 
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
DATA_PATHS = {
    "raw_dataset_path": "dataset/", 
    "image_set_path": "complete_image_set/train", 
    "checkpoints_dir": "checkpoints/", 
    "save_model_dir": "checkpoints/model/",
    "inference_set_dir": "complete_image_set/inference/"
}

In [3]:
def create_metadata(DATA_PATHS): 
    captions = []
    for category in os.listdir(DATA_PATHS["raw_dataset_path"]): 
        for file in os.listdir(f"{DATA_PATHS['raw_dataset_path']}/{category}/images/"):
            with open(f"{DATA_PATHS['raw_dataset_path']}/{category}/prompts/{file.split('.')[0]}.txt", "r") as f: 
                prompts = f.readlines()

                for prompt in prompts: 
                    captions.append({"file_name": f"{category}/{file}", "prompts": prompt.split('\n')[0]})

    # add metadata.jsonl file to this folder
    with open(f"{DATA_PATHS['image_set_path']}/metadata.jsonl", 'w') as f:
        for item in captions:
            f.write(json.dumps(item) + "\n")

In [4]:
def plot_loss_curve(ax, train_losses, plot_title): 
    ax.plot(train_losses, label=plot_title, color='blue')
    ax.set_xlabel('Steps')
    ax.set_ylabel('Loss')
    ax.set_title(plot_title)
    ax.legend()
    ax.grid(True)

In [5]:
def track_losses(trainer): 
    track_losses = ["loss", "eval_loss", "eval_wer_score"]
    losses = {key: [record[key] for record in trainer.state.log_history if key in record] for key in track_losses}
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    for index, loss in enumerate(track_losses): 
        plot_loss_curve(axs[index], losses[loss], loss)

    plt.tight_layout()
    plt.show()

In [6]:
def plot_image_with_captions(image_path, captions):
    image = Image.open(image_path)
    plt.figure(figsize=(10, 8))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.title('Image')

    plt.subplot(1, 2, 2)
    plt.axis('off')
    plt.title('Captions')
    for i, caption in enumerate(captions):
        plt.text(0, 0.9 - 0.1*i, caption, fontsize=8, wrap=True)

    plt.show()