Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in upload file endpoint #14924

Merged
merged 13 commits into from Oct 5, 2022
6 changes: 4 additions & 2 deletions src/lightning_app/core/api.py
Expand Up @@ -26,7 +26,8 @@
from lightning_app.core.queues import RedisQueue
from lightning_app.storage import Drive
from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
from lightning_app.utilities.enum import OpenAPITags
from lightning_app.utilities.component import _context
from lightning_app.utilities.enum import ComponentContext, OpenAPITags
from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available

if _is_starsessions_available():
Expand Down Expand Up @@ -255,7 +256,8 @@ async def upload_file(filename: str, uploaded_file: UploadFile = File(...)):
f.write(content)
done = content == b""

drive.put(filename)
with _context(ComponentContext.WORK):
drive.put(filename)
return f"Successfully uploaded '{filename}' to the Drive"


Expand Down
11 changes: 10 additions & 1 deletion tests/tests_app/core/test_lightning_api.py
Expand Up @@ -430,7 +430,7 @@ def target():


def test_configure_api():

# Setup
process = Process(target=target)
process.start()
time_left = 15
Expand All @@ -442,6 +442,13 @@ def test_configure_api():
sleep(0.1)
time_left -= 0.1

# Test Upload File
files = {"uploaded_file": open(__file__, "rb")}

response = requests.put(f"http://localhost:{APP_SERVER_PORT}/api/v1/upload_file/test", files=files)
assert response.json() == "Successfully uploaded 'test' to the Drive"

# Test Custom Request
response = requests.post(
f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json()
)
Expand All @@ -450,6 +457,8 @@ def test_configure_api():
f"http://localhost:{APP_SERVER_PORT}/api/v1/request", data=InputRequestModel(name="hello").json()
)
assert response.json() == {"name": "hello", "counter": 2}

# Teardown
time_left = 15
while time_left > 0:
if process.exitcode == 0:
Expand Down