In [1]:
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from d2lvit import ViT
import torch

In [2]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_4k'), preprocess) 
val_set = TreeDataset(os.path.join('..', 'data', 'extra_variety_2k'), preprocess)
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}

Train size: 4000 Val size: 2000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [3]:
img_size, patch_size = 224, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr).to(device)



In [4]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[-1.0491,  0.0803,  0.2877, -0.1619,  0.5497,  0.3042, -0.0856,  0.0765,
         -0.0812,  0.2780]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [5]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'd2lvit_3'), None)

Epoch 10 done, train loss: 0.0381 val acc: 0.9585
Epoch 20 done, train loss: 0.0171 val acc: 0.9765
Epoch 30 done, train loss: 0.0064 val acc: 0.9915
Epoch 40 done, train loss: 0.0085 val acc: 0.9945
Epoch 50 done, train loss: 0.0035 val acc: 0.9950
Epoch 60 done, train loss: 0.0017 val acc: 0.9955
Epoch 70 done, train loss: 0.0031 val acc: 0.9925
Epoch 80 done, train loss: 0.0021 val acc: 0.9975
Epoch 90 done, train loss: 0.0147 val acc: 0.9935
Epoch 100 done, train loss: 0.0008 val acc: 0.9955


In [6]:
model = torch.load(os.path.join('..', 'models', 'd2lvit_3', 'digit-model.pt'))

In [7]:
train_acc = m.predict(model, train_loader, device, config, None)
print(train_acc)

1.0


In [8]:
val_acc = m.predict(model, val_loader, device, config, None)
print(val_acc)

1.0
