# Test models

In [1]:
from datasets import PlacesAudio
import sys
from opts import get_arguments
from utils import util as u
from torch.utils.data import Dataset, DataLoader, dataset
from torch.autograd import Variable
import time
import torch
from tqdm import tqdm


In [None]:
def validate(val_loader, model, criterion, device, epoch, args):
    batch_time = u.AverageMeter()
    losses = u.AverageMeter()
    A_r10_meter = u.AverageMeter()
    A_r5_meter = u.AverageMeter()
    A_r1_meter = u.AverageMeter()

    I_r10_meter = u.AverageMeter()
    I_r5_meter = u.AverageMeter()
    I_r1_meter = u.AverageMeter()

    img_embs_all = []
    aud_embs_all = []

    tic = time.time()

    model.eval()
    
    with torch.no_grad():
        end = time.time()
        for idx, (image, spec, audio, name, im) in tqdm(enumerate(val_loader), total=len(val_loader)):

            spec = Variable(spec).to(device, non_blocking=True)
            image = Variable(image).to(device, non_blocking=True)
            B = image.size(0)

            imgs_out, auds_out = model(image.float(), spec.float(), args, mode='val')

            imgs_out = imgs_out.detach()
            auds_out = auds_out.detach()

            loss_cl = u.infoNCE_loss(imgs_out, auds_out, args)

            img_embs_all.append(imgs_out)
            aud_embs_all.append(auds_out)
            
            losses.update(loss_cl.item(),B)

            batch_time.update(time.time() - end)
            end = time.time()

    imgs_out_all = torch.cat(img_embs_all)
    auds_out_all = torch.cat(aud_embs_all)

    sims =u.similarity_matrix_bxb(imgs_out_all,auds_out_all)
    
    recalls      = u.topk_accuracies(sims, [1,5,10])
    A_r10 = recalls["A_r10"]
    A_r5  = recalls["A_r5"]
    A_r1  = recalls["A_r1"]
    I_r10   = recalls["I_r10"]
    I_r5    = recalls["I_r5"]
    I_r1    = recalls["I_r1"]

    N_examples= len(val_loader)
    
    print('Epoch: [{0}]\t Eval '
          'Loss: {loss.avg:.4f}  \t T-epoch: {t:.2f} \t'
          .format(epoch, loss=losses, t=time.time()-tic))
    
    print(' * Audio R@10 {A_r10:.3f} Image R@10 {I_r10:.3f} over {N:d} validation pairs'
        .format(A_r10=A_r10, I_r10=I_r10, N=N_examples), flush=True)
    print(' * Audio R@5 {A_r5:.3f} Image R@5 {I_r5:.3f} over {N:d} validation pairs'
        .format(A_r5=A_r5, I_r5=I_r5, N=N_examples), flush=True)
    print(' * Audio R@1 {A_r1:.3f} Image R@1 {I_r1:.3f} over {N:d} validation pairs'
        .format(A_r1=A_r1, I_r1=I_r1, N=N_examples), flush=True)
    
    return losses.avg, recalls

In [None]:
# Simulate command-line arguments for loading the model
sys.argv = ['script_name', '--order_3_tensor', 
            '--simtype', 'MISA', 
            '--placesAudio', '$DATA/PlacesAudio_400k_distro/metadata/',
            '--batch_size', "32", 
            '--n_threads', '0',
            '--val_video_idx', '10']

args = get_arguments()

models = [
    {"model_name": "SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV", "epoch": 13},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA", "epoch": 39},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA", "epoch": 100},
]

# remote_path = f'/home/asantos/models/to_test/{model["model_name"]}-epoch{model["epoch"]}.pth.tar'
# u.load_model()
model_path = models[0]
model_path = f'/home/asantos/models/to_test/{model_path["model_name"]}-epoch{model_path["epoch"]}.pth.tar'

model,device = u.load_model(model_path,args)

val_dataset = PlacesAudio(args.placesAudio + 'val.json', args,mode='val')

val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,\
        num_workers=args.n_threads, drop_last=True, pin_memory=True)

results = {}
for model_info in models:
    model_path = f'/home/asantos/models/to_test/{model_info["model_name"]}-epoch{model_info["epoch"]}.pth.tar'
    model, device = u.load_model(model_path, args)
    print(f'Validating model: {model_info["model_name"]}, Epoch: {model_info["epoch"]}')
    loss, recalls = validate(val_loader, model, None, device, epoch=model_info["epoch"], args=args)
    
    results[model_info["model_name"]] = {
        "epoch": model_info["epoch"],
        "loss": loss,
        "A_r10": recalls["A_r10"],
        "A_r5": recalls["A_r5"],
        "A_r1": recalls["A_r1"],
        "I_r10": recalls["I_r10"],
        "I_r5": recalls["I_r5"],
        "I_r1": recalls["I_r1"]
    }

NameError: name 'sys' is not defined

In [8]:
import pandas as pd

df = pd.DataFrame(results)
df = df.transpose()
df.sort_values(by='loss', ascending=False, inplace=True)
df

Unnamed: 0,epoch,loss,acc_a_v,acc_v_a
SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA,100.0,5.398741,0.275202,0.183468
SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA,100.0,4.528646,0.270161,0.232863
SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA,100.0,4.520318,0.257056,0.221774
SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV,13.0,3.928914,0.287298,0.215726
SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA,39.0,3.465736,0.15625,0.15625
SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA,100.0,3.465736,0.15625,0.15625
SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA,100.0,3.315941,0.478831,0.439516
