In [1]:
import torch
import torchvision.transforms as T
import pytorch_lightning as pl
from PIL import Image
from components_datamodule import ComponentsDataModule
from regnet import rnn_regulated_block, RegNet

# Load Model, Freeze & Get Output

In [2]:
model = RegNet.load_from_checkpoint('pcb_components_val_acc_90.ckpt')



In [4]:
transforms = T.Compose([
                T.Resize((112, 112)),
                T.ToTensor(),
                T.Normalize(mean=(0.2979, 0.2789, 0.2408), std=(0.2960, 0.2848, 0.2620))
            ])

In [8]:
image = Image.open('./diode_1002.png')
image = transforms(image)[None, :, :, :]

In [5]:
model.eval()
model.freeze()
output = model(image)
output = torch.softmax(output, dim=-1)
torch.argmax(output)

tensor(5)

In [6]:
model = RegNet(rnn_regulated_block,
                   in_dim=3,
                   h_dim=64,
                   intermediate_channels=32,
                   classes=6,
                   cell_type='lstm',
                   layers=[1, 1, 3]
                  )
model.load_state_dict(torch.load('pcb_components_val_acc_90.pth'))
model.eval()
print('Model Loaded')

Model Loaded


In [9]:
output = model(image)
output = torch.softmax(output, dim=-1)
torch.argmax(output)

tensor(0)

# Get model test accuracy

In [67]:
root_path = '/storage/PCB-Components-L1'
pcb_components_data_module = ComponentsDataModule(root_path, batch_size=32, transforms=transforms)
trainer = pl.Trainer(gpus= 1)
trainer.test(model, pcb_components_data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


{'diodes': 0, 'ICs': 1, 'capacitors': 2, 'transistors': 3, 'inductors': 4, 'resistors': 5}
Testing:  99%|█████████▉| 155/156 [00:45<00:00,  3.66it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.9204270839691162}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 156/156 [00:45<00:00,  3.40it/s]


[{'test_accuracy': 0.9204270839691162}]