In [1]:
import os
import sys
import csv
from PIL import Image
from collections import OrderedDict
import torch
from torchvision import transforms

from efficientnet_pytorch import EfficientNet

In [3]:
class_list = {'Cloud': 1, 'Crystal': 2, 'Feather': 3, 'Twinning_wisp': 4}

oao_groups = {'1':['Cloud','Crystal'], '2':['Cloud','Feather'], '3':['Cloud','Twinning_wisp'],  
        '4':['Crystal','Feather'], '5':['Crystal','Twinning_wisp'],  '6':['Feather','Twinning_wisp']}

# Choose model
labels_map = oao_groups['1']
print(f'Infer based on OAO model {labels_map[0]}_{labels_map[1]}')

# Load checkpoint
model_name = 'efficientnet-b0'
image_size = EfficientNet.get_image_size(model_name) # 224
model = EfficientNet.from_name(model_name, num_classes=2)

model_path = '/media/hdd/diamond_result/cls_multi-class_EfficientNet/oao_strategy'
checkpoint_path = os.path.join(model_path,f'{labels_map[0]}_{labels_map[1]}','model_best.pth')

state_dict = torch.load(checkpoint_path)['state_dict']
# create new OrderedDict that does not contain 'module.'
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.'
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

Infer based on OAO model Cloud_Crystal


<All keys matched successfully>

In [4]:
# Infer set
testList = {}
data_path = '/media/hdd/diamond_data/cls_multi-class_EfficientNet'
for cls_name in class_list.keys():
    # if not cls_name.lower() == 'pinpoint':
    with open(os.path.join(data_path,cls_name,'test_ids.txt'), 'r') as file:
        tmpList = file.readlines()
    tmpList = [n.split('\n')[0] for n in tmpList]
    for tmp in tmpList:
        testList[tmp] = [tmp,class_list[cls_name],cls_name]

In [None]:
infer_info = [['img_name','cls_name','cls_id',f'prob_{labels_map[0]}',f'prob_{labels_map[1]}']]
num_idx = len(testList.keys())
for idx, img_n in enumerate(testList.keys()):

    # Open image
    cls_name = testList[img_n][2]
    img = Image.open(os.path.join(data_path,cls_name,img_n))

    # Preprocess image
    tfms = transforms.Compose([transforms.Resize((image_size,image_size)), 
                            transforms.CenterCrop(image_size), 
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
    img = tfms(img).unsqueeze(0)

    # Classify with EfficientNet
    model.eval()
    with torch.no_grad():
        logits = model(img)
        print(logits.size())

In [6]:
    class_list = {'Cloud': 1, 'Crystal': 2, 'Feather': 3, 'Twinning_wisp': 4}

    oao_groups = {'1':['Cloud','Crystal'], '2':['Cloud','Feather'], '3':['Cloud','Twinning_wisp'],  
            '4':['Crystal','Feather'], '5':['Crystal','Twinning_wisp'],  '6':['Feather','Twinning_wisp']}

In [11]:
a = class_list.keys()
a[0]

TypeError: 'dict_keys' object is not subscriptable