Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Adding MLCompute compute context to OD
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyajn authored and Shreya Jain committed Jul 31, 2020
1 parent 6716bbe commit 304cda6
Show file tree
Hide file tree
Showing 7 changed files with 561 additions and 10 deletions.
7 changes: 2 additions & 5 deletions cmake/SetupCompiler.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ if(APPLE)
set(HAS_CORE_ML TRUE)
endif()

# MLCompute is only present on macOS 11.0 or higher.
if(NOT TC_BASE_SDK_VERSION VERSION_LESS 10.16)
# MLCompute is only present on macOS 10.16 or higher.
if(TC_BASE_SDK_VERSION VERSION_GREATER_EQUAL 10.16)
add_definitions(-DHAS_ML_COMPUTE)
set(HAS_ML_COMPUTE TRUE)
endif()
Expand All @@ -135,9 +135,6 @@ if(APPLE)
add_definitions(-DHAS_MACOS_10_15)
endif()

if(NOT TC_BASE_SDK_VERSION VERSION_LESS 10.16)
add_definitions(-DHAS_MACOS_10_16)
endif()
endif()

endmacro()
2 changes: 2 additions & 0 deletions src/ml/neural_net/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ if(APPLE AND HAS_ML_COMPUTE AND NOT TC_BUILD_IOS)
SOURCES
mlc_compute_context.mm
mlc_layer_weights.mm
mlc_dc_backend.mm
mlc_od_backend.mm
mlc_utils.mm
TCMLComputeDrawingClassifierDescriptor.m
TCMLComputeObjectDetectorDescriptor.m
TCMLComputeUtil.m
TCModelTrainerBackendGraphs.m
Expand Down
41 changes: 41 additions & 0 deletions src/ml/neural_net/TCMLComputeDrawingClassifierDescriptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright © 2020 Apple Inc. All rights reserved.
*
* Use of this source code is governed by a BSD-3-clause license that can
* be found in the LICENSE.txt file or at
* https://opensource.org/licenses/BSD-3-Clause
*/

#import <MLCompute/MLCompute.h>

NS_ASSUME_NONNULL_BEGIN

// Defines the parameters for the MLCompute-based implementation of the
// Drawing Classifier model.
API_AVAILABLE(macos(10.16))
@interface TCMLComputeDrawingClassifierDescriptor : NSObject

// Defines the shape of the graph's input.
@property(nonatomic) MLCTensor *inputTensor;

// Defines the shape of the graph's ouput.
@property(nonatomic) MLCTensor *outputTensor;

// Controls the number of features in the output tensor, which should be equal
// to the number of classes.
@property(nonatomic) NSUInteger outputChannels;

// Dictionary mapping layer names to weights.
@property(nonatomic) NSDictionary<NSString *, MLCTensor *> *weights;

@end

API_AVAILABLE(macos(10.16))
@interface MLCGraph (TCMLComputeDrawingClassifier)

+ (instancetype)tc_graphForDrawingClassifierDescriptor:
(TCMLComputeDrawingClassifierDescriptor *)descriptor
batchSize:(NSUInteger)batchSize;

@end

NS_ASSUME_NONNULL_END
132 changes: 132 additions & 0 deletions src/ml/neural_net/TCMLComputeDrawingClassifierDescriptor.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Copyright © 2020 Apple Inc. All rights reserved.
*
* Use of this source code is governed by a BSD-3-clause license that can
* be found in the LICENSE.txt file or at
* https://opensource.org/licenses/BSD-3-Clause
*/

#import <ml/neural_net/TCMLComputeDrawingClassifierDescriptor.h>

#import <ml/neural_net/TCMLComputeUtil.h>

@implementation TCMLComputeDrawingClassifierDescriptor

- (BOOL)isComplete
{
if (self.inputTensor == nil) return NO;
if (self.outputChannels == 0) return NO;

return YES;
}

- (MLCTensor *)addConvLayer:(NSUInteger)index
outputChannels:(NSUInteger)outputChannels
source:(MLCTensor *)source
graph:(MLCGraph *)graph
{
// Find the weights for this conv layer in our dictionary of parameters.
NSString *biasKey = [NSString stringWithFormat:@"drawing_conv%lu_bias", (unsigned long)index];
NSString *weightKey = [NSString stringWithFormat:@"drawing_conv%lu_weight", (unsigned long)index];
MLCTensor *bias = self.weights[biasKey];
MLCTensor *weights = self.weights[weightKey];

// Configure the convolution descriptor.
NSUInteger inputChannels =
weights.descriptor.shape[TCMLComputeTensorSizeChannels].unsignedIntegerValue / outputChannels;
NSUInteger kernelHeight =
weights.descriptor.shape[TCMLComputeTensorSizeHeight].unsignedIntegerValue;
NSUInteger kernelWidth =
weights.descriptor.shape[TCMLComputeTensorSizeWidth].unsignedIntegerValue;
// Configure the convolution descriptor.
MLCConvolutionDescriptor *conv_desc =
[MLCConvolutionDescriptor descriptorWithKernelSizes:@[ @(kernelHeight), @(kernelWidth) ]
inputFeatureChannelCount:inputChannels
outputFeatureChannelCount:outputChannels
strides:@[ @1, @1 ]
paddingPolicy:MLCPaddingPolicySame
paddingSizes:nil];

MLCConvolutionLayer *conv = [MLCConvolutionLayer layerWithWeights:weights
biases:bias
descriptor:conv_desc];
MLCTensor *convTensor = [graph nodeWithLayer:conv source:source];

MLCLayer *relu1 = [MLCActivationLayer
layerWithDescriptor:[MLCActivationDescriptor descriptorWithType:MLCActivationTypeReLU
a:0.0f]];
MLCTensor *reluTensor = [graph nodeWithLayer:relu1 source:convTensor];

MLCPoolingDescriptor *poolDesc =
[MLCPoolingDescriptor maxPoolingDescriptorWithKernelSizes:@[ @2, @2 ]
strides:@[ @2, @2 ]
paddingPolicy:MLCPaddingPolicyValid
paddingSizes:nil];
MLCLayer *pool = [MLCPoolingLayer layerWithDescriptor:poolDesc];

MLCTensor *poolTensor = [graph nodeWithLayer:pool source:reluTensor];

return poolTensor;
}

- (MLCTensor *)addDenseLayer:(NSUInteger)index
outputChannels:(NSUInteger)outputChannels
source:(MLCTensor *)source
graph:(MLCGraph *)graph
{
// Find the weights for this conv layer in our dictionary of parameters.
NSString *biasKey = [NSString stringWithFormat:@"drawing_dense%lu_bias", (unsigned long)index];
NSString *weightKey =
[NSString stringWithFormat:@"drawing_dense%lu_weight", (unsigned long)index];
MLCTensor *bias = self.weights[biasKey];
MLCTensor *weights = self.weights[weightKey];

// Configure the convolution descriptor.
NSUInteger inputChannels =
weights.descriptor.shape[TCMLComputeTensorSizeChannels].unsignedIntegerValue / outputChannels;
MLCConvolutionDescriptor *dense_desc =
[MLCConvolutionDescriptor descriptorWithKernelSizes:@[ @1, @1 ]
inputFeatureChannelCount:inputChannels
outputFeatureChannelCount:outputChannels
strides:@[ @1, @1 ]
paddingPolicy:MLCPaddingPolicySame
paddingSizes:nil];

MLCFullyConnectedLayer *dense = [MLCFullyConnectedLayer layerWithWeights:weights
biases:bias
descriptor:dense_desc];
MLCTensor *denseTensor = [graph nodeWithLayer:dense source:source];
return denseTensor;
}
@end

@implementation MLCGraph (TCMLComputeDrawingClassifier)

+ (instancetype)tc_graphForDrawingClassifierDescriptor:
(TCMLComputeDrawingClassifierDescriptor *)descriptor
batchSize:(NSUInteger)batchSize
{
if (![descriptor isComplete]) return nil;

MLCGraph *graph = [[self alloc] init];

NSUInteger channelCounts[] = {16, 32, 64};
MLCTensor *tensor = descriptor.inputTensor;
for (NSUInteger i = 0; i < 3; ++i) {
tensor = [descriptor addConvLayer:i outputChannels:channelCounts[i] source:tensor graph:graph];
}
MLCReshapeLayer *flatten_layer =
[MLCReshapeLayer layerWithShape:@[ @(batchSize), @(576), @(1), @(1) ]];
tensor = [graph nodeWithLayer:flatten_layer source:tensor];
tensor = [descriptor addDenseLayer:0 outputChannels:128 source:tensor graph:graph];
MLCLayer *relu1 = [MLCActivationLayer
layerWithDescriptor:[MLCActivationDescriptor descriptorWithType:MLCActivationTypeReLU
a:0.0f]];
MLCTensor *reluTensor = [graph nodeWithLayer:relu1 source:tensor];
descriptor.outputTensor = [descriptor addDenseLayer:1
outputChannels:descriptor.outputChannels
source:reluTensor
graph:graph];
return graph;
}

@end
12 changes: 7 additions & 5 deletions src/ml/neural_net/mlc_compute_context.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <core/logging/logger.hpp>
#include <core/util/std/make_unique.hpp>

#include <ml/neural_net/mlc_dc_backend.hpp>
#include <ml/neural_net/mlc_od_backend.hpp>
#include <ml/neural_net/mps_compute_context.hpp>
#include <ml/neural_net/mps_image_augmentation.hpp>
Expand All @@ -34,8 +36,7 @@
// At static-init time, register create_mps_compute_context().
// TODO: Codify priority levels
static auto* mlc_registration = new compute_context::registration(
/* priority */ 0, nullptr, nullptr, &create_mlc_compute_context);

/* priority */ 0, nullptr, &create_mlc_compute_context, &create_mlc_compute_context);
}

mlc_compute_context::mlc_compute_context(MLCDevice* device)
Expand Down Expand Up @@ -99,7 +100,6 @@
const float_array_map& config, const float_array_map& weights)
{
if (@available(macOS 10.16, *)) {
std::cout << "MLC Compute context is here" << std::endl;
return std::make_unique<mlc_object_detector_backend>(device_, n, c_in, h_in, w_in, c_out, h_out,
w_out, config, weights);
}
Expand All @@ -109,13 +109,15 @@
std::unique_ptr<model_backend> mlc_compute_context::create_activity_classifier(
const ac_parameters& ac_params)
{
std::unique_ptr<compute_context> mps_compute_context = create_mps_compute_context();
return mps_compute_context->create_activity_classifier(ac_params);
return nullptr;
}

std::unique_ptr<model_backend> mlc_compute_context::create_drawing_classifier(
const float_array_map& weights, size_t batch_size, size_t num_classes)
{
if (@available(macOS 10.16, *)) {
return std::make_unique<mlc_drawing_classifier_backend>(device_, weights, batch_size, num_classes);
}
return nullptr;
}

Expand Down
41 changes: 41 additions & 0 deletions src/ml/neural_net/mlc_dc_backend.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright © 2020 Apple Inc. All rights reserved.
*
* Use of this source code is governed by a BSD-3-clause license that can
* be found in the LICENSE.txt file or at
* https://opensource.org/licenses/BSD-3-Clause
*/

#pragma once

#import <MLCompute/MLCompute.h>

#include <ml/neural_net/model_backend.hpp>
#include <ml/neural_net/mlc_layer_weights.hpp>

namespace turi {
namespace neural_net {

class API_AVAILABLE(macos(10.16)) mlc_drawing_classifier_backend : public model_backend {
public:
mlc_drawing_classifier_backend(MLCDevice *device, const float_array_map &weights,
size_t batch_size, size_t num_classes);

// model_backend interface
float_array_map export_weights() const override;
void set_learning_rate(float lr) override;
float_array_map train(const float_array_map &inputs) override;
float_array_map predict(const turi::neural_net::float_array_map &inputs) const override;

private:
MLCTrainingGraph *training_graph_ = nil;
MLCInferenceGraph *inference_graph_ = nil;
MLCTensor *input_ = nil;
MLCTensor *weights_ = nil;
MLCTensor *labels_ = nil;

mlc_layer_weights layer_weights_;
size_t num_classes_;
};

} // namespace neural_net
} // namespace turi
Loading

0 comments on commit 304cda6

Please sign in to comment.