In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
from d2lvit import *
import copy
from collections import OrderedDict

In [2]:
d2l_digit_model = torch.load(os.path.join('..', 'models', 'd2lvit_3', 'digit-model.pt'))
torch_digit_model = torch.load(os.path.join('..', 'models', 'finetune_3', 'digit-model.pt'))
untrained_digit_model = torch.load(os.path.join('..', 'models', 'untrained_torchvit', 'digit-model.pt'))
unused_digit_model = nn.Linear(1,1)

d2l_tree_model = torch.load(os.path.join('..', 'models', 'd2lvit_3', 'tree-model.pt'))
torch_tree_model = torch.load(os.path.join('..', 'models', 'finetune_3', 'tree-model.pt'))
untrained_tree_model = torch.load(os.path.join('..', 'models', 'untrained_torchvit', 'tree-model.pt'))

d2l_tree_noprims_model = torch.load(os.path.join('..', 'models', 'd2lvit_noprims', 'tree-model.pt'))
torch_tree_noprims_model = torch.load(os.path.join('..', 'models', 'finetune_noprims', 'tree-model.pt'))
untrained_tree_noprims_model = torch.load(os.path.join('..', 'models', 'untrained_torchvit_noprims', 'tree-model.pt'))

d2l_scratchtree_model = torch.load(os.path.join('..', 'models', 'd2lvit_scratchtrees', 'tree-model.pt'))
torch_scratchtree_model = torch.load(os.path.join('..', 'models', 'finetune_scratchtrees', 'tree-model.pt'))
untrained_scratchtree_model = torch.load(os.path.join('..', 'models', 'untrained_torchvit_noprims_scratchtrees', 'tree-model.pt'))

In [15]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
torch_test_set = TreeDataset(os.path.join('..', 'data', 'super_variety_2k'), m.resnet_preprocess()) 
d2l_test_set = TreeDataset(os.path.join('..', 'data', 'super_variety_2k'), preprocess)
print(f'Test size: {len(torch_test_set)}')
torch_loader = DataLoader(torch_test_set, batch_size=32)
d2l_loader = DataLoader(d2l_test_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'tree_label'}

Test size: 2000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [4]:
d2l_acc = m.predict(d2l_tree_model, d2l_loader, device, config, d2l_digit_model, False)
print(d2l_acc)

0.1135


In [5]:
torch_acc = m.predict(torch_tree_model, torch_loader, device, config, torch_digit_model)
print(torch_acc)

0.1135


In [6]:
untrained_acc = m.predict(untrained_tree_model, torch_loader, device, config, untrained_digit_model, False)
print(untrained_acc)

0.113


In [7]:
d2l_noprims_acc = m.predict(d2l_tree_noprims_model, d2l_loader, device, config, d2l_digit_model, False, False)
print(d2l_noprims_acc)

0.113


In [8]:
torch_noprims_acc = m.predict(torch_tree_noprims_model, torch_loader, device, config, torch_digit_model, False, False)
print(torch_noprims_acc)

0.1135


In [9]:
untrained_noprims_acc = m.predict(untrained_tree_noprims_model, torch_loader, device, config, untrained_digit_model, False, False)
print(untrained_noprims_acc)

0.1125


In [10]:
d2l_scratchtree_acc = m.predict(d2l_scratchtree_model, d2l_loader, device, config, unused_digit_model, False, False)
print(d2l_scratchtree_acc)

0.1135


In [11]:
torch_scratchtree_acc = m.predict(torch_scratchtree_model, torch_loader, device, config, unused_digit_model, False, False)
print(torch_scratchtree_acc)

0.1135


In [12]:
untrained_scratchtree_acc = m.predict(untrained_scratchtree_model, torch_loader, device, config, unused_digit_model, False, False)
print(untrained_scratchtree_acc)

0.1125


In [16]:
trained_torch_digit_model = torch.load(os.path.join('..', 'models', 'finetune_super4', 'digit-model.pt'))
trained_torch_tree_model = torch.load(os.path.join('..', 'models', 'finetune_super4', 'tree-model.pt'))
print(m.predict(trained_torch_digit_model, torch_loader, device, {'labels_key': 'digit_labels'}, None))
print(m.predict(trained_torch_tree_model, torch_loader, device, config, trained_torch_digit_model))

0.8545
0.8485
