In [2]:
import os
import sys
import pandas as pd
import numpy as np
import torch
from dynaconf import Dynaconf
from vit_pytorch import ViT
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class VanillaViT(pl.LightningModule):
    def __init__(self, settings):
        super(VanillaViT, self).__init__()
        self.settings = settings
        self.save_hyperparameters()

        self.vit = ViT(
            image_size=tuple(settings.VVIT_image_size),
            patch_size=tuple(settings.VVIT_patch_size),
            num_classes=settings.VVIT_num_classes,
            dim=settings.VVIT_dim,
            depth=settings.VVIT_depth,
            heads=settings.VVIT_heads,
            mlp_dim=settings.VVIT_mlp_dim,
            pool=settings.VVIT_pool,
            channels=settings.VVIT_channels,
            dim_head=settings.VVIT_dim_head,
            dropout=settings.VVIT_dropout,
            emb_dropout=settings.VVIT_emb_dropout
        )

    def forward(self, X):
        return self.vit(X).squeeze()

In [4]:
settings = Dynaconf(
    envvar_prefix="DYNACONF",
    settings_files=['./settings.toml'])

model = VanillaViT(settings)


In [5]:
checkpoint = torch.load("./pretrained.ckpt", map_location='cpu')

# Remove prefix added by Pytorch lightning
new_state_dict = {".".join(key.split(".")[1:]): value for key, value in checkpoint['state_dict'].items()}


In [6]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>