In [33]:
import torch
import timm
from torchvision.datasets import Food101
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchmetrics

In [9]:
checkpoint = torch.load("models/levit_256.fb_dist_in1k/checkpoints.ckpt")
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin'])


In [15]:
state_dict = {}
for key in checkpoint["state_dict"]:
    new_key = key.replace('model.','')
    state_dict[new_key] = checkpoint["state_dict"][key]
state_dict.keys()

dict_keys(['stem.conv1.linear.weight', 'stem.conv1.bn.weight', 'stem.conv1.bn.bias', 'stem.conv1.bn.running_mean', 'stem.conv1.bn.running_var', 'stem.conv1.bn.num_batches_tracked', 'stem.conv2.linear.weight', 'stem.conv2.bn.weight', 'stem.conv2.bn.bias', 'stem.conv2.bn.running_mean', 'stem.conv2.bn.running_var', 'stem.conv2.bn.num_batches_tracked', 'stem.conv3.linear.weight', 'stem.conv3.bn.weight', 'stem.conv3.bn.bias', 'stem.conv3.bn.running_mean', 'stem.conv3.bn.running_var', 'stem.conv3.bn.num_batches_tracked', 'stem.conv4.linear.weight', 'stem.conv4.bn.weight', 'stem.conv4.bn.bias', 'stem.conv4.bn.running_mean', 'stem.conv4.bn.running_var', 'stem.conv4.bn.num_batches_tracked', 'stages.0.blocks.0.attn.attention_biases', 'stages.0.blocks.0.attn.qkv.linear.weight', 'stages.0.blocks.0.attn.qkv.bn.weight', 'stages.0.blocks.0.attn.qkv.bn.bias', 'stages.0.blocks.0.attn.qkv.bn.running_mean', 'stages.0.blocks.0.attn.qkv.bn.running_var', 'stages.0.blocks.0.attn.qkv.bn.num_batches_tracked', 

In [18]:
model = timm.create_model("models/levit_256.fb_dist_in1k", pretrained=False, num_classes=101)
model.load_state_dict(state_dict)
model.cpu()
model

LevitDistilled(
  (stem): Stem16(
    (conv1): ConvNorm(
      (linear): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act1): Hardswish()
    (conv2): ConvNorm(
      (linear): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act2): Hardswish()
    (conv3): ConvNorm(
      (linear): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act3): Hardswish()
    (conv4): ConvNorm(
      (linear): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stages): Sequential(
  

In [21]:
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)

In [23]:
test_data = Food101('data', split='test', transform=transform)
test_loader = DataLoader(test_data, batch_size=128, num_workers=4, pin_memory=True)

In [25]:
scripted_model = torch.jit.script(model)

In [37]:
scripted_model.save("models/levit_256.fb_dist_in1k/scripted.pt")

In [34]:
def test_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               accuracy_fn: torchmetrics.Metric):
    test_acc = 0
    model.eval()
    # Turn on inference context manager
    with torch.no_grad():
        for images, labels in tqdm(data_loader,
                                    total=len(data_loader),
                                    desc='Making predictions:'):
            # 1. Forward pass
            preds = model(images)

            # 2. Calculate accuracy
            test_acc += accuracy_fn(preds.argmax(dim=1), labels)

        # Adjust metrics and print out
        test_acc /= len(data_loader)
        print(f"Test accuracy: {test_acc:.2f}")
    return test_acc.cpu()

In [38]:
model_reloaded = torch.jit.load("models/levit_256.fb_dist_in1k/scripted.pt")

In [39]:
accuracy_fn = torchmetrics.Accuracy('multiclass', num_classes=101)
test_step(model_reloaded, test_loader, accuracy_fn)

 does not have profile information (Triggered internally at ../third_party/nvfuser/csrc/graph_fuser.cpp:104.)
  return forward_call(*args, **kwargs)
Making predictions:: 100%|██████████| 198/198 [02:41<00:00,  1.23it/s]

Test accuracy: 0.65





tensor(0.6523)