In [43]:
import torch
import numpy as np
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from PIL import Image

In [44]:
def resnet50(num_classes):
    model = torchvision.models.resnet50(pretrained=True)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    model.eval()
    return model

def vgg16(num_classes=39):
    model = torchvision.models.vgg16(pretrained=True)
    model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)
    model.eval()
    return model
    
class_maps = {
    'label1': ['detritus', 'zooplankton'],
    'label2': ['copepod','detritus', 'noncopepod'],
    'label3': ['annelida_polychaeta', 'appendicularia', 'bivalvia-larvae', 
               'byrozoa-larvae', 'chaetognatha', 'cirripedia_barnacle-nauplii', 
               'cladocera', 'cladocera_evadne-spp', 'cnidaria', 'copepod_calanoida', 
               'copepod_calanoida_acartia-spp', 'copepod_calanoida_calanus-spp', 
               'copepod_calanoida_candacia-spp', 'copepod_calanoida_centropages-spp', 
               'copepod_calanoida_para-pseudocalanus-spp', 'copepod_calanoida_temora-spp', 
               'copepod_cyclopoida', 'copepod_cyclopoida_corycaeus-spp', 
               'copepod_cyclopoida_oithona-spp', 'copepod_cyclopoida_oncaea-spp', 
               'copepod_harpacticoida', 'copepod_nauplii', 'copepod_unknown', 
               'decapoda-larvae_brachyura', 'detritus', 'echniodermata-larvae', 
               'euphausiid', 'euphausiid_nauplii', 'fish-eggs', 'gastropoda-larva', 
               'mysideacea', 'nt-bubbles', 'nt-phyto_ceratium-spp', 
               'nt-phyto_rhizosolenia-spp', 'nt_phyto_chains', 'ostracoda', 
               'radiolaria', 'tintinnida', 'tunicata_doliolida']
}

In [45]:
# define model architecture
model = resnet50(num_classes=39)
# Loading saved model weights
model_state_dict = torch.load(f'/output/models/resnet50/resnet50_label3_005.pth', map_location='cpu')
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [55]:
device = torch.device("cpu")
# check if gpu is available
# if torch.cuda.is_available():
#     device = torch.device("cuda")

In [56]:
# Load image
image = Image.open('/data/images/Pia1.2016-08-02.1631+N38_hc._fx.tif')
# Convert Image to tensor and resize it
t = F.to_tensor(image)
t = F.resize(t, (256, 256))
# model expects a batch of images so lets convert this image tensor to batch of 1 image
t = t.unsqueeze(dim=0)
print((type(t),t.shape))

(<class 'torch.Tensor'>, torch.Size([1, 3, 256, 256]))


In [59]:
model = model.to(device)
t = t.to(device)

with torch.set_grad_enabled(False):
    outputs = model(t)
    # select top 1 from outputs
    _, preds = torch.max(outputs, 1)

In [62]:
print(class_maps['label3'][preds[0]],outputs[0][[preds[0]]])

appendicularia tensor([[ 3.8014, 16.6294, -4.9266,  1.3700,  5.5444, -8.0880, -0.0420,  2.0174,
          1.1218,  2.8133, -2.7241, -2.3848, -2.7246, -2.3331, -2.1492, -1.0956,
          0.4141, -0.7841, -1.5687, -0.1150, -0.8236, -0.1830,  3.7441,  4.9344,
          1.7407,  1.4253,  2.6904,  2.5774,  5.2806, -0.6292, -1.9488, -2.0405,
         -0.9698, -1.2764, -4.3951, -3.1551, -5.7602,  0.3590,  0.2814]])
