In [1]:
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from d2lvit import *
import copy
import torch.nn as nn

In [2]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_4k'), preprocess) 
val_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_2k'), preprocess)
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()

Train size: 4000 Val size: 2000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [3]:
img_size, patch_size = 224, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr).to(device)



In [4]:
unused_digit_model = nn.Linear(1,1)

In [5]:
model.head = nn.Sequential(
    nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True),
    nn.Linear(512, 45)
)
model = model.to(device)

In [6]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.4194,  0.0340, -0.6822,  0.3429,  0.1966,  0.0919,  1.5820,  0.8176,
         -0.9807, -0.1770,  0.6108,  0.3901,  0.1147,  0.2809,  0.4554,  1.2832,
          0.3074,  0.5045, -0.0970,  1.5171, -0.1003, -0.1453, -0.5183,  0.6655,
          0.6319, -0.5300,  0.4046,  0.4099, -0.2683, -0.6036, -0.7103, -0.4599,
          0.1655,  1.0063, -0.3506,  1.4950,  0.0377, -0.0619,  1.0871,  0.5886,
         -0.3558, -0.3134,  0.4576,  0.0811,  1.1043]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [7]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'd2lvit_scratchtrees'), unused_digit_model, False, False)

Epoch 10 done, train loss: 0.0184 val acc: 0.8510
Epoch 20 done, train loss: 0.0041 val acc: 0.9865
Epoch 30 done, train loss: 0.0015 val acc: 0.9925
Epoch 40 done, train loss: 0.0012 val acc: 0.9950
Epoch 50 done, train loss: 0.0009 val acc: 0.9980
Epoch 60 done, train loss: 0.0003 val acc: 0.9970
Epoch 70 done, train loss: 0.0007 val acc: 0.9955
Epoch 80 done, train loss: 0.0003 val acc: 0.9960
Epoch 90 done, train loss: 0.0001 val acc: 0.9955
Epoch 100 done, train loss: 0.0003 val acc: 0.9945


In [8]:
model = torch.load(os.path.join('..', 'models', 'd2lvit_scratchtrees', 'tree-model.pt'))

In [9]:
config = {'labels_key': 'tree_label'}
train_acc = m.predict(model, train_loader, device, config, unused_digit_model, True, False)
print(train_acc)

tree_138: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_501: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1579: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_2905: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
0.999


In [10]:
val_acc = m.predict(model, val_loader, device, config, unused_digit_model, True, False)
print(val_acc)

tree_10: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_213: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1306: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tree_1852: tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
0.998
