In [2]:
import torch
import timm

In [3]:
 genres = ['agriculture',
           'artisinal_mine',
           'bare_ground',
           'blooming',
           'blow_down',
           'clear',
           'cloudy',
           'conventional_mine',
           'cultivation',
           'habitation',
           'haze',
           'partly_cloudy',
           'primary',
           'road',
           'selective_logging',
           'slash_burn',
           'water']
DEVICE = 'cuda:0'

In [5]:
state_dict = torch.load('../weights/model.best.pth')

model = timm.create_model(model_name="resnet18", pretrained=False, num_classes=len(genres))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
 class ModelWrapper(torch.nn.Module):
    def __init__(self, model, classes, size, thresholds):
        super().__init__()
        self.model = model
        self.classes = classes
        self.size = size
        self.thresholds = thresholds
    
    def forward(self, image):
        return torch.sigmoid(self.model.forward(image))

In [7]:
 wrapper = ModelWrapper(model, classes=genres, size=(224, 224), thresholds=(0.5,) * len(genres))

In [8]:
scripted_model = torch.jit.script(wrapper)

In [10]:
 scripted_model.classes

['agriculture',
 'artisinal_mine',
 'bare_ground',
 'blooming',
 'blow_down',
 'clear',
 'cloudy',
 'conventional_mine',
 'cultivation',
 'habitation',
 'haze',
 'partly_cloudy',
 'primary',
 'road',
 'selective_logging',
 'slash_burn',
 'water']

In [11]:
 traced_model = torch.jit.trace(wrapper, torch.rand(1, 3, 224, 224))

In [14]:
 dummy_input = torch.rand(1, 3, 224, 224)

In [15]:
with torch.no_grad():
    print(torch.sigmoid(model(dummy_input)))

tensor([[0.1534, 0.0017, 0.0054, 0.0027, 0.0017, 0.7507, 0.0025, 0.0016, 0.0324,
         0.0172, 0.0104, 0.0807, 0.9936, 0.0766, 0.0029, 0.0025, 0.0985]])


In [16]:
 with torch.no_grad():
    print(scripted_model(dummy_input))

tensor([[0.1534, 0.0017, 0.0054, 0.0027, 0.0017, 0.7507, 0.0025, 0.0016, 0.0324,
         0.0172, 0.0104, 0.0807, 0.9936, 0.0766, 0.0029, 0.0025, 0.0985]])


In [17]:
 torch.jit.save(scripted_model, '../weights/space_image_classification.pt')

In [19]:
 model = torch.jit.load('../weights/space_image_classification.pt', map_location='cpu')

In [20]:
model.classes

['agriculture',
 'artisinal_mine',
 'bare_ground',
 'blooming',
 'blow_down',
 'clear',
 'cloudy',
 'conventional_mine',
 'cultivation',
 'habitation',
 'haze',
 'partly_cloudy',
 'primary',
 'road',
 'selective_logging',
 'slash_burn',
 'water']