In [21]:
import torch
import coremltools as ct

In [22]:
class AirQualityNet(torch.nn.Module):
    def __init__(self):
        super(AirQualityNet, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)
        self.fc2 = torch.nn.Linear(10, 10)
        self.fc3 = torch.nn.Linear(10, 10)
        self.fc4 = torch.nn.Linear(10, 4)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
    def predict(self, x):
        x = self.forward(x)
        return torch.argmax(x, dim=1)

In [23]:
version = 3
acc = 96
model_file = f'air_quality_v{version}_acc_{acc}.pt'
model = AirQualityNet()
model.load_state_dict(torch.load(model_file))
model.eval()

AirQualityNet(
  (fc1): Linear(in_features=5, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=10, bias=True)
  (fc3): Linear(in_features=10, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=4, bias=True)
)

In [24]:
example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])  
traced_model = torch.jit.trace(model, example_input)

In [25]:
model_from_torch = ct.convert(traced_model,
							  inputs=[ct.TensorType(name="input", 
													shape=example_input.shape)])

# mlmodel = ct.convert(traced_model, inputs=[ct.TensorType(shape=example_input.shape)])

# Save the model
model_from_torch.save(f"air_quality_v{version}_acc_{acc}.mlpackage")

When both 'convert_to' and 'minimum_deployment_target' not specified, 'convert_to' is set to "mlprogram" and 'minimum_deployment_target' is set to ct.target.iOS15 (which is same as ct.target.macOS12). Note: the model will not run on systems older than iOS15/macOS12/watchOS8/tvOS15. In order to make your model run on older system, please set the 'minimum_deployment_target' to iOS14/iOS13. Details please see the link: https://apple.github.io/coremltools/docs-guides/source/target-conversion-formats.html
Converting PyTorch Frontend ==> MIL Ops:  86%|████████▌ | 6/7 [00:00<00:00, 5142.18 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 4207.77 passes/s]
Running MIL default pipeline: 100%|██████████| 89/89 [00:00<00:00, 3630.16 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 8192.00 passes/s]
