In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
from d2lvit import *
import copy
from collections import OrderedDict

In [2]:
model = torch.load(os.path.join('..', 'models', 'finetune_3', 'digit-model.pt'))

In [3]:
train_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_4k'), m.resnet_preprocess()) 
val_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_2k'), m.resnet_preprocess()) 
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}

Train size: 4000 Val size: 2000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [4]:
train_acc = m.predict(model, train_loader, device, config, None)
print(train_acc)

1.0


In [5]:
val_acc = m.predict(model, val_loader, device, config, None)
print(val_acc)

1.0


In [6]:
digits_model = copy.deepcopy(model)

In [7]:
model.heads = nn.Sequential(OrderedDict([('head', nn.Linear(768, 45))]))
model = model.to(device)

In [8]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[-1.5605e-01, -8.4759e-01, -2.6016e-02,  1.5521e-01, -1.5912e-01,
         -3.0882e-01,  7.9547e-01,  2.1157e-01, -5.9635e-01,  1.5759e-01,
         -1.6936e-01, -2.6918e-01,  9.6948e-01,  5.0122e-01,  5.2601e-01,
         -1.7740e-01,  5.5402e-01,  4.2892e-01,  9.1623e-02, -1.6630e-01,
         -1.4784e-01,  3.7945e-04,  1.8073e-01,  4.5887e-01, -3.6593e-01,
          4.8867e-01, -1.1109e+00, -4.5716e-01, -7.9147e-02, -3.6450e-01,
          8.8896e-02,  1.6460e-01,  5.1687e-01, -8.3933e-01,  1.9863e-01,
          1.4351e-01, -3.7023e-01, -3.1495e-01,  3.2912e-01, -4.7073e-01,
          1.3962e+00, -7.0441e-02,  6.1156e-01, -1.3545e-01, -3.0385e-01]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [9]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'finetune_noprims'), digits_model, False, False)

Epoch 10 done, train loss: 0.0007 val acc: 1.0000
Epoch 20 done, train loss: 0.0002 val acc: 0.9995
Epoch 30 done, train loss: 0.0001 val acc: 0.9995
Epoch 40 done, train loss: 0.0000 val acc: 0.9995
Epoch 50 done, train loss: 0.0000 val acc: 0.9995
Epoch 60 done, train loss: 0.0000 val acc: 0.9995
Epoch 70 done, train loss: 0.0000 val acc: 0.9995
Epoch 80 done, train loss: 0.0000 val acc: 0.9995
Epoch 90 done, train loss: 0.0000 val acc: 0.9995
Epoch 100 done, train loss: 0.0000 val acc: 0.9995


In [10]:
model = torch.load(os.path.join('..', 'models', 'finetune_noprims', 'tree-model.pt'))

In [11]:
config = {'labels_key': 'tree_label'}
train_acc = m.predict(model, train_loader, device, config, digits_model, False, False)
print(train_acc)

1.0


In [12]:
val_acc = m.predict(model, val_loader, device, config, digits_model, False, False)
print(val_acc)

1.0
