In [None]:
import torch
from src.utils.parameters import load_parameters, instanciate_cls
from src.models.classification import CLSModel
from src.utils.decoders import softmax_decoder
from src.utils.dataloaders import load_mnist_dataloader

parameter_file ='./params/mnist.yaml'
params = load_parameters(parameter_file)

xp_params = params['experiment']['parameters']
data_params = params['dataset']['parameters']
net_params = params['network']['parameters']
enc_params = params['encoder']['parameters']

gpu = torch.cuda.is_available()

DEVICE = "cpu"
input_encoder = instanciate_cls(
    params['encoder']['module'], params['encoder']['name'], enc_params)

net = instanciate_cls('src.networks.classification',
                        params['network']['name'], net_params)

model = CLSModel(
    encoder=input_encoder,
    snn=net,
    decoder=softmax_decoder
).to(DEVICE)

model.load_state_dict(torch.load("./logs/MNIST - 3b5e60ac/checkpoint.pth", weights_only=True))
model.eval()
train_dl, test_dl, _ = load_mnist_dataloader('./data', image_size=data_params['resize'], batch_size=1, gpu=gpu)
batch = train_dl.dataset[90]

print('Input given is : {}'.format(batch[1]))
output = model(batch[0])
pred = output.argmax(dim=1, keepdim=True)
print('Input predicted is : {}'.format(pred[0][0]))