In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights

In [2]:
model = models.vit_b_16(weights="ViT_B_16_Weights.IMAGENET1K_V1")

In [3]:
from tree_dataset import TreeDataset, show_img
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from collections import OrderedDict

In [4]:
train_set = TreeDataset(os.path.join('..', 'data', 'super_variety_4k'), transforms.Compose([
    transforms.RandAugment(),
    m.resnet_preprocess()
])) 
val_set = TreeDataset(os.path.join('..', 'data', 'super_variety_1k'), m.resnet_preprocess()) 
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}
model.heads = nn.Sequential(OrderedDict([('head', nn.Linear(768, 10))]))
m.set_torch_vit_dropouts(model, 0.1)
model = model.to(device)

Train size: 4000 Val size: 1000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [5]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.3996, -0.3002,  0.7401, -0.2059, -0.3082,  0.4045,  0.1807,  0.0325,
          0.2929, -0.3336]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [6]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'finetune_super4'), None)

Epoch 10 done, train loss: 0.5159 val acc: 0.1750
Epoch 20 done, train loss: 0.4455 val acc: 0.2830
Epoch 30 done, train loss: 0.3140 val acc: 0.4940
Epoch 40 done, train loss: 0.2104 val acc: 0.6700
Epoch 50 done, train loss: 0.1646 val acc: 0.7410
Epoch 60 done, train loss: 0.1447 val acc: 0.7780
Epoch 70 done, train loss: 0.1189 val acc: 0.8040
Epoch 80 done, train loss: 0.1106 val acc: 0.8250
Epoch 90 done, train loss: 0.1008 val acc: 0.8450
Epoch 100 done, train loss: 0.0933 val acc: 0.8170


In [7]:
model = torch.load(os.path.join('..', 'models', 'finetune_super4', 'digit-model.pt'))
final_model = torch.load(os.path.join('..', 'models', 'finetune_super4', 'final-digit-model.pt'))

In [8]:
train_acc = m.predict(model, train_loader, device, config, None, True)
print(train_acc)

tree_4: tensor([1, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_9: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1], device='cuda:0')
tree_21: tensor([0, 1, 1, 1, 0, 1, 0, 0, 0, 0], device='cuda:0')
tree_25: tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tree_34: tensor([0, 1, 0, 1, 1, 1, 0, 1, 1, 1], device='cuda:0')
tree_35: tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_36: tensor([1, 0, 0, 0, 1, 1, 1, 0, 0, 1], device='cuda:0')
tree_37: tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_43: tensor([1, 1, 1, 1, 0, 1, 0, 0, 0, 0], device='cuda:0')
tree_46: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1], device='cuda:0')
tree_64: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_66: tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_71: tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 1], device='cuda:0')
tree_75: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_81: tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_85: tensor([0, 1, 0, 1

tree_965: tensor([1, 0, 0, 0, 1, 1, 1, 0, 0, 0], device='cuda:0')
tree_967: tensor([0, 0, 1, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_980: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 0], device='cuda:0')
tree_981: tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 1], device='cuda:0')
tree_997: tensor([1, 1, 0, 0, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_1004: tensor([1, 0, 1, 1, 1, 0, 0, 0, 0, 0], device='cuda:0')
tree_1006: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_1015: tensor([0, 0, 0, 0, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_1018: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_1030: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_1034: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_1049: tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tree_1056: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_1058: tensor([0, 0, 0, 1, 1, 1, 1, 0, 0, 0], device='cuda:0')
tree_1059: tensor([0, 0, 0, 1, 1, 0, 0, 0, 0, 1], device='cuda:0')


tree_2030: tensor([1, 1, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_2053: tensor([0, 0, 1, 0, 0, 0, 1, 0, 0, 1], device='cuda:0')
tree_2054: tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_2064: tensor([0, 1, 1, 1, 1, 0, 0, 0, 0, 1], device='cuda:0')
tree_2073: tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tree_2081: tensor([0, 0, 0, 0, 1, 0, 1, 0, 1, 0], device='cuda:0')
tree_2085: tensor([1, 0, 1, 1, 1, 1, 0, 0, 0, 1], device='cuda:0')
tree_2100: tensor([0, 1, 0, 1, 0, 1, 1, 1, 0, 0], device='cuda:0')
tree_2114: tensor([0, 1, 0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tree_2115: tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_2130: tensor([0, 0, 1, 1, 1, 1, 1, 0, 0, 0], device='cuda:0')
tree_2135: tensor([0, 1, 1, 1, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_2136: tensor([1, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_2137: tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 0], device='cuda:0')
tree_2141: tensor([0, 1, 0, 1, 1, 1, 1, 1, 0, 0], device='cuda

tree_3008: tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_3016: tensor([0, 1, 1, 1, 1, 0, 0, 0, 0, 0], device='cuda:0')
tree_3038: tensor([0, 1, 1, 1, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3047: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3048: tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tree_3051: tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_3068: tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_3070: tensor([1, 0, 0, 1, 1, 1, 1, 0, 1, 0], device='cuda:0')
tree_3077: tensor([1, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3078: tensor([0, 1, 0, 1, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3101: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3130: tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_3132: tensor([0, 0, 0, 0, 0, 1, 1, 1, 0, 0], device='cuda:0')
tree_3135: tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_3139: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda

tree_3963: tensor([1, 1, 1, 1, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3967: tensor([0, 0, 0, 1, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_3972: tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 0], device='cuda:0')
tree_3973: tensor([0, 1, 1, 1, 1, 0, 0, 0, 0, 1], device='cuda:0')
tree_3977: tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_3979: tensor([0, 0, 0, 0, 0, 1, 1, 1, 0, 1], device='cuda:0')
tree_3993: tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_3999: tensor([1, 0, 0, 0, 0, 0, 0, 1, 0, 1], device='cuda:0')
0.87225


In [9]:
val_acc = m.predict(model, val_loader, device, config, None, True)
print(val_acc)

tree_1: tensor([1, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_10: tensor([0, 0, 0, 1, 1, 0, 0, 1, 1, 0], device='cuda:0')
tree_26: tensor([0, 1, 1, 0, 1, 0, 1, 1, 1, 0], device='cuda:0')
tree_45: tensor([1, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_50: tensor([0, 0, 1, 0, 1, 0, 1, 1, 0, 0], device='cuda:0')
tree_61: tensor([0, 1, 0, 0, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_82: tensor([0, 0, 0, 1, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_83: tensor([1, 0, 0, 0, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_96: tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 1], device='cuda:0')
tree_102: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 1], device='cuda:0')
tree_103: tensor([0, 1, 0, 0, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_108: tensor([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], device='cuda:0')
tree_132: tensor([0, 0, 0, 0, 0, 0, 1, 0, 1, 0], device='cuda:0')
tree_137: tensor([0, 0, 1, 1, 1, 0, 0, 1, 0, 0], device='cuda:0')
tree_147: tensor([0, 0, 1, 1, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_150: tensor([0,

tree_865: tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 1], device='cuda:0')
tree_869: tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 1], device='cuda:0')
tree_876: tensor([0, 0, 1, 1, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_884: tensor([0, 0, 0, 1, 0, 0, 1, 1, 0, 1], device='cuda:0')
tree_890: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_893: tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0], device='cuda:0')
tree_894: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_901: tensor([1, 0, 1, 0, 0, 1, 0, 0, 1, 0], device='cuda:0')
tree_906: tensor([0, 0, 0, 1, 1, 1, 1, 0, 0, 1], device='cuda:0')
tree_912: tensor([0, 0, 1, 0, 0, 1, 0, 0, 0, 1], device='cuda:0')
tree_915: tensor([0, 0, 0, 0, 1, 0, 1, 1, 1, 0], device='cuda:0')
tree_922: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 0], device='cuda:0')
tree_923: tensor([1, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_926: tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 0], device='cuda:0')
tree_931: tensor([0, 0, 1, 0, 1, 1, 0, 0, 0, 0], device='cuda:0')
tree_940: 

In [10]:
idx = 29
elem = val_set[idx]
print(elem['digit_labels'])
show_img(elem['image'])

tensor([0, 1, 1, 0, 0, 1, 0, 0, 0, 0], dtype=torch.int32)


In [11]:
final_train_acc = m.predict(final_model, train_loader, device, config, None)
print(final_train_acc)

0.85175


In [12]:
final_val_acc = m.predict(final_model, val_loader, device, config, None)
print(final_val_acc)

0.817
