In [None]:
%pip install -r requirements.txt

In [None]:
import os

import torch
import torch.optim as optim
from torch.utils.data import random_split

import dataset # MiniFlickrDataset, get_loader
from captioner import CaptioningModel
from train_caption import Trainer
from lr_warmup import LRWarmup
import utils
from data_loader import get_loader
from utils import Ranker
from retriever import Model, Criterion, train, val, save_ckp_rt


In [None]:
!unzip images.zip -d images/data/

In [None]:

class Args_caption():
    def __init__(self):
        self.data_path = 'dataset.pkl'
        self.clip_model = 'openai/clip-vit-base-patch32'
        self.text_model = 'gpt2'
        self.seed = 100
        self.num_workers = 0
        self.train_size = 0.84
        self.val_size = 0.13
        self.test_size = 100
        self.epochs = 10
        self.lr = 3e-3
        self.k = 0.33
        self.batch_size_exp = 6
        self.ep_len = 4
        self.num_layers = 6
        self.n_heads = 16
        self.forward_expansion = 4
        self.max_len = 40
        self.dropout = 0.1
config = Args_caption()

In [None]:

class Args_retrieval():
    def __init__(self):
        # Dataset
        self.data_root = "./images/data/"
        self.data_set = "dress"
        self.image_root = os.path.join(self.data_root, 'resized_images/')
        self.caption_path = os.path.join(self.data_root, 'images/data/captions/captions/cap.{}.{}.json')
        self.split_path = os.path.join(self.data_root, 'images/data/image_splits/image_splits/split.{}.{}.json')

        # Model
        self.embed_dim = 512
        self.vision_feature_dim = 512
        self.text_feature_dim = 512

        # Training
        self.log_step = 15
        self.batch_size = 64
        self.learning_rate = 0.001
        self.num_workers = 4
        self.epochs = 3
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = Args_retrieval()


In [None]:
#setup training for captioning model

device = utils.init_env(config.seed)

# Create data loaders
dataset = dataset.MiniFlickrDataset(config.data_path)
config.train_size = int(config.train_size * len(dataset))
config.val_size = len(dataset) - config.train_size - config.test_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [config.train_size, config.val_size, config.test_size])
train_loader = dataset.get_loader(
    train_dataset,
    bs_exp=config.batch_size_exp,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
)
test_loader = dataset.get_loader(
    test_dataset,
    bs_exp=0,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
    train=False,
)

# Creat model
capt_model = CaptioningModel(
    clip_model=config.clip_model,
    text_model=config.text_model,
    ep_len=config.ep_len,
    num_layers=config.num_layers,
    n_heads=config.n_heads,
    forward_expansion=config.forward_expansion,
    dropout=config.dropout,
    max_len=config.max_len,
    device=device
)

# Create optimizer, lr scheduler
optimizer = optim.Adam(capt_model.parameters(), lr=config.lr)
warmup = LRWarmup(epochs=config.epochs, max_lr=config.lr, k=config.k)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup.lr_warmup)

# Create trainer
trainer = Trainer(
    model=capt_model,
    optimizer=optimizer,
    scaler=torch.amp.GradScaler('cuda'),
    scheduler=scheduler,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device
)

# use _load_ckpt method of the trainer to load weights from the saved checkpoint to resume the training. Below is a sample code for the same

#trainer._load_ckp("path to .pt file")


In [None]:
# Start training
for epoch in range(trainer.epoch, config.epochs):
    trainer.train_epoch()

    score = trainer.test_epoch()
    print("Score: {:.4f}".format(score))

    os.makedirs("checkpoints", exist_ok=True)
    if (epoch + 1) % 3 == 0:
        trainer.save_ckp(os.path.join("checkpoints", f'epoch_{epoch + 1}.pt'))


In [None]:
!python resize_images.py --image_dir images/data/images/data/ --output_dir images/data/resized_images --image_size 256

In [None]:
#setup training for retrieval model
# Build data loader
data_loader = get_loader(
    args.image_root.format(args.data_set),
    args.caption_path.format(args.data_set, 'train'),
    args.batch_size,
    shuffle=True,
    return_target=True,
    num_workers=args.num_workers,
)
data_loader_dev = get_loader(
    args.image_root.format(args.data_set),
    args.caption_path.format(args.data_set, 'val'),
    args.batch_size,
    shuffle=False,
    return_target=True,
    num_workers=args.num_workers,
)
#images\data\images\data\captions\captions\cap.dress.train.json
# Build model, criterion, oprimizer, evaluator
ret_model = Model(args.vision_feature_dim, args.text_feature_dim, args.embed_dim)
ret_model.to(args.device)
ret_model.train()
criterion = Criterion()
current_lr = args.learning_rate
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ret_model.parameters()), lr=current_lr)
ranker = Ranker(root=args.image_root.format(args.data_set),
                image_split_file=args.split_path.format(args.data_set, 'val'),
                transform=None, num_workers=args.num_workers)

In [None]:
# Start training
best_score = 0
for epoch in range(args.epochs):

    train(data_loader, ret_model, criterion, optimizer, args.log_step)
    best_score = val(data_loader_dev, ret_model, ranker, best_score)

save_ckp_rt(ret_model, os.path.join("checkpoint", f'epoch_{-1}.pt'))
print(best_score)

In [None]:
#generate a caption from an image using the trained captioning model
import PIL.Image as Image
import matplotlib.pyplot as plt
import numpy as np
import io
def generate_caption(capt_model, images):

        num_examples = len(images)
        capt_model.model.eval()

        fig, axs = plt.subplots(num_examples, 1, figsize=(20, 12))
        captions = list()
        for idx, img in enumerate(images):
            # img_path = image_paths[idx]
            # img = Image.open(img_path)

            with torch.no_grad():
                caption, _ = capt_model.model(img)

            axs[idx].imshow(img)
            axs[idx].set_title(caption)
            captions.append(caption)
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        plt.show()
        fig.clear()
        plt.close(fig)
        return captions



In [None]:
#get image from caption
def get_image_from_caption(ranker, caption):
    ret_model.eval()
    with torch.no_grad():
        image = ranker.retrieve_image(ret_model, caption)
    return image

def get_caption_from_image(ranker, image_path):
    ret_model.eval()
    with torch.no_grad():
        caption = ranker.retrieve_caption(ret_model, image_path)
    return caption

In [None]:
#from an input image, retrieve a caption and generate a caption for the retrieved image
input_image = "path/to/your/image.jpg"
img = Image.open(input_image).convert("RGB")
retrieved_caption = get_caption_from_image(ranker, input_image)
caption = generate_caption(capt_model, [img])
print("captions: ", retrieved_caption, caption)

In [None]:
# from an input caption, retrieve an image and generate a caption for the retrieved image
input_caption = "A blue dress with white polka dots"
retrieved_image = get_image_from_caption(ranker, input_caption)
plt.imshow(np.array(retrieved_image))
caption = generate_caption(capt_model, [retrieved_image])