In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
from d2lvit import *
import copy
from collections import OrderedDict

In [2]:
unused_digit_model = nn.Linear(1,1)

In [3]:
model = models.vit_b_16(weights=None)

In [4]:
train_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_4k'), m.resnet_preprocess()) 
val_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_2k'), m.resnet_preprocess()) 
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()

Train size: 4000 Val size: 2000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [5]:
model.heads = nn.Sequential(OrderedDict([('head', nn.Linear(768, 45))]))
model = model.to(device)

In [6]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.9964, -0.2394,  0.6590, -0.7964,  0.5954, -0.8024, -0.5738,  0.1781,
         -0.5073, -1.0062,  0.2076, -1.6032, -0.9090, -0.4220,  0.6519,  0.3125,
         -0.2106,  0.9760, -0.4351, -0.0734, -0.1038, -0.2923,  0.1310, -0.0565,
         -0.0155,  1.0305,  0.1852, -0.0431,  0.7668,  0.7625,  0.3871, -0.9237,
         -0.4097, -0.7336,  0.1342,  0.0116,  0.7653, -0.1571,  0.7928,  0.2072,
          0.0121,  0.7058, -0.3350, -0.2559,  0.7048]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [7]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'untrained_torchvit_noprims_scratchtrees'), unused_digit_model, False, False)

Epoch 10 done, train loss: 0.0054 val acc: 0.7670
Epoch 20 done, train loss: 0.0014 val acc: 0.8810
Epoch 30 done, train loss: 0.0014 val acc: 0.8895
Epoch 40 done, train loss: 0.0009 val acc: 0.9180
Epoch 50 done, train loss: 0.0006 val acc: 0.9445
Epoch 60 done, train loss: 0.0000 val acc: 0.9610
Epoch 70 done, train loss: 0.0000 val acc: 0.9625
Epoch 80 done, train loss: 0.0000 val acc: 0.9595
Epoch 90 done, train loss: 0.0000 val acc: 0.9595
Epoch 100 done, train loss: 0.0000 val acc: 0.9595


In [8]:
model = torch.load(os.path.join('..', 'models', 'untrained_torchvit_noprims_scratchtrees', 'tree-model.pt'))

In [9]:
config = {'labels_key': 'tree_label'}
train_acc = m.predict(model, train_loader, device, config, unused_digit_model, True, False)
print(train_acc)

1.0


In [10]:
val_acc = m.predict(model, val_loader, device, config, unused_digit_model, True, False)
print(val_acc)

tree_10: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_21: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_31: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_39: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_90: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_113: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 

tree_1334: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1346: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1359: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1418: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1432: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1436: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0