In [None]:
import timm
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from os.path import join
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms

HF_MODEL = True
CONVERT_FEATS = True
OPTIMIZE = False # Optimization most of the time messes up models

In [None]:
model_name = 'mobilenet_v2_0.75_160'
organistation = 'google' # For HF models
num_classes = 7
MODELS_PATH = 'saved_models'
# IMAGE_SIZE = 384
#IMAGE_SIZE = 256
# IMAGE_SIZE = 224
IMAGE_SIZE = 160
# X = np.ones((IMAGE_SIZE, IMAGE_SIZE, 3)) * 255
# X = Image.fromarray(X.astype('uint8')).convert('RGB')
# processor = AutoFeatureExtractor.from_pretrained('microsoft/' + model_name)
# print(processor)
# processor(X)['pixel_values'][0]

In [None]:
class Helper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x).logits

if not HF_MODEL:
    model_full = timm.create_model(model_name, pretrained = True, num_classes = num_classes, drop_rate = 0.1)
    model_full.train()
    if CONVERT_FEATS:
        model_feats = timm.create_model(model_name, pretrained = True, num_classes = 0, drop_rate = 0.1)
        model_feats.eval()
    timm_data_config = timm.data.resolve_data_config({}, model=model_full)
    print(timm_data_config)
else:
    model_full = AutoModelForImageClassification.from_pretrained(organistation + '/' + model_name, num_labels = num_classes, ignore_mismatched_sizes = True)
    model_full.train()
    helper = Helper(model_full)
    if CONVERT_FEATS:
        model_feats = AutoModelForImageClassification.from_pretrained(organistation + '/' + model_name, num_labels = 0, ignore_mismatched_sizes = True)
        model_feats.eval()
        helper_feats = Helper(model_feats)
        
model_full.load_state_dict(torch.load(join(MODELS_PATH, model_name + '_full.pth')))
if CONVERT_FEATS:
    print('Loading features model')
    model_feats.load_state_dict(torch.load(join(MODELS_PATH, model_name + '_feats.pth')))

X = torch.distributions.uniform.Uniform(0, 1).sample((1, 3, IMAGE_SIZE, IMAGE_SIZE))
if not HF_MODEL:
    processor = transforms.Normalize(mean = timm_data_config['mean'], std = timm_data_config['std'])
else:
    # processor = transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
    processor = transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))

X = processor(X)

if not HF_MODEL:
    dropout_traced_script = torch.jit.trace(model_full, X)
    model_full.eval()
    classifier_traced_script = torch.jit.trace(model_full, X)
    if CONVERT_FEATS:
        feature_traced_script = torch.jit.trace(model_feats, X)
else:
    dropout_traced_script = torch.jit.trace(helper, X)
    model_full.eval()
    classifier_traced_script = torch.jit.trace(helper, X)
    if CONVERT_FEATS:
        feature_traced_script = torch.jit.trace(helper_feats, X)

if OPTIMIZE:
    if CONVERT_FEATS:
        feature_traced_script = optimize_for_mobile(feature_traced_script)
    dropout_traced_script = optimize_for_mobile(dropout_traced_script)
    classifier_traced_script = optimize_for_mobile(classifier_traced_script)

if CONVERT_FEATS:
    feature_traced_script.save(join(MODELS_PATH, model_name + '_feats_mobile.pt'))
dropout_traced_script.save(join(MODELS_PATH, model_name + '_dropout_mobile.pt'))
classifier_traced_script.save(join(MODELS_PATH, model_name + '_full_mobile.pt'))