# Test with hold-out set

In [None]:
import os
import glob
import json
from os.path import join

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random
import json

from ShellRec.data_utils.prepare_photos import get_img_graph, split_graph
from ShellRec.model import TurtleDiff 
from ShellRec.dataset import TurtlePair 
from ShellRec.inference import test_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   

In [None]:
transform_test = transforms.Compose([
        transforms.Resize((384,384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5,.5,.5], 
                             std=[0.5, 0.5, 0.5])
])

test_set = TurtlePair(data_file='../dataset/test.json',transform=transform_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

holdout_set = TurtlePair(data_file='../dataset/BoxTurtle_holdout.json',transform=transform_test)
holdout_loader = DataLoader(holdout_set, batch_size=32, shuffle=False)

## ViT only head

### pooled head

In [None]:
from ShellRec.model import TurtleDiffPool
model = TurtleDiffPool('vit_base_patch16_384').to(device)
model.load_state_dict(torch.load('vit_base_patch16_384-pool_turtle_identifier.pth', 
                                 map_location = device))

model.eval()
test_results = test_model(model, test_loader, device)
holdout_results = test_model(model, holdout_loader, device)
with open('vit_base_patch16_384_pool_test.json', 'w') as f:
    json.dump(test_results, f)

with open('vit_base_patch16_384_pool_holdout.json', 'w') as f:
    json.dump(holdout_results, f)


### simple difference head

In [None]:
from ShellRec.model import TurtleDiff
model = TurtleDiff('vit_base_patch16_384').to(device)
model.load_state_dict(torch.load('vit_base_patch16_384_turtle_identifier.pth', 
                                 map_location = device))

model.eval()
test_results = test_model(model, test_loader, device)
holdout_results = test_model(model, holdout_loader, device)
with open('vit_base_patch16_384_test.json', 'w') as f:
    json.dump(test_results, f)

with open('vit_base_patch16_384_holdout.json', 'w') as f:
    json.dump(holdout_results, f)

## ViT with last attention 

In [None]:
from ShellRec.model import TurtleDiffAttnPool
model = TurtleDiffAttnPool('vit_base_patch16_384').to(device)
model.load_state_dict(torch.load('vit_base_patch16_384-attn-pool_turtle_identifier.pth', 
                                 map_location = device))

model.eval()
test_results = test_model(model, test_loader, device)
holdout_results = test_model(model, holdout_loader, device)
with open('vit_base_patch16_384_attn_pool_test.json', 'w') as f:
    json.dump(test_results, f)

with open('vit_base_patch16_384_attn_pool_holdout.json', 'w') as f:
    json.dump(holdout_results, f)