In [1]:
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from d2lvit import *
import copy
import torch.nn as nn

In [2]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train_set = TreeDataset(os.path.join('..', 'data', 'trainset4'), preprocess) 
val_set = TreeDataset(os.path.join('..', 'data', 'trainset1'), preprocess)
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}

Train size: 2000 Val size: 1000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [3]:
model = m.scratch_trained_d2l_vit_digits_model()

In [4]:
train_acc = m.predict(model, train_loader, device, config, None)
print(train_acc)

1.0


In [5]:
val_acc = m.predict(model, val_loader, device, config, None)
print(val_acc)

1.0


In [6]:
digits_model = copy.deepcopy(model)

In [7]:
model.head = nn.Sequential(
    nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True),
    nn.Linear(512, 45)
)
model = model.to(device)

In [8]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.6323, -0.0355, -0.7307, -0.7880, -0.0793,  0.1641,  0.1942,  0.0765,
         -0.8921, -0.4477,  0.3037,  0.3669,  0.0016, -0.4484, -0.6748,  0.5939,
          0.3098, -1.0040,  0.6282,  0.0548,  0.5700,  0.5419, -0.4336,  0.6089,
         -0.5477,  0.3620, -0.5139, -1.0191,  0.9796,  0.1361,  0.8076,  0.3364,
         -0.6024, -0.3169, -0.3417,  0.1919,  0.2471, -0.7637,  1.2581,  0.0854,
          1.2756, -0.8350,  0.8885, -0.9039, -0.1991]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [10]:
m.train(model, 0.0001, 0, 10, train_loader, val_loader, device, os.path.join('..', 'models', 'd2lvit'), digits_model)

Epoch 10 done, train loss: 0.0013 val acc: 0.9010
