# Test models

In [None]:
from datasets import PlacesAudio
import sys
from opts import get_arguments
from utils import util as u
from torch.utils.data import Dataset, DataLoader, dataset
from torch.autograd import Variable
import time
import torch
from tqdm import tqdm


In [None]:
def validate(val_loader, model, criterion, device, epoch, args):
    batch_time = u.AverageMeter()
    losses = u.AverageMeter()
    acc_a_vAverage = u.AverageMeter()
    acc_v_aAverage = u.AverageMeter() 

    tic = time.time()

    model.eval()
    
    with torch.no_grad():
        end = time.time()
        for idx, (image, spec, audio, name, im) in tqdm(enumerate(val_loader), total=len(val_loader)):

            spec = Variable(spec).to(device, non_blocking=True)
            image = Variable(image).to(device, non_blocking=True)
            B = image.size(0)

            imgs_out, auds_out = model(image.float(), spec.float(), args, mode='val')
                            
            loss_cl,sims = u.infoNCE_loss(imgs_out,auds_out, args,return_S=True)
            acc_v_a, acc_a_v =u.topk_accuracy(sims,k=5)

            losses.update(loss_cl.item(), B)
            acc_a_vAverage.update(acc_a_v)
            acc_v_aAverage.update(acc_v_a)
            batch_time.update(time.time() - end)
            end = time.time()


    print('Epoch: [{0}]\t Eval '
          'Loss: {loss.avg:.4f}  \t T-epoch: {t:.2f} \t'
          .format(epoch, loss=losses, t=time.time()-tic))
    return losses.avg, 0, 0

In [None]:
# Simulate command-line arguments for loading the model
sys.argv = ['script_name', '--order_3_tensor', 
            '--simtype', 'MISA', 
            '--placesAudio', '$DATA/PlacesAudio_400k_distro/metadata/',
            '--batch_size', "32", 
            '--n_threads', '0']

args = get_arguments()

models = [
    {"model_name": "SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV", "epoch": 13},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B32-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B256-MISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-5-2ly-B128-SISA", "epoch": 100},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-3-2ly-B128-SISA", "epoch": 39},
    {"model_name": "SSL_TIE_PlacesAudio_lr1e-4-2ly-B32-MISA", "epoch": 100},
]

# remote_path = f'/home/asantos/models/to_test/{model["model_name"]}-epoch{model["epoch"]}.pth.tar'
# u.load_model()
model_path = models[0]
model_path = f'/home/asantos/models/to_test/{model_path["model_name"]}-epoch{model_path["epoch"]}.pth.tar'
model_path

model,device = u.load_model(model_path,args)

val_dataset = PlacesAudio(args.placesAudio + 'val.json', args,mode='val')

val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,\
        num_workers=args.n_threads, drop_last=True, pin_memory=True)

validate(val_loader,model,None,device,epoch=100, args=args)



'/home/asantos/models/to_test/SSL_TIE_PlacesAudio-lr1e-5-2ly-B128-SISA-1GPUS-wV-epoch13.pth.tar'