In [None]:
from IPython.display import Image  # to visualize images within this notebook
import torch  # library used for implementing ML tools, great flexibility and ease of use, highly recommended
from PIL import Image as PILImage  # to load images
import open_clip  # open source implementation of CLIP, we use it to extract representations of queries and images
from torch.utils.data import Dataset, Subset  # useful class for managing datasets, greatly integrated within the pytorch environment
import matplotlib.pyplot as plt  # library for plotting
from matplotlib.pyplot import figure, subplots, imshow, axis  # library for plotting
import os  # contains utilities for reading directories on disk

In [None]:
print(torch.__version__)

In [None]:
import pandas as pd
indices = pd.read_pickle("indices_museum_dataset.pkl")
indices['train'][:10], indices['val'][:10], indices['test'][:10],

In [None]:
from torch.utils.data import Dataset
import torch

class DescriptionSceneMuseum(Dataset):
    def __init__(self, data_description_path, data_raw_description_path, data_scene_path, data_art_path, indices, split, customized_margin=False):
        self.description_path = data_description_path
        self.raw_description_path = data_raw_description_path
        self.data_pov_path = data_scene_path
        self.indices = indices[split]
        self.split = split

        available_data = [im.strip(".pt") for im in os.listdir(data_scene_path)]
        available_data = sorted(available_data)
        available_data = [available_data[ix] for ix in self.indices.tolist()]

        self.descs = [torch.load(os.path.join(data_description_path, f"{sm}.pt")) for sm in available_data]
        self.raw_descs = [" ".join(pd.read_pickle(os.path.join(data_raw_description_path, f"{sm}.pkl"))) for sm in available_data]
        self.pov_images = [torch.load(os.path.join(data_scene_path, f"{sm}.pt")) for sm in available_data]
        self.art_vectors = [torch.load(os.path.join(data_art_path, f"{sm}.pt")) for sm in available_data]
        self.names = available_data
        print(f"'{split.upper()}': {len(self.names)} names, "
              f"{len(self.descs)} sentences ({sum([len(x) for x in self.descs]) / len(self.descs)} avg), "
              f"{len(self.pov_images)} images ({sum([len(x) for x in self.pov_images]) / len(self.pov_images)} avg).")

    def __len__(self):
        return len(self.names)

    def __getitem__(self, index):
        desc_tensor = self.descs[index]
        if self.split == "train":
            raw_desc = self.raw_descs[index]
        scene_img_tensor = self.pov_images[index]
        scene_art_tensor = self.art_vectors[index]
        name = self.names[index]

        if self.split == "train":
            return desc_tensor, scene_img_tensor, scene_art_tensor, raw_desc, name, index
        else:
            return desc_tensor, scene_img_tensor, scene_art_tensor, name, index

In [None]:
visual_backbone = "rn50"
visual_bb_ftsize_k = {'rn18': 512, 'rn34': 512, 'rn50': 2048, 'rn101': 2048, 'vitb16': 768, 'vitb32': 768}
visual_bb_ftsize = visual_bb_ftsize_k[visual_backbone]

In [None]:
for vn in visual_bb_ftsize_k.keys():
    print(vn, torch.load(f"preextracted_vectors_wikiart_{vn}/Museum1554-7.unity.pt").shape)

In [None]:
train_dataset = DescriptionSceneMuseum("./tmp_museums/open_clip_features_museums3k/descriptions/sentences", 
                                       "./tmp_museums/open_clip_features_museums3k/descriptions/tokens_strings", 
                                       "./tmp_museums/open_clip_features_museums3k/images",
                                       f"./preextracted_vectors_wikiart_{visual_backbone}",
                                indices, "train")

val_dataset = DescriptionSceneMuseum("./tmp_museums/open_clip_features_museums3k/descriptions/sentences", 
                                       "./tmp_museums/open_clip_features_museums3k/descriptions/tokens_strings", 
                                       "./tmp_museums/open_clip_features_museums3k/images",
                                       f"./preextracted_vectors_wikiart_{visual_backbone}",
                                indices, "val")

test_dataset = DescriptionSceneMuseum("./tmp_museums/open_clip_features_museums3k/descriptions/sentences", 
                                       "./tmp_museums/open_clip_features_museums3k/descriptions/tokens_strings", 
                                       "./tmp_museums/open_clip_features_museums3k/images",
                                       f"./preextracted_vectors_wikiart_{visual_backbone}",
                                indices, "test")
desc, scene, art, raw_desc, name, ix = train_dataset[1]
print(f"The sample #{ix} ({name}) has a description of {len(desc)} sentences (shaped as {desc.shape} matrix), whereas there are {len(scene)} images (shaped as {scene.shape} matrix)")
print(f"Example of raw description (capped at 100 characters): {raw_desc[:100]}")

In [None]:
all_descs = [rd for rd in train_dataset.raw_descs] + [rd for rd in val_dataset.raw_descs] + [rd for rd in test_dataset.raw_descs] 
print("tot", len(all_descs))

n_tokens = [len(rd.split()) for rd in all_descs]
print("avg tokens per museum", sum(n_tokens) / len(n_tokens))
print("num tokens", sum(n_tokens))

import re
_tmp = [rd.split(".") for rd in all_descs]
_tmp = [[t for t in ts if "there are" in t and "painting" in t] for ts in _tmp]
_tmp = [[re.sub(r"In the \w+ room , there are", "", t).strip() for t in ts] for ts in _tmp]
_tmp = [[t for t in ts if len(t.split()) < 3] for ts in _tmp]
_tmp = [[t.replace(" paintings", "") for t in ts] for ts in _tmp]
_mm = {'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6}
_tmp = [[_mm[t] for t in ts] for ts in _tmp]
_x = [sum(t) for t in _tmp]
sum(_x) / len(_x)

In [None]:
def cosine_sim(im, s):
    '''cosine similarity between all the image and sentence pairs
    '''
    inner_prod = im.mm(s.t())
    im_norm = torch.sqrt((im ** 2).sum(1).view(-1, 1) + 1e-18)
    s_norm = torch.sqrt((s ** 2).sum(1).view(1, -1) + 1e-18)
    sim = inner_prod / (im_norm * s_norm)
    return sim


def create_rank(result, entire_descriptor, desired_output_index):
    similarity = torch.nn.functional.cosine_similarity(entire_descriptor, result, dim=1)
    similarity = similarity.squeeze()
    sorted_indices = torch.argsort(similarity, descending=True)
    position = torch.where(sorted_indices == desired_output_index)
    return position[0].item(), sorted_indices


def evaluate(output_description, output_scene, section, out_values=False, excel_format=False):
    avg_rank_scene = 0
    ranks_scene = []
    avg_rank_description = 0
    ranks_description = []

    ndcg_10_list = []
    ndcg_entire_list = []

    for j, i in enumerate(output_scene):
        rank, sorted_list = create_rank(i, output_description, j)
        avg_rank_scene += rank
        ranks_scene.append(rank)

    for j, i in enumerate(output_description):
        rank, sorted_list = create_rank(i, output_scene, j)
        avg_rank_description += rank
        ranks_description.append(rank)

    ranks_scene = np.array(ranks_scene)
    ranks_description = np.array(ranks_description)

    n_q = len(output_scene)
    sd_r1 = 100 * len(np.where(ranks_scene < 1)[0]) / n_q
    sd_r5 = 100 * len(np.where(ranks_scene < 5)[0]) / n_q
    sd_r10 = 100 * len(np.where(ranks_scene < 10)[0]) / n_q
    sd_medr = np.median(ranks_scene) + 1
    sd_meanr = ranks_scene.mean() + 1

    n_q = len(output_description)
    ds_r1 = 100 * len(np.where(ranks_description < 1)[0]) / n_q
    ds_r5 = 100 * len(np.where(ranks_description < 5)[0]) / n_q
    ds_r10 = 100 * len(np.where(ranks_description < 10)[0]) / n_q
    ds_medr = np.median(ranks_description) + 1
    ds_meanr = ranks_description.mean() + 1

    ds_out, sc_out = "", ""
    for mn, mv in [["R@1", ds_r1],
                   ["R@5", ds_r5],
                   ["R@10", ds_r10],
                   ["median rank", ds_medr],
                   ["mean rank", ds_meanr],
                   ]:
        ds_out += f"{mn}: {mv:.4f}   "

    for mn, mv in [("R@1", sd_r1),
                   ("R@5", sd_r5),
                   ("R@10", sd_r10),
                   ("median rank", sd_medr),
                   ("mean rank", sd_meanr),
                   ]:
        sc_out += f"{mn}: {mv:.4f}   "

    if out_values:
        print(section + " data: ")
        print("Scenes ranking: " + ds_out)
        print("Descriptions ranking: " + sc_out)
    if section == "test" and len(ndcg_10_list) > 0:
        avg_ndcg_10_entire = 100 * sum(ndcg_10_list) / len(ndcg_10_list)
        avg_ndcg_entire = 100 * sum(ndcg_entire_list) / len(ndcg_entire_list)
    else:
        avg_ndcg_10_entire = -1
        avg_ndcg_entire = -1
    
    if excel_format:
        print("-"*5)
        print("{ds_r1};{ds_r5};{ds_r10};{sd_r1};{sd_r5};{sd_r10};{ds_medr};{sd_medr}")
        print(f"{ds_r1};{ds_r5};{ds_r10};{sd_r1};{sd_r5};{sd_r10};{ds_medr};{sd_medr}")
        print("-"*5)
        formatted_string = f"{ds_r1};{ds_r5};{ds_r10};{sd_r1};{sd_r5};{sd_r10};{ds_medr};{sd_medr}"
        return ds_r1, ds_r5, ds_r10, sd_r1, sd_r5, sd_r10, avg_ndcg_10_entire, avg_ndcg_entire, ds_medr, sd_medr, formatted_string        
    
    return ds_r1, ds_r5, ds_r10, sd_r1, sd_r5, sd_r10, avg_ndcg_10_entire, avg_ndcg_entire, ds_medr, sd_medr


In [None]:
class LossContrastive:
    def __init__(self, name, patience=15, delta=.001, verbose=True):
        self.train_losses = []
        self.validation_losses = []
        self.name = name
        self.counter_patience = 0
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.verbose = verbose

    def on_epoch_end(self, loss, train=True):
        if train:
            self.train_losses.append(loss)
        else:
            self.validation_losses.append(loss)

    def get_loss_trend(self):
        return self.train_losses, self.validation_losses

    def calculate_loss(self, pairwise_distances, margin=.25, margin_tensor=None):
        batch_size = pairwise_distances.shape[0]
        diag = pairwise_distances.diag().view(batch_size, 1)
        pos_masks = torch.eye(batch_size).bool().to(pairwise_distances.device)
        d1 = diag.expand_as(pairwise_distances)
        if margin_tensor is not None:
            margin_tensor = margin_tensor.to(pairwise_distances.device)
            cost_s = (margin_tensor + pairwise_distances - d1).clamp(min=0)
        else:
            cost_s = (margin + pairwise_distances - d1).clamp(min=0)
        cost_s = cost_s.masked_fill(pos_masks, 0)
        cost_s = cost_s / (batch_size * (batch_size - 1))
        cost_s = cost_s.sum()

        d2 = diag.t().expand_as(pairwise_distances)
        if margin_tensor is not None:
            margin_tensor = margin_tensor.to(pairwise_distances.device)
            cost_d = (margin_tensor + pairwise_distances - d2).clamp(min=0)
        else:
            cost_d = (margin + pairwise_distances - d2).clamp(min=0)
        cost_d = cost_d.masked_fill(pos_masks, 0)
        cost_d = cost_d / (batch_size * (batch_size - 1))
        cost_d = cost_d.sum()

        return (cost_s + cost_d) / 2

    def is_val_improving(self):
        score = -self.validation_losses[-1] if self.validation_losses else None

        if score and self.best_score and self.verbose:
            print('epoch:', len(self.validation_losses), ' score:', -score, ' best_score:', -self.best_score, ' counter:', self.counter_patience)

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter_patience += 1
            if self.counter_patience >= self.patience:
                return False
        else:
            self.best_score = score
            self.counter_patience = 0
        return True

    def save_plots(self):
        save_path = f'models/{self.name}.png'
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.validation_losses, label='Val Loss')

        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Loss Trend')

        plt.legend()

        plt.savefig(save_path)

In [None]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
def collate_fn(data):  # data -> desc_tensor, scene_img_tensor, name, index
    raw_descs = False
    adj = 0
    if len(data[0]) == 6:  # train -> raw descriptions
        raw_descs = True
        adj = 1

    tmp_description_povs = [x[0] for x in data]
    tmp = pad_sequence(tmp_description_povs, batch_first=True)
    descs_pov = pack_padded_sequence(tmp,
                                     torch.tensor([len(x) for x in tmp_description_povs]),
                                     batch_first=True,
                                     enforce_sorted=False)

    tmp_pov = [x[1] for x in data]
    len_pov = torch.tensor([len(x) for x in tmp_pov])
    padded_pov = pad_sequence(tmp_pov, batch_first=True)
    padded_pov = torch.transpose(padded_pov, 1, 2)

    tmp_art = [x[2] for x in data]
    len_art = torch.tensor([len(x) for x in tmp_art])
    padded_art = pad_sequence(tmp_art, batch_first=True)
    padded_art = torch.transpose(padded_art, 1, 2)

    if raw_descs:
        raw_descs = [x[3] for x in data]
    names = [x[3+adj] for x in data]
    indexes = [x[4+adj] for x in data]
    
    if raw_descs:
        return descs_pov, padded_pov, padded_art, raw_descs, names, indexes, len_pov
    else:
        return descs_pov, padded_pov, padded_art, names, indexes, len_pov

In [None]:
def save_best_model(model_name, run_folder, *args):
    model_path = "models"
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(os.path.join(model_path, run_folder), exist_ok=True)
    model_path = os.path.join(model_path, run_folder, model_name + '.pt')
    new_dict = dict()
    for i, bm in enumerate(args):
        new_dict[f'best_model_{str(i)}'] = bm
    torch.save(new_dict, model_path)


def load_best_model(model_name, run_folder):
    model_path = os.path.join("models", run_folder)
    avail_models = [m for m in os.listdir(model_path) if m.startswith(model_name)]
    assert len(avail_models) == 1, avail_models
    model_name_ = avail_models[0]
    model_path = model_path + os.sep + model_name_
    check_point = torch.load(model_path)
    bm_list = [check_point[bm] for bm in check_point.keys()]
    return bm_list

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import torch.nn as nn
    
class MyBaseline(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.trf_photo = nn.Linear(in_channels, out_channels)
        self.trf_mean = nn.Linear(out_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x, list_length=None, clip_mask=None, imgs_per_room=None):
        x = x.to(torch.float32)
        
        x1 = self.trf_photo(x.transpose(1, 2))
        
        if clip_mask is not None:
            x1 = x1 * clip_mask
        # remove the effect of the padding
        if list_length is not None:
            for item_idx in range(x.shape[0]):
                x1[item_idx, list_length[item_idx]:, :] = 0
        x1_img = self.relu(x1)
        
        bsz, max_n_imgs, ft_size = x1_img.shape
        list_length_t = torch.tensor(list_length, device=x1_img.device) if isinstance(list_length, list) else list_length.to(x1_img.device)
        x1_mean = x1_img.sum(1) / list_length_t.unsqueeze(1)
        x1_museum = self.trf_mean(x1_mean)
        return x1_museum


class GRUNet(nn.Module):
    def __init__(self, hidden_size, num_features, is_bidirectional=False):
        super(GRUNet, self).__init__()
        self.gru = nn.GRU(input_size=num_features, hidden_size=hidden_size, batch_first=True,
                          bidirectional=is_bidirectional)
        self.is_bidirectional = is_bidirectional

    def forward(self, x):
        x = x.to(torch.float32)
        _, h_n = self.gru(x)
        if self.is_bidirectional:
            return h_n.mean(0)
        return h_n.squeeze(0)

In [None]:
####### TEST WITH THE BASELINE


import wandb
use_wandb = True
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from torch.nn.utils.rnn import pad_packed_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import time

import random
import string
run_folder = ''.join(random.choices(string.ascii_uppercase + string.digits, k=5))

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)

num_epochs = 50
number_of_tries = 3
kernel_size = 3
final_output_strings = []

output_feature_size = 256 # default: 256
is_bidirectional = True
approach_name = f"mean_pool_baseline"

for n_try in range(number_of_tries):
    lr = 0.001  # default: 0.008
    
    loss_fn = LossContrastive(approach_name, patience=25, delta=0.0001)

    model_desc_pov = GRUNet(hidden_size=output_feature_size, num_features=512, is_bidirectional=is_bidirectional)
    model_pov = MyBaseline(in_channels=512, out_channels=256)

    model_desc_pov.to(device)
    model_pov.to(device)

    params = list(model_desc_pov.parameters()) + list(model_pov.parameters())
    optimizer = torch.optim.Adam(params, lr=lr)
    
    step_size = 27
    gamma = 0.75
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    sched_name = StepLR
    
    # scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0, last_epoch=-1)
    # sched_name = CosineAnnealingLR
    
    # num_training_steps = (len(train_dataset) * num_epochs) // batch_size
    # num_warmup_steps = int(num_training_steps * 0.1)
    # scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    # sched_name = "cosine_schedule_with_warmup"
    
    if use_wandb:
        wandb.init(
            # set the wandb project where this run will be logged
            project="Museums",

            # track hyperparameters and run metadata
            config={**{
                "batch_size": batch_size,
                "kernel_size": kernel_size,
                "learning_rate": lr,
                "architecture": MyBaseline,
                "epochs": num_epochs,
                "approach_name": approach_name,
                "output_feature_size": output_feature_size,
                "scheduler/name": sched_name,
            }, **{f"scheduler/{n}": v for n, v in scheduler.__dict__.items()}}
        )
    
    
    best_r10 = 0
    print(f"({n_try+1}/{number_of_tries}) Train procedure ...")
    for ep in tqdm(range(num_epochs)):
        
        model_desc_pov.train()
        model_pov.train()
        # if not loss_fn.is_val_improving():
        #     print('Early Stopping !!!')
        #     break

        total_loss_train = 0
        total_loss_val = 0
        num_batches_train = 0
        num_batches_val = 0

        output_description_val = torch.empty(len(indices['val']), output_feature_size)
        output_pov_val = torch.empty(len(indices['val']), output_feature_size)

        for i, (data_desc_pov, data_pov, data_art, raw_descs, names, indexes, len_pov) in enumerate(train_loader):
            data_desc_pov = data_desc_pov.to(device)
            data_pov = data_pov.to(device)
            data_art = data_art.to(device)

            optimizer.zero_grad()

            bsz, fts, no_room_times_no_imgs = data_pov.shape

            output_desc_pov = model_desc_pov(data_desc_pov)
            output_pov = model_pov(data_pov, len_pov)

            multiplication_dp = cosine_sim(output_desc_pov, output_pov)

            loss_contrastive = loss_fn.calculate_loss(multiplication_dp)

            loss_contrastive.backward()

            optimizer.step()

            total_loss_train += loss_contrastive.item()
            num_batches_train += 1
            
            # tmp_max_wo_gt = likeness_raw_values.clone()
            # tmp_max_wo_gt.fill_diagonal_(0.)
            if use_wandb:
                wandb.log({
                    "train/loss": loss_contrastive.item(), 
                    # "train/likeness_raw_values_mean": likeness_raw_values.mean().item(),
                    # "train/likeness_raw_values_min": likeness_raw_values.min().item(),
                    # "train/likeness_raw_values_max_without_gt": tmp_max_wo_gt.max().item(),
                    "scheduler/lr": scheduler.get_last_lr()[0]
                })

        scheduler.step()
        print(scheduler.get_last_lr())
        epoch_loss_train = total_loss_train / num_batches_train

        model_desc_pov.eval()
        model_pov.eval()
        # Validation Procedure
        with torch.no_grad():
            for j, (data_desc_pov, data_pov, data_art, names, indexes, len_pov) in enumerate(val_loader):

                data_desc_pov = data_desc_pov.to(device)
                data_pov = data_pov.to(device)
                data_art = data_art.to(device)

                bsz, fts, no_room_times_no_imgs = data_pov.shape

                output_desc_pov = model_desc_pov(data_desc_pov)
                output_pov = model_pov(data_pov, len_pov)

                initial_index = j * batch_size
                final_index = (j + 1) * batch_size
                if final_index > len(indices['val']):
                    final_index = len(indices['val'])

                output_description_val[initial_index:final_index, :] = output_desc_pov
                output_pov_val[initial_index:final_index, :] = output_pov

                multiplication_dp = cosine_sim(output_desc_pov, output_pov)

                loss_contrastive = loss_fn.calculate_loss(multiplication_dp)

                total_loss_val += loss_contrastive.item()
                num_batches_val += 1
                if use_wandb:
                    wandb.log({
                        "val/batch_loss": loss_contrastive.item(), 
                    })

            epoch_loss_val = total_loss_val / num_batches_val
            if use_wandb:
                wandb.log({
                    "val/epoch_loss": epoch_loss_val, 
                })

    #         print('Loss Train', epoch_loss_train)
    #         loss_fn.on_epoch_end(epoch_loss_train, train=True)
    #         print('Loss Val', epoch_loss_val)
    #         loss_fn.on_epoch_end(epoch_loss_val, train=False)

        r1, r5, r10, _, _, _, _, _, _, _ = evaluate(output_description=output_description_val,
                                                                  output_scene=output_pov_val, section='val',
                                                   out_values=ep % 5 == 4)
        
        if r10 > best_r10:
            best_r10 = r10
            save_best_model(f"{approach_name}_{n_try}", run_folder, model_pov.state_dict(), model_desc_pov.state_dict())
            
        if use_wandb:
            wandb.log({
                "val/T2S_R@1": r1, 
                "val/T2S_R@5": r5, 
                "val/T2S_R@10": r10, 
            })

        # Validation ON TRAIN Procedure
        output_description_val_train = torch.empty(len(indices['train']), output_feature_size)
        output_pov_val_train = torch.empty(len(indices['train']), output_feature_size)

        with torch.no_grad():
            for j, (data_desc_pov, data_pov, data_art, raw_descs, names, indexes, len_pov) in enumerate(train_loader):

                data_desc_pov = data_desc_pov.to(device)
                data_pov = data_pov.to(device)
                data_art = data_art.to(device)

                bsz, fts, no_room_times_no_imgs = data_pov.shape

                output_desc_pov = model_desc_pov(data_desc_pov)
                output_pov = model_pov(data_pov, len_pov)

                initial_index = j * batch_size
                final_index = (j + 1) * batch_size
                if final_index > len(indices['train']):
                    final_index = len(indices['train'])

                output_description_val_train[initial_index:final_index, :] = output_desc_pov
                output_pov_val_train[initial_index:final_index, :] = output_pov

                multiplication_dp = cosine_sim(output_desc_pov, output_pov)

                loss_contrastive = loss_fn.calculate_loss(multiplication_dp)

                total_loss_val += loss_contrastive.item()
                num_batches_val += 1

            epoch_loss_val = total_loss_val / num_batches_val

            print('Loss Train', epoch_loss_train)
            loss_fn.on_epoch_end(epoch_loss_train, train=True)
            print('Loss Val', epoch_loss_val)
            loss_fn.on_epoch_end(epoch_loss_val, train=False)

        r1, r5, r10, _, _, _, _, _, _, _ = evaluate(output_description=output_description_val_train,
                                                                  output_scene=output_pov_val_train, section='TRAIN',
                                                   out_values=ep % 5 == 4)
        if use_wandb:
            wandb.log({
                "train/T2S_R@1": r1, 
                "train/T2S_R@5": r5, 
                "train/T2S_R@10": r10, 
            })
        

    bm_pov, bm_desc_pov = load_best_model(f"{approach_name}_{n_try}", run_folder)
    model_pov.load_state_dict(bm_pov)
    model_desc_pov.load_state_dict(bm_desc_pov)

    test_names = list()
    model_pov.eval()
    model_desc_pov.eval()
    output_description_test = torch.empty(len(indices['test']), output_feature_size)
    output_pov_test = torch.empty(len(indices['test']), output_feature_size)
    # Evaluate test set
    with torch.no_grad():
        for j, (data_desc_pov, data_pov, data_art, names, indexes, len_pov) in enumerate(test_loader):

            data_desc_pov = data_desc_pov.to(device)
            data_pov = data_pov.to(device)
            data_art = data_art.to(device)

            test_names.extend(names)

            bsz, fts, no_room_times_no_imgs = data_pov.shape

            output_desc_pov = model_desc_pov(data_desc_pov)
            output_pov = model_pov(data_pov, len_pov)

            initial_index = j * batch_size
            final_index = (j + 1) * batch_size
            if final_index > len(indices['test']):
                final_index = len(indices['test'])
            output_description_test[initial_index:final_index, :] = output_desc_pov
            output_pov_test[initial_index:final_index, :] = output_pov
    ds1, ds5, ds10, sd1, sd5, sd10, ndgc_10, ndcg, ds_medr, sd_medr, formatted_string = evaluate(
        output_description=output_description_test,
        output_scene=output_pov_test,
        section="test",
        out_values=True,
        excel_format=True)
    
    if use_wandb:
        wandb.log({
            "test/T2S_R@1": ds1, 
            "test/T2S_R@5": ds5, 
            "test/T2S_R@10": ds10,
            "test/S2T_R@1": sd1, 
            "test/S2T_R@5": sd5, 
            "test/S2T_R@10": sd10, 
        })
    
    final_output_strings.append(formatted_string)

if use_wandb:
    wandb.finish()
for out_str in final_output_strings:
    print(out_str)