In [1]:
import torch
import timm

In [2]:
image_classes = ['agriculture',
                 'artisinal_mine',
                 'bare_ground',
                 'blooming',
                 'blow_down',
                 'clear',
                 'cloudy',
                 'conventional_mine',
                 'cultivation',
                 'habitation',
                 'haze',
                 'partly_cloudy',
                 'primary',
                 'road',
                 'selective_logging',
                 'slash_burn',
                 'water']
DEVICE = 'cuda:0'

In [3]:
state_dict = torch.load('../weights/model.best.pth')

model = timm.create_model(model_name="resnet18", pretrained=False, num_classes=len(image_classes))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, classes, size, thresholds):
        super().__init__()
        self.model = model
        self.classes = classes
        self.size = size
        self.thresholds = thresholds

    def forward(self, image):
        return torch.sigmoid(self.model.forward(image))

In [5]:
wrapper = ModelWrapper(model, classes=image_classes, size=(224, 224), thresholds=(0.5,) * len(image_classes))

In [6]:
scripted_model = torch.jit.script(wrapper)

In [7]:
scripted_model.classes

['agriculture',
 'artisinal_mine',
 'bare_ground',
 'blooming',
 'blow_down',
 'clear',
 'cloudy',
 'conventional_mine',
 'cultivation',
 'habitation',
 'haze',
 'partly_cloudy',
 'primary',
 'road',
 'selective_logging',
 'slash_burn',
 'water']

In [8]:
traced_model = torch.jit.trace(wrapper, torch.rand(1, 3, 224, 224))

In [9]:
dummy_input = torch.rand(1, 3, 224, 224)

In [10]:
with torch.no_grad():
    print(torch.sigmoid(model(dummy_input)))

tensor([[0.1616, 0.0019, 0.0059, 0.0026, 0.0017, 0.7342, 0.0027, 0.0016, 0.0342,
         0.0191, 0.0116, 0.0800, 0.9929, 0.0734, 0.0029, 0.0026, 0.1044]])


In [11]:
 with torch.no_grad():
    print(scripted_model(dummy_input))

tensor([[0.1616, 0.0019, 0.0059, 0.0026, 0.0017, 0.7342, 0.0027, 0.0016, 0.0342,
         0.0191, 0.0116, 0.0800, 0.9929, 0.0734, 0.0029, 0.0026, 0.1044]])


In [13]:
torch.jit.save(scripted_model, '../weights/space_image_classification.pt')

In [14]:
model = torch.jit.load('../weights/space_image_classification.pt', map_location='cpu')

In [15]:
model.classes

['agriculture',
 'artisinal_mine',
 'bare_ground',
 'blooming',
 'blow_down',
 'clear',
 'cloudy',
 'conventional_mine',
 'cultivation',
 'habitation',
 'haze',
 'partly_cloudy',
 'primary',
 'road',
 'selective_logging',
 'slash_burn',
 'water']