Skip to content
Permalink
Browse files
Reformat code using google yapf standard
  • Loading branch information
Nasser Kaze committed Aug 6, 2021
1 parent 16dfa3d commit 70078f27426e08e6ae794eb26defc366ebc4be1e
Showing 12 changed files with 265 additions and 265 deletions.
@@ -25,7 +25,8 @@
class Dataset(models.Model):
name = models.CharField(max_length=128, unique=True)
region = models.CharField(max_length=128, unique=True)



class Algorithm(models.Model):
'''
The MLAlgorithm represent the ML algorithm object.
@@ -44,8 +45,10 @@ class Algorithm(models.Model):
description: str = models.TextField(blank=True, null=True)
version: str = models.CharField(max_length=128)
status: str = models.CharField(max_length=128)
dataset: Dataset = models.ForeignKey(Dataset, to_field="name",
null=True, on_delete=models.SET_NULL)
dataset: Dataset = models.ForeignKey(Dataset,
to_field="name",
null=True,
on_delete=models.SET_NULL)
created_at: date = models.DateTimeField(auto_now_add=True, blank=True)
created_by: str = models.CharField(max_length=128)

@@ -59,10 +62,11 @@ def __str__(self):
Created By: {self.created_by}
Created At: {self.created_at}
"""

class Meta:
ordering = ['created_at']


class PredictionRequest(models.Model):
'''
The MLRequest will keep information about all requests to ML algorithms.
@@ -82,7 +86,9 @@ class PredictionRequest(models.Model):
prediction = models.CharField(max_length=128)
feedback = models.CharField(max_length=128, blank=True, null=True)
notes = models.TextField(blank=True, null=True)
algorithm: Algorithm = models.ForeignKey(Algorithm, on_delete=models.DO_NOTHING, blank=True)
algorithm: Algorithm = models.ForeignKey(Algorithm,
on_delete=models.DO_NOTHING,
blank=True)
created_at: date = models.DateTimeField(auto_now_add=True, blank=True)
created_by = models.CharField(max_length=128)

@@ -96,6 +102,6 @@ def __str__(self):
Created By: {self.created_by}
Created At: {self.created_at}
"""

class Meta:
ordering = ['created_at']
@@ -27,12 +27,13 @@ class Meta:
model = Dataset
fields = '__all__'

class AlgorithmSerializer(serializers.ModelSerializer):

class AlgorithmSerializer(serializers.ModelSerializer):
class Meta:
model = Algorithm
fields = '__all__'



class PredictionRequestSerializer(serializers.ModelSerializer):
class Meta:
model = PredictionRequest
@@ -18,7 +18,6 @@
from django.test import TestCase
from rest_framework.test import APIClient


test_data = {
"age": 22,
"sex": "female",
@@ -31,11 +30,11 @@

expected_output = 'bad'

class AlgorithmTests(TestCase):

class AlgorithmTests(TestCase):
def test_predict_view(self):
client = APIClient()

classifier_url = "/api/v1/algorithms/predict?classifier=RandomForestClassifier&version=0.0.1"
response = client.post(classifier_url, test_data, format='json')
self.assertEqual(response.status_code, 200)
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""
Definition of urls for api resource.
"""
@@ -26,22 +25,27 @@

from api.views import AlgorithmViewSet, DatasetViewSet, PredictionRequestViewSet


router = routers.DefaultRouter(trailing_slash=False)
router.register(r"algorithms", AlgorithmViewSet, basename="algorithms")
router.register(r"datasets", DatasetViewSet, basename="datasets")
router.register(r"requests", PredictionRequestViewSet, basename="prediction_requests")
router.register(r"requests",
PredictionRequestViewSet,
basename="prediction_requests")

urlpatterns = [
# API docs
path('api-docs/', SpectacularAPIView.as_view(), name='api-docs'),

# Optional UI:
path('api-docs/swagger-ui', SpectacularSwaggerView.as_view(url_name='api-docs'), name='swagger-ui'),
path('api-docs/redoc', SpectacularRedocView.as_view(url_name='api-docs'), name='redoc'),
path('api-docs/swagger-ui',
SpectacularSwaggerView.as_view(url_name='api-docs'),
name='swagger-ui'),
path('api-docs/redoc',
SpectacularRedocView.as_view(url_name='api-docs'),
name='redoc'),

# API Views
path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),

path('api-auth/', include('rest_framework.urls',
namespace='rest_framework')),
path('api/v1/', include(router.urls)),
]

@@ -28,9 +28,6 @@
from rest_framework.fields import CharField, FloatField, IntegerField
from rest_framework.response import Response

# from rest_framework import permissions
# from rest_framework_api_key.permissions import HasAPIKey

from api.models import Algorithm, Dataset, PredictionRequest
from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer

@@ -51,37 +48,45 @@ class AlgorithmViewSet(viewsets.ModelViewSet):
@extend_schema(
description='Predict credit risk for a loan',
parameters=[
OpenApiParameter(name='classifier',
description='The algorithm/classifier to use',
required=True,
examples=[OpenApiExample('Example 1',
value=RandomForestClassifier().__class__.__name__)]),
OpenApiParameter(name='dataset',
description='The name of the dataset',
examples=[OpenApiExample('Example 1', value='german')]),
OpenApiParameter(name='status',
description='The status of the algorithm',
deprecated=True,
examples=[OpenApiExample('Example 1', value='production')]),
OpenApiParameter(name='version',
description='Algorithm version',
required=True,
default='0.0.1',
examples=[OpenApiExample('Example 1', value='0.0.1')]),
OpenApiParameter(
name='classifier',
description='The algorithm/classifier to use',
required=True,
examples=[
OpenApiExample(
'Example 1',
value=RandomForestClassifier().__class__.__name__)
]),
OpenApiParameter(
name='dataset',
description='The name of the dataset',
examples=[OpenApiExample('Example 1', value='german')]),
OpenApiParameter(
name='status',
description='The status of the algorithm',
deprecated=True,
examples=[OpenApiExample('Example 1', value='production')]),
OpenApiParameter(
name='version',
description='Algorithm version',
required=True,
default='0.0.1',
examples=[OpenApiExample('Example 1', value='0.0.1')]),
],
operation_id='algorithms_predict',
request=Dict[str, Any],
responses=inline_serializer(name="PredictionResponse",
fields={"probability": FloatField(),
"label": CharField(),
"method": CharField(),
"color": CharField(),
"wilkis_lambda": FloatField(),
"pillais_trace": FloatField(),
"hotelling_tawley": FloatField(),
"roys_reatest_roots": FloatField(),
"request_id": IntegerField()})
)
fields={
"probability": FloatField(),
"label": CharField(),
"method": CharField(),
"color": CharField(),
"wilkis_lambda": FloatField(),
"pillais_trace": FloatField(),
"hotelling_tawley": FloatField(),
"roys_reatest_roots": FloatField(),
"request_id": IntegerField()
}))
@action(detail=False, methods=['post'])
def predict(self, request, format=None):

@@ -90,30 +95,39 @@ def predict(self, request, format=None):
region = self.request.query_params.get("dataset", "german")
version = self.request.query_params.get("version", "0.0.1")
status = self.request.query_params.get("status", "production")

print(request)

if version is None:
raise bad_request(request=request,
data={"error": "Missing required query parameter: version"})
raise bad_request(
request=request,
data={
"error": "Missing required query parameter: version"
})
if classifier is None:
raise bad_request(request=request,
data={"error": "Missing required query parameter: classifier"})


if classifier in ['manova', 'linearRegression', 'polynomialRegression']:
raise bad_request(
request=request,
data={
"error": "Missing required query parameter: classifier"
})

if classifier in [
'manova', 'linearRegression', 'polynomialRegression'
]:
prediction = stat_score(request.data, classifier)
algorithm = None

else:
algorithm: Algorithm = Algorithm.objects.filter(classifier=classifier,
status=status,
version=version,
dataset__name=region)[0]
algorithm: Algorithm = Algorithm.objects.filter(
classifier=classifier,
status=status,
version=version,
dataset__name=region)[0]

if algorithm is None:
raise bad_request(request=request,
data={"error": "ML algorithm is not available"})
raise bad_request(
request=request,
data={"error": "ML algorithm is not available"})
classifier = registry.classifiers[algorithm.id]
prediction = classifier.compute_prediction(request.data)

@@ -122,7 +136,8 @@ def predict(self, request, format=None):
else:
label = prediction['method']

prediction_request = PredictionRequest(input=json.dumps(request.data),
prediction_request = PredictionRequest(input=json.dumps(
request.data),
response=prediction,
prediction=label,
feedback="",

0 comments on commit 70078f2

Please sign in to comment.