In [1]:
import os
os.chdir('../')

In [2]:
import sys
sys.path.append('./')

In [3]:
from box import Box
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np

from src.SLR.data import VideoDataModule
from src.SLR.models import SLR_Lightning

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
cfg = Box.from_yaml(open('configs/SLR/cfg.yaml', "r").read())

In [5]:
dataset_cfg = cfg.dataset_args
model_cfg = cfg.model_args
optimizer_cfg = cfg.optimizer_args
trainer_cfg = cfg.trainer_args
eval_cfg = cfg.evaluate_args

In [6]:
gloss_dict = np.load(os.path.join(dataset_cfg.info_dir, "gloss_dict.npy"), allow_pickle=True).item()

In [7]:
slr_model = SLR_Lightning(gloss_dict=gloss_dict, **model_cfg, **optimizer_cfg, **eval_cfg)
dm = VideoDataModule(gloss_dict=gloss_dict, **dataset_cfg)

In [8]:
import torch
# sd = torch.load('models/SLR/dev_19.82_epoch35_model.pt')
sd = torch.load('models/SLR/resnet18_vac_smkd_dev_19.80_epoch35_model.pt')["model_state_dict"]

In [9]:
from collections import OrderedDict
def modified_weights(state_dict, modified=False):
    state_dict = OrderedDict(
        [(k.replace(".module", ""), v) for k, v in state_dict.items()]
    )
    if not modified:
        return state_dict
    modified_dict = dict()
    return modified_dict

In [10]:
sd = modified_weights(sd, False)
slr_model.model.load_state_dict(sd, strict=True)

<All keys matched successfully>

In [11]:
checkpoint_callback = ModelCheckpoint(
    dirpath=trainer_cfg.ckpt_dir,
    filename="Phoenix2014-SLR-{epoch:02d}-{val_loss:.2f}"
)
trainer = Trainer(
    accelerator=trainer_cfg.accelerator,
    devices=trainer_cfg.devices,
    max_epochs=trainer_cfg.max_epochs,
    callbacks=[checkpoint_callback]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.validate(model=slr_model, datamodule=dm)

In [None]:
trainer.test(model=slr_model, datamodule=dm)

In [None]:
trainer.fit(slr_model, dm)