In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights

In [2]:
model = models.vit_b_16(weights="ViT_B_16_Weights.IMAGENET1K_V1")

In [16]:
from tree_dataset import TreeDataset, show_img
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from collections import OrderedDict

In [4]:
train_set = TreeDataset(os.path.join('..', 'data', 'super_variety_4k'), transforms.Compose([
    transforms.RandAugment(),
    m.resnet_preprocess()
])) 
val_set = TreeDataset(os.path.join('..', 'data', 'super_variety_1k'), m.resnet_preprocess()) 
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}
model.heads = nn.Sequential(OrderedDict([('head', nn.Linear(768, 10))]))
m.set_torch_vit_dropouts(model, 0.1)
model = model.to(device)

Train size: 4000 Val size: 1000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [5]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.0977, -0.5913,  0.3210, -0.2130, -0.0865,  0.1422, -0.2270,  0.0059,
          0.3273, -0.2642]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [6]:
m.train(model, 0.0001, 0, 60, train_loader, val_loader, device, os.path.join('..', 'models', 'finetune_super3'), None)

Epoch 10 done, train loss: 0.4201 val acc: 0.3260
Epoch 20 done, train loss: 0.2277 val acc: 0.6650
Epoch 30 done, train loss: 0.1625 val acc: 0.7530
Epoch 40 done, train loss: 0.1313 val acc: 0.8150
Epoch 50 done, train loss: 0.1106 val acc: 0.8420
Epoch 60 done, train loss: 0.1038 val acc: 0.8290


In [7]:
model = torch.load(os.path.join('..', 'models', 'finetune_super3', 'digit-model.pt'))
final_model = torch.load(os.path.join('..', 'models', 'finetune_super3', 'final-digit-model.pt'))

In [8]:
train_acc = m.predict(model, train_loader, device, config, None, True)
print(train_acc)

0.8565


In [12]:
val_acc = m.predict(model, val_loader, device, config, None, True)
print(val_acc)

tree_3: tensor([1, 0, 1, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_26: tensor([0, 1, 1, 0, 0, 0, 1, 1, 1, 0], device='cuda:0')
tree_29: tensor([0, 1, 1, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_34: tensor([1, 0, 0, 1, 1, 1, 0, 0, 1, 0], device='cuda:0')
tree_45: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_49: tensor([0, 0, 0, 0, 1, 0, 1, 1, 0, 0], device='cuda:0')
tree_56: tensor([1, 0, 1, 1, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_74: tensor([0, 0, 0, 1, 0, 0, 1, 1, 0, 0], device='cuda:0')
tree_83: tensor([1, 0, 0, 0, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_88: tensor([0, 0, 0, 0, 1, 1, 0, 1, 0, 0], device='cuda:0')
tree_96: tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_99: tensor([1, 1, 1, 1, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_102: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_108: tensor([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], device='cuda:0')
tree_114: tensor([1, 0, 1, 1, 0, 0, 1, 0, 0, 0], device='cuda:0')
tree_118: tensor([1, 1,

tree_898: tensor([0, 1, 1, 1, 1, 0, 1, 1, 1, 0], device='cuda:0')
tree_901: tensor([0, 0, 1, 0, 0, 0, 1, 0, 1, 0], device='cuda:0')
tree_903: tensor([0, 0, 0, 1, 1, 1, 1, 1, 0, 0], device='cuda:0')
tree_912: tensor([0, 0, 1, 0, 0, 1, 0, 0, 0, 1], device='cuda:0')
tree_922: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 0], device='cuda:0')
tree_923: tensor([1, 0, 0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tree_930: tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_940: tensor([0, 0, 1, 0, 1, 0, 0, 1, 1, 0], device='cuda:0')
tree_942: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_944: tensor([0, 0, 0, 0, 0, 1, 1, 1, 0, 1], device='cuda:0')
tree_955: tensor([1, 0, 1, 0, 0, 1, 1, 1, 1, 0], device='cuda:0')
tree_958: tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_966: tensor([0, 0, 0, 1, 1, 0, 1, 1, 1, 0], device='cuda:0')
tree_968: tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 1], device='cuda:0')
tree_970: tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tree_971: 

In [17]:
idx = 3
elem = val_set[3]
print(elem['digit_labels'])
show_img(elem['image'])

tensor([1, 0, 0, 0, 0, 1, 1, 0, 1, 0], dtype=torch.int32)


In [10]:
final_train_acc = m.predict(final_model, train_loader, device, config, None)
print(final_train_acc)

0.8475


In [11]:
final_val_acc = m.predict(final_model, val_loader, device, config, None)
print(final_val_acc)

0.829
