In [9]:
import torch
from src.utils.parameters import load_parameters, instanciate_cls
from src.models.classification import CLSModel
from src.utils.decoders import softmax_decoder
from src.utils.dataloaders import load_mnist_dataloader

parameter_file ='./params/mnist.yaml'
params = load_parameters(parameter_file)

xp_params = params['experiment']['parameters']
data_params = params['dataset']['parameters']
net_params = params['network']['parameters']
enc_params = params['encoder']['parameters']

gpu = torch.cuda.is_available()

In [10]:
DEVICE = "cpu"
input_encoder = instanciate_cls(
    params['encoder']['module'], params['encoder']['name'], enc_params)

net = instanciate_cls('src.networks.classification',
                        params['network']['name'], net_params)

model = CLSModel(
    encoder=input_encoder,
    snn=net,
    decoder=softmax_decoder
).to(DEVICE)


In [11]:
model

CLSModel(
  (encoder): ConstantCurrentLIFEncoder()
  (snn): DoubleConvNet(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (fc): Linear(in_features=1600, out_features=64, bias=True)
    (out): LILinearCell()
    (lif0): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
    (lif1): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
    (lif2): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
  )
)

In [15]:
model.load_state_dict(torch.load("./logs/MNIST - 3b5e60ac/checkpoint.pth", weights_only=True))
model.eval()
train_dl, test_dl, _ = load_mnist_dataloader('./data', image_size=data_params['resize'], batch_size=1, gpu=gpu)
batch = train_dl.dataset[90]

In [16]:
model

CLSModel(
  (encoder): ConstantCurrentLIFEncoder()
  (snn): DoubleConvNet(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (fc): Linear(in_features=1600, out_features=64, bias=True)
    (out): LILinearCell()
    (lif0): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
    (lif1): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
    (lif2): LIFCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(80)), dt=0.001)
  )
)

In [17]:
print('Input given is : {}'.format(batch[1]))
output = model(batch[0])
pred = output.argmax(dim=1, keepdim=True)
print('Input predicted is : {}'.format(pred[0][0]))

Input given is : 6
Input predicted is : 6
