In [1]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
from graphs import Graph, prims
import os
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
from d2lvit import *
import copy
from collections import OrderedDict

In [2]:
model = torch.load(os.path.join('..', 'models', 'finetune_super3', 'digit-model.pt'))

In [3]:
train_set = TreeDataset(os.path.join('..', 'data', 'super_variety_4k'), transforms.Compose([
    transforms.RandAugment(),
    m.resnet_preprocess()
]))
val_set = TreeDataset(os.path.join('..', 'data', 'super_variety_1k'), m.resnet_preprocess()) 
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}

Train size: 4000 Val size: 1000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [4]:
train_acc = m.predict(model, train_loader, device, config, None)
print(train_acc)

0.85525


In [5]:
val_acc = m.predict(model, val_loader, device, config, None)
print(val_acc)

0.857


In [6]:
digits_model = copy.deepcopy(model)

In [7]:
model.heads = nn.Sequential(OrderedDict([('head', nn.Linear(768, 45))]))
model = model.to(device)

In [8]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[ 0.5921, -0.3702,  0.0625,  0.0472, -0.0136,  0.9181, -0.1674,  0.1637,
         -0.3269, -0.0237, -0.2246,  0.4781, -0.1143,  0.3735,  0.7475, -0.5257,
          0.1460,  0.2912,  0.9398,  0.2868, -0.1353,  0.2743, -0.2618, -0.2068,
         -0.2620, -0.1315, -0.2012, -0.7385,  0.1504,  0.1457,  0.2571, -0.5673,
         -0.9958, -0.2312,  0.6248,  0.2130, -0.1281, -0.2305, -0.2011, -0.3228,
         -0.0362, -0.4088, -0.1373,  0.7346, -0.9438]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [None]:
m.train(model, 0.0001, 0, 50, train_loader, val_loader, device, os.path.join('..', 'models', 'finetune_super3'), digits_model)

Epoch 10 done, train loss: 0.0489 val acc: 0.8190
Epoch 20 done, train loss: 0.0350 val acc: 0.8370


In [None]:
model = torch.load(os.path.join('..', 'models', 'finetune_super3', 'tree-model.pt'))
final_model = torch.load(os.path.join('..', 'models', 'finetune_super3', 'final-tree-model.pt'))

In [None]:
config = {'labels_key': 'tree_label'}
train_acc = m.predict(model, train_loader, device, config, digits_model)
print(train_acc)

In [None]:
val_acc = m.predict(model, val_loader, device, config, digits_model)
print(val_acc)

In [None]:
final_train_acc = m.predict(final_model, train_loader, device, config, digits_model)
print(final_train_acc)

In [None]:
final_val_acc = m.predict(final_model, val_loader, device, config, digits_model)
print(final_val_acc)

In [None]:
import matplotlib.pyplot as plt
train_loss = np.load(os.path.join('..', 'models', 'finetune_super3', 'tree_train_loss.npy'))
val_acc = np.load(os.path.join('..', 'models', 'finetune_super3', 'tree_val_acc.npy'))
epochs = np.arange(1, train_loss.shape[0]+1)
plt.plot(epochs, train_loss, label='loss')
plt.plot(epochs, val_acc, label='acc')
plt.legend()
plt.show()
print(train_loss[:10])