Skip to content

Commit

Permalink
Add trade models (#9)
Browse files Browse the repository at this point in the history
add basic stock trade model and route
add tests and fix endpoints
fix unrelated linting issues
  • Loading branch information
DemetreJou committed Nov 24, 2021
1 parent 9a69d89 commit 96d680e
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 30 deletions.
34 changes: 34 additions & 0 deletions backend/tradingbot/migrations/0001_initial.py
@@ -0,0 +1,34 @@
# Generated by Django 3.2.8 on 2021-11-19 21:49

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

initial = True

dependencies = [
]

operations = [
migrations.CreateModel(
name='Company',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.TextField()),
('ticker', models.CharField(max_length=5, unique=True)),
],
),
migrations.CreateModel(
name='StockTrade',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('price', models.FloatField()),
('amount', models.IntegerField()),
('bought_timestamp', models.DateTimeField(auto_now_add=True)),
('sold_timestamp', models.DateTimeField(null=True)),
('company', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tradingbot.company')),
],
),
]
29 changes: 27 additions & 2 deletions backend/tradingbot/models.py
@@ -1,3 +1,28 @@
# from django.db import models
from django.db import models
from rest_framework import serializers

# Create your models here.

class Company(models.Model):
name = models.TextField()
ticker = models.CharField(max_length=5, unique=True)

def __str__(self):
return f"{self.name}:{self.ticker}"


class StockTrade(models.Model):
# TODO: this is an overly simplistic model.
# need to add things like bought_price, sold_price, etc.
# or add transaction type (buy, sell, etc.) which is probably preferable
# should probably change to represent a single exchange instance instead of trying to show an entire buy/sell operation
company = models.ForeignKey(Company, on_delete=models.CASCADE)
price = models.FloatField()
amount = models.IntegerField()
bought_timestamp = models.DateTimeField(auto_now_add=True)
sold_timestamp = models.DateTimeField(null=True)


class StockTradeSerializer(serializers.ModelSerializer):
class Meta:
model = StockTrade
fields = ('company_id', 'price', 'amount', 'bought_timestamp', 'sold_timestamp')
49 changes: 48 additions & 1 deletion backend/tradingbot/tests.py
@@ -1,11 +1,58 @@
from django.urls import reverse
from rest_framework.test import APITestCase
from rest_framework import status
from rest_framework.test import APITestCase

from .models import Company, StockTrade


class TradingbotTests(APITestCase):

def setUp(self) -> None:
company = Company(
name="Apple",
ticker="AAPL",
)
company.save()

StockTrade.objects.create(
company=company,
price=100,
amount=100,
)

def test_chatbot_welcome(self):
url = reverse("tradingbot_welcome")
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_invalid_trade_argument(self):
url = reverse("stock_trade")
response = self.client.post(url, {"trade": "option"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_buy_trade(self):
url = reverse("stock_trade")
response = self.client.post(
url,
{"transaction_type": "buy", "ticker": "AAPL", "amount": "1", "price": "100"}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(StockTrade.objects.all().count(), 2)

def test_invalid_buy(self):
url = reverse("stock_trade")
response = self.client.post(
url,
{"transaction_type": "invalid_options", "ticker": "AAPL", "amount": "1", "price": "100"}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertNotIsInstance(response.json()["data"], type(None))

def test_trade_get(self):
url = reverse("stock_trade")
random_trade = StockTrade.objects.all()[0]
response = self.client.get(url, {"id": random_trade.id})
json_response = response.json()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json_response["price"], random_trade.price)
self.assertEqual(json_response["company_id"], random_trade.company.id)
2 changes: 2 additions & 0 deletions backend/tradingbot/urls.py
Expand Up @@ -2,6 +2,8 @@

from . import views

# TODO: add route for creating, patching company
urlpatterns = [
path('', views.index, name='tradingbot_welcome'),
path('stock_trade', views.StockTradeView.as_view(), name='stock_trade'),
]
39 changes: 38 additions & 1 deletion backend/tradingbot/views.py
@@ -1,5 +1,42 @@
from django.http import HttpResponse
from django.http import HttpResponse, JsonResponse
from django.views import View
from rest_framework import status

from .models import StockTrade, StockTradeSerializer, Company


def index(request):
return HttpResponse("Hello World, welcome to tradingbot!")


class StockTradeView(View):
# TODO: Add alpaca integration
model = StockTrade

def get(self, request):
id = request.GET.get("id")
stock_trade = self.model.objects.all().filter(id=id).first()
return JsonResponse(StockTradeSerializer(stock_trade).data, safe=False)

def post(self, request):
transaction_type = request.POST.get("transaction_type")
if transaction_type == "sell":
return HttpResponse(status=status.HTTP_501_NOT_IMPLEMENTED)

if transaction_type == "buy":
ticker = request.POST.get("ticker")
price = float(request.POST.get("price"))
amount = int(request.POST.get("amount"))
company = Company.objects.filter(ticker=ticker).first()
if not company:
return JsonResponse(
{"data": f"ticker: {ticker} not valid"},
status=status.HTTP_400_BAD_REQUEST,
)
self.model.objects.create(company=company, price=price, amount=amount)
return HttpResponse(status=status.HTTP_201_CREATED)

return JsonResponse(
{"data": "the only supported transactions are 'buy' or 'sell'"},
status=status.HTTP_400_BAD_REQUEST
)
20 changes: 0 additions & 20 deletions backend/urls.py
@@ -1,26 +1,6 @@
"""backend URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/3.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include

# from django.urls import re_path
# from django.shortcuts import render
# import os
# from .settings import BASE_DIR

BASE_API_URL = "api"


Expand Down
10 changes: 4 additions & 6 deletions ml/data_collection/get_news.py
Expand Up @@ -11,13 +11,15 @@ def get_raw_data(ticker, limit, date_from=None, date_to=None):
if date_from is None:
url = "https://eodhistoricaldata.com/api/news?api_token={}&s={}&offset=0&limit={}".format(API_TOKEN, ticker, limit)
else:
url = "https://eodhistoricaldata.com/api/news?api_token={}&s={}&from={}&to={}&offset=0&limit={}".format(API_TOKEN, ticker, date_from, date_to, limit)
url = "https://eodhistoricaldata.com/api/news?api_token={}&s={}&from={}&to={}&offset=0&limit={}".\
format(API_TOKEN, ticker, date_from, date_to, limit)
response = requests.get(url)
content = response.content
parsed = json.loads(content)

return parsed


def to_dataframe(parsed):
d = {"date": [], "title": [], "content": [], "symbols": [], "tags": []}
for i in range(len(parsed)):
Expand All @@ -30,6 +32,7 @@ def to_dataframe(parsed):

return df


def main():
parsed = get_raw_data("AAPL.US", 10)
df = to_dataframe(parsed)
Expand All @@ -38,8 +41,3 @@ def main():

if __name__ == "__main__":
main()





0 comments on commit 96d680e

Please sign in to comment.