# Test models

In [1]:
from datasets import PlacesAudio
import sys
from opts import get_arguments
from utils import util as u
from torch.utils.data import Dataset, DataLoader, dataset
from torch.autograd import Variable
import time
import torch
from tqdm import tqdm


In [4]:
def validate(val_loader, model, criterion, device, epoch, args):
    batch_time = u.AverageMeter()
    losses = u.AverageMeter()
    A_r10_meter = u.AverageMeter()
    A_r5_meter = u.AverageMeter()
    A_r1_meter = u.AverageMeter()

    I_r10_meter = u.AverageMeter()
    I_r5_meter = u.AverageMeter()
    I_r1_meter = u.AverageMeter()

    img_embs_all = []
    aud_embs_all = []

    tic = time.time()

    model.eval()
    
    with torch.no_grad():
        end = time.time()
        for idx, (image, spec, audio, name, im) in tqdm(enumerate(val_loader), total=len(val_loader)):

            spec = Variable(spec).to(device, non_blocking=True)
            image = Variable(image).to(device, non_blocking=True)
            B = image.size(0)

            imgs_out, auds_out = model(image.float(), spec.float(), args, mode='val')

            imgs_out = imgs_out.detach()
            auds_out = auds_out.detach()

            loss_cl = u.infoNCE_loss(imgs_out, auds_out, args)

            img_embs_all.append(imgs_out)
            aud_embs_all.append(auds_out)
            
            losses.update(loss_cl.item(),B)

            batch_time.update(time.time() - end)
            end = time.time()

    imgs_out_all = torch.cat(img_embs_all)
    auds_out_all = torch.cat(aud_embs_all)

    sims =u.similarity_matrix_bxb(imgs_out_all,auds_out_all)
    
    recalls      = u.topk_accuracies(sims, [1,5,10])
    A_r10 = recalls["A_r10"]
    A_r5  = recalls["A_r5"]
    A_r1  = recalls["A_r1"]
    I_r10   = recalls["I_r10"]
    I_r5    = recalls["I_r5"]
    I_r1    = recalls["I_r1"]

    N_examples= len(val_loader) * B
    
    print('Epoch: [{0}]\t Eval '
          'Loss: {loss.avg:.4f}  \t T-epoch: {t:.2f} \t'
          .format(epoch, loss=losses, t=time.time()-tic))
    
    print(' * Audio R@10 {A_r10:.3f} Image R@10 {I_r10:.3f} over {N:d} validation pairs'
        .format(A_r10=A_r10, I_r10=I_r10, N=N_examples), flush=True)
    print(' * Audio R@5 {A_r5:.3f} Image R@5 {I_r5:.3f} over {N:d} validation pairs'
        .format(A_r5=A_r5, I_r5=I_r5, N=N_examples), flush=True)
    print(' * Audio R@1 {A_r1:.3f} Image R@1 {I_r1:.3f} over {N:d} validation pairs'
        .format(A_r1=A_r1, I_r1=I_r1, N=N_examples), flush=True)
    
    return losses.avg, recalls

In [5]:
# Simulate command-line arguments for loading the model
sys.argv = ['script_name', '--order_3_tensor', 
            '--simtype', 'MISA', 
            '--placesAudio', '$DATA/PlacesAudio_400k_distro/metadata/',
            '--batch_size', "32", 
            '--n_threads', '0',
            '--val_video_idx', '10']

args = get_arguments()

models = [
    {"model_name": "SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV", "epoch": 13},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA", "epoch": 39},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA", "epoch": 100},
]

# remote_path = f'/home/asantos/models/to_test/{model["model_name"]}-epoch{model["epoch"]}.pth.tar'
# u.load_model()
model_path = models[0]
model_path = f'/home/asantos/models/to_test/{model_path["model_name"]}-epoch{model_path["epoch"]}.pth.tar'

model,device = u.load_model(model_path,args)

val_dataset = PlacesAudio(args.placesAudio + 'val.json', args,mode='val')

val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,\
        num_workers=args.n_threads, drop_last=True, pin_memory=True)

results = {}
for model_info in models:
    model_path = f'/home/asantos/models/to_test/{model_info["model_name"]}-epoch{model_info["epoch"]}.pth.tar'
    model, device = u.load_model(model_path, args)
    print(f'\nValidating model: {model_info["model_name"]}, Epoch: {model_info["epoch"]}')
    loss, recalls = validate(val_loader, model, None, device, epoch=model_info["epoch"], args=args)
    
    results[model_info["model_name"]] = {
        "epoch": model_info["epoch"],
        "loss": loss,
        "A_r10": recalls["A_r10"],
        "A_r5": recalls["A_r5"],
        "A_r1": recalls["A_r1"],
        "I_r10": recalls["I_r10"],
        "I_r5": recalls["I_r5"],
        "I_r1": recalls["I_r1"]
    }

PlacesAudio split: VAL dataset size: 1000

Validating model: SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV, Epoch: 13


100%|██████████| 31/31 [01:07<00:00,  2.18s/it]


Epoch: [13]	 Eval Loss: 3.9289  	 T-epoch: 95.44 	
 * Audio R@10 0.019 Image R@10 0.030 over 992 validation pairs
 * Audio R@5 0.008 Image R@5 0.016 over 992 validation pairs
 * Audio R@1 0.002 Image R@1 0.004 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA, Epoch: 100


100%|██████████| 31/31 [00:55<00:00,  1.80s/it]


Epoch: [100]	 Eval Loss: 3.3159  	 T-epoch: 84.49 	
 * Audio R@10 0.075 Image R@10 0.085 over 992 validation pairs
 * Audio R@5 0.048 Image R@5 0.064 over 992 validation pairs
 * Audio R@1 0.009 Image R@1 0.014 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA, Epoch: 100


100%|██████████| 31/31 [01:06<00:00,  2.14s/it]


Epoch: [100]	 Eval Loss: 4.5286  	 T-epoch: 94.78 	
 * Audio R@10 0.014 Image R@10 0.035 over 992 validation pairs
 * Audio R@5 0.008 Image R@5 0.019 over 992 validation pairs
 * Audio R@1 0.001 Image R@1 0.003 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA, Epoch: 100


100%|██████████| 31/31 [00:57<00:00,  1.85s/it]


Epoch: [100]	 Eval Loss: 4.5203  	 T-epoch: 88.32 	
 * Audio R@10 0.017 Image R@10 0.025 over 992 validation pairs
 * Audio R@5 0.012 Image R@5 0.014 over 992 validation pairs
 * Audio R@1 0.004 Image R@1 0.003 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA, Epoch: 100


100%|██████████| 31/31 [00:54<00:00,  1.75s/it]


Epoch: [100]	 Eval Loss: 5.3987  	 T-epoch: 86.47 	
 * Audio R@10 0.011 Image R@10 0.030 over 992 validation pairs
 * Audio R@5 0.006 Image R@5 0.014 over 992 validation pairs
 * Audio R@1 0.001 Image R@1 0.001 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA, Epoch: 39


100%|██████████| 31/31 [00:55<00:00,  1.78s/it]


Epoch: [39]	 Eval Loss: 3.4657  	 T-epoch: 86.24 	
 * Audio R@10 0.010 Image R@10 0.010 over 992 validation pairs
 * Audio R@5 0.005 Image R@5 0.005 over 992 validation pairs
 * Audio R@1 0.001 Image R@1 0.001 over 992 validation pairs

Validating model: SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA, Epoch: 100


100%|██████████| 31/31 [00:53<00:00,  1.73s/it]


Epoch: [100]	 Eval Loss: 3.4657  	 T-epoch: 85.30 	
 * Audio R@10 0.010 Image R@10 0.010 over 992 validation pairs
 * Audio R@5 0.005 Image R@5 0.005 over 992 validation pairs
 * Audio R@1 0.001 Image R@1 0.001 over 992 validation pairs


In [8]:
import pandas as pd

df = pd.DataFrame(results)
df = df.transpose()
df.sort_values(by='loss', ascending=False, inplace=True)
df.to_csv('garbage/results.csv', index=True)
df

Unnamed: 0,epoch,loss,A_r10,A_r5,A_r1,I_r10,I_r5,I_r1
SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA,100.0,5.398741,0.011089,0.006048,0.001008,0.030242,0.014113,0.001008
SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA,100.0,4.528646,0.014113,0.008065,0.001008,0.035282,0.019153,0.003024
SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA,100.0,4.520318,0.017137,0.012097,0.004032,0.025202,0.014113,0.003024
SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV,13.0,3.928914,0.019153,0.008065,0.002016,0.030242,0.016129,0.004032
SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA,39.0,3.465736,0.010081,0.00504,0.001008,0.010081,0.00504,0.001008
SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA,100.0,3.465736,0.010081,0.00504,0.001008,0.010081,0.00504,0.001008
SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA,100.0,3.315941,0.074597,0.048387,0.009073,0.084677,0.063508,0.014113
