From 92de224eaa73a2f7fa48910fa7b8224769e7fa95 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 3 Aug 2022 12:59:42 +0200 Subject: [PATCH 1/5] update --- lightning_hpo/app/sweeper.py | 3 + lightning_hpo/commands/sweep.py | 109 ++++++++++++++++++++++++++++---- 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/lightning_hpo/app/sweeper.py b/lightning_hpo/app/sweeper.py index b43d6526..00959571 100644 --- a/lightning_hpo/app/sweeper.py +++ b/lightning_hpo/app/sweeper.py @@ -36,12 +36,15 @@ def run(self): trials.extend(sweep.get_trials()) if trials: for trial in trials: + breakpoint() resp = requests.post(self.db.url + "/trial", data=trial.json()) assert resp.status_code == 200 def create_sweep(self, config: SweepConfig) -> str: sweep_ids = list(self.sweeps.keys()) if config.sweep_id not in sweep_ids: + breakpoint() + resp = requests.post(self.db.url + "/sweep", data=config.json()) self.sweeps[config.sweep_id] = Sweep.from_config( config, code={"drive": self.drive, "name": config.sweep_id}, diff --git a/lightning_hpo/commands/sweep.py b/lightning_hpo/commands/sweep.py index 313d3954..484a9ac4 100644 --- a/lightning_hpo/commands/sweep.py +++ b/lightning_hpo/commands/sweep.py @@ -4,16 +4,97 @@ from argparse import ArgumentParser from getpass import getuser from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Generic, TypeVar from uuid import uuid4 import requests from lightning.app.source_code import LocalSourceCodeDir from lightning.app.source_code.uploader import FileUploader from lightning.app.utilities.commands import ClientCommand -from pydantic import BaseModel from sqlalchemy import Column -from sqlmodel import Field, JSON, SQLModel +from pydantic import BaseModel +from sqlmodel import Field, JSON, SQLModel, TypeDecorator +import json +from fastapi.encoders import jsonable_encoder + +T = TypeVar("T") + +# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082 +def pydantic_column_type(pydantic_type): + class PydanticJSONType(TypeDecorator, Generic[T]): + impl = JSON() + + def __init__( + self, + json_encoder=json, + ): + self.json_encoder = json_encoder + super(PydanticJSONType, self).__init__() + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + dumps = self.json_encoder.dumps + if impl_processor: + + def process(value: T): + if value is not None: + if isinstance(pydantic_type, BaseModel): + # This allows to assign non-InDB models and if they're + # compatible, they're directly parsed into the InDB + # representation, thus hiding the implementation in the + # background. However, the InDB model will still be returned + value_to_dump = pydantic_type.from_orm(value) + else: + value_to_dump = value + value = jsonable_encoder(value_to_dump) + return impl_processor(value) + + else: + + def process(value): + if isinstance(pydantic_type, BaseModel): + # This allows to assign non-InDB models and if they're + # compatible, they're directly parsed into the InDB + # representation, thus hiding the implementation in the + # background. However, the InDB model will still be returned + value_to_dump = pydantic_type.from_orm(value) + else: + value_to_dump = value + value = dumps(jsonable_encoder(value_to_dump)) + return value + + return process + + def result_processor(self, dialect, coltype) -> T: + impl_processor = self.impl.result_processor(dialect, coltype) + if impl_processor: + + def process(value): + value = impl_processor(value) + if value is None: + return None + + data = value + # Explicitly use the generic directly, not type(T) + full_obj = parse_obj_as(pydantic_type, data) + return full_obj + + else: + + def process(value): + if value is None: + return None + + # Explicitly use the generic directly, not type(T) + full_obj = parse_obj_as(pydantic_type, value) + return full_obj + + return process + + def compare_values(self, x, y): + return x == y + + return PydanticJSONType class Params(SQLModel, table=False): @@ -21,11 +102,12 @@ class Params(SQLModel, table=False): class Distributions(SQLModel, table=False): - distribution: Dict[str, Params] = Field(sa_column=Column(JSON)) + name: str + distribution: str + params: Params = Field(sa_column=Column(pydantic_column_type(Params))) -# class SweepConfig(SQLModel, table=True): -class SweepConfig(BaseModel): +class SweepConfig(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) sweep_id: str script_path: str @@ -33,8 +115,8 @@ class SweepConfig(BaseModel): simultaneous_trials: int requirements: List[str] script_args: List[str] - # distributions: Dict[str, Distributions] = Field(sa_column=Column(JSON)) - distributions: Dict[str, Any] + # TODO: How to support a List[Optional[Distributions]] here ? + distributions: Optional[Distributions] = Field(sa_column=Column(pydantic_column_type(Distributions))) framework: str cloud_compute: str num_nodes: int = 1 @@ -173,10 +255,10 @@ def run(self) -> None: repo.package() repo.upload(url=f"{url}/uploadfile/{sweep_id}") - # distributions = { - # k: Distributions(distribution={x['distribution']: Params(params=x['params'])}) - # for k, x in distributions.items() - # } + distributions = [ + Distributions(name=k, distribution=x['distribution'], params=Params(params=x['params'])) + for k, x in distributions.items() + ] config = SweepConfig( sweep_id=sweep_id, @@ -185,13 +267,12 @@ def run(self) -> None: simultaneous_trials=hparams.simultaneous_trials, requirements=hparams.requirements, script_args=script_args, - distributions=distributions, + distributions=distributions[0], framework=hparams.framework, cloud_compute=hparams.cloud_compute, num_nodes=hparams.num_nodes, logger=hparams.logger, direction=hparams.direction, ) - breakpoint() response = self.invoke_handler(config=config) print(response) From daf32599a50bdf016bdeddc925046decd9bc043d Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 3 Aug 2022 14:38:41 +0200 Subject: [PATCH 2/5] update --- lightning_hpo/commands/sweep.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lightning_hpo/commands/sweep.py b/lightning_hpo/commands/sweep.py index 484a9ac4..60dd6038 100644 --- a/lightning_hpo/commands/sweep.py +++ b/lightning_hpo/commands/sweep.py @@ -1,24 +1,26 @@ +import json import os import re import sys from argparse import ArgumentParser from getpass import getuser from pathlib import Path -from typing import Any, Dict, List, Optional, Generic, TypeVar +from typing import Dict, Generic, List, Optional, TypeVar from uuid import uuid4 import requests +from fastapi.encoders import jsonable_encoder from lightning.app.source_code import LocalSourceCodeDir from lightning.app.source_code.uploader import FileUploader from lightning.app.utilities.commands import ClientCommand +from pydantic import parse_obj_as +from pydantic.main import ModelMetaclass from sqlalchemy import Column -from pydantic import BaseModel from sqlmodel import Field, JSON, SQLModel, TypeDecorator -import json -from fastapi.encoders import jsonable_encoder T = TypeVar("T") + # Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082 def pydantic_column_type(pydantic_type): class PydanticJSONType(TypeDecorator, Generic[T]): @@ -38,7 +40,7 @@ def bind_processor(self, dialect): def process(value: T): if value is not None: - if isinstance(pydantic_type, BaseModel): + if isinstance(pydantic_type, ModelMetaclass): # This allows to assign non-InDB models and if they're # compatible, they're directly parsed into the InDB # representation, thus hiding the implementation in the @@ -52,7 +54,7 @@ def process(value: T): else: def process(value): - if isinstance(pydantic_type, BaseModel): + if isinstance(pydantic_type, ModelMetaclass): # This allows to assign non-InDB models and if they're # compatible, they're directly parsed into the InDB # representation, thus hiding the implementation in the @@ -256,7 +258,7 @@ def run(self) -> None: repo.upload(url=f"{url}/uploadfile/{sweep_id}") distributions = [ - Distributions(name=k, distribution=x['distribution'], params=Params(params=x['params'])) + Distributions(name=k, distribution=x["distribution"], params=Params(params=x["params"])) for k, x in distributions.items() ] From b74817b2e5fbf2d15ed643dc3a7c5543991ddc77 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 3 Aug 2022 16:47:58 +0200 Subject: [PATCH 3/5] update --- examples/scripts/test.db | 0 lightning_hpo/app/sweeper.py | 14 ++++++------- lightning_hpo/commands/sweep.py | 21 ++++++++++--------- lightning_hpo/components/servers/db/server.py | 13 ++++++++++++ lightning_hpo/components/sweep.py | 3 ++- lightning_hpo/distributions/distributions.py | 9 ++++++++ requirements.txt | 6 +----- 7 files changed, 43 insertions(+), 23 deletions(-) create mode 100644 examples/scripts/test.db diff --git a/examples/scripts/test.db b/examples/scripts/test.db new file mode 100644 index 00000000..e69de29b diff --git a/lightning_hpo/app/sweeper.py b/lightning_hpo/app/sweeper.py index 00959571..fb5dff7b 100644 --- a/lightning_hpo/app/sweeper.py +++ b/lightning_hpo/app/sweeper.py @@ -30,21 +30,21 @@ def run(self): self.db.run() self.db_viz.run() if self.file_server.alive() and self.db.alive(): - trials = [] for sweep in self.sweeps.values(): sweep.run() - trials.extend(sweep.get_trials()) - if trials: - for trial in trials: - breakpoint() - resp = requests.post(self.db.url + "/trial", data=trial.json()) + trials = sweep.get_trials() + if trials: + for trial in trials: + resp = requests.post(self.db.url + "/trial", data=trial.json()) + assert resp.status_code == 200 + resp = requests.put(self.db.url + f"/sweep?sweep_id={sweep.sweep_id}&num_trials={sweep.num_trials}") assert resp.status_code == 200 def create_sweep(self, config: SweepConfig) -> str: sweep_ids = list(self.sweeps.keys()) if config.sweep_id not in sweep_ids: - breakpoint() resp = requests.post(self.db.url + "/sweep", data=config.json()) + assert resp.status_code == 200 self.sweeps[config.sweep_id] = Sweep.from_config( config, code={"drive": self.drive, "name": config.sweep_id}, diff --git a/lightning_hpo/commands/sweep.py b/lightning_hpo/commands/sweep.py index 60dd6038..fc3aa4ee 100644 --- a/lightning_hpo/commands/sweep.py +++ b/lightning_hpo/commands/sweep.py @@ -100,11 +100,10 @@ def compare_values(self, x, y): class Params(SQLModel, table=False): - params: Dict[str, str] + params: Dict[str, str] = Field(sa_column=Column(pydantic_column_type(Dict[str, str]))) class Distributions(SQLModel, table=False): - name: str distribution: str params: Params = Field(sa_column=Column(pydantic_column_type(Params))) @@ -115,10 +114,12 @@ class SweepConfig(SQLModel, table=True): script_path: str n_trials: int simultaneous_trials: int - requirements: List[str] - script_args: List[str] - # TODO: How to support a List[Optional[Distributions]] here ? - distributions: Optional[Distributions] = Field(sa_column=Column(pydantic_column_type(Distributions))) + num_trials: int = 0 + requirements: List[str] = Field(..., sa_column=Column(pydantic_column_type(List[str]))) + script_args: List[str] = Field(..., sa_column=Column(pydantic_column_type(List[str]))) + distributions: Dict[str, Distributions] = Field( + ..., sa_column=Column(pydantic_column_type(Dict[str, Distributions])) + ) framework: str cloud_compute: str num_nodes: int = 1 @@ -257,10 +258,10 @@ def run(self) -> None: repo.package() repo.upload(url=f"{url}/uploadfile/{sweep_id}") - distributions = [ - Distributions(name=k, distribution=x["distribution"], params=Params(params=x["params"])) + distributions = { + k: Distributions(distribution=x["distribution"], params=Params(params=x["params"])) for k, x in distributions.items() - ] + } config = SweepConfig( sweep_id=sweep_id, @@ -269,7 +270,7 @@ def run(self) -> None: simultaneous_trials=hparams.simultaneous_trials, requirements=hparams.requirements, script_args=script_args, - distributions=distributions[0], + distributions=distributions, framework=hparams.framework, cloud_compute=hparams.cloud_compute, num_nodes=hparams.num_nodes, diff --git a/lightning_hpo/components/servers/db/server.py b/lightning_hpo/components/servers/db/server.py index e848c3c8..20fbe82f 100644 --- a/lightning_hpo/components/servers/db/server.py +++ b/lightning_hpo/components/servers/db/server.py @@ -43,6 +43,19 @@ async def insert_sweep(sweep: SweepConfig): session.refresh(sweep) return sweep + @app.put("/sweep/") + async def update_sweep(sweep_id: str, num_trials: int): + with Session(engine) as session: + statement = select(SweepConfig).where(SweepConfig.sweep_id == sweep_id) + results = session.exec(statement) + sweeps = results.all() + assert len(sweeps) == 1 + sweep = sweeps[0] + sweep.num_trials = int(num_trials) + session.add(sweep) + session.commit() + session.refresh(sweep) + @app.get("/trials/") async def collect_trials() -> List[Trial]: with Session(engine) as session: diff --git a/lightning_hpo/components/sweep.py b/lightning_hpo/components/sweep.py index 867f0063..1dc79fbb 100644 --- a/lightning_hpo/components/sweep.py +++ b/lightning_hpo/components/sweep.py @@ -11,6 +11,7 @@ from lightning_hpo.commands.sweep import SweepConfig from lightning_hpo.components.servers.db.models import Trial from lightning_hpo.distributions import Distribution +from lightning_hpo.distributions.distributions import parse_distributions from lightning_hpo.framework.agnostic import BaseObjective from lightning_hpo.loggers import LoggerType from lightning_hpo.utilities.utils import ( @@ -172,7 +173,7 @@ def from_config(cls, config: SweepConfig, code: Optional[Code] = None): simultaneous_trials=config.simultaneous_trials, framework=config.framework, script_args=config.script_args, - distributions=config.distributions, + distributions=parse_distributions(config.distributions), cloud_compute=CloudCompute(config.cloud_compute, config.num_nodes), sweep_id=config.sweep_id, code=code, diff --git a/lightning_hpo/distributions/distributions.py b/lightning_hpo/distributions/distributions.py index 4680a5bc..ad299b27 100644 --- a/lightning_hpo/distributions/distributions.py +++ b/lightning_hpo/distributions/distributions.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, TypedDict +from lightning_hpo.commands.sweep import Distributions + class DistributionDict(TypedDict): distribution: str @@ -47,3 +49,10 @@ def __init__(self, low: int, high: int, step: Optional[int] = 1) -> None: def to_dict(self) -> DistributionDict: return {"distribution": "int_uniform", "params": {"low": self.low, "high": self.high, "step": self.step}} + + +_DISTRIBUTION = {"uniform": Uniform, "int_uniform": IntUniform, "log_uniform": LogUniform, "categorical": Categorical} + + +def parse_distributions(distributions: Dict[str, Distributions]) -> Dict[str, Distribution]: + return {k: _DISTRIBUTION[v.distribution](**v.params.params) for k, v in distributions.items()} diff --git a/requirements.txt b/requirements.txt index 32bd1766..1a6781bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1 @@ -lightning -sqlmodel -streamlit -wandb -optuna +pytorch-lightning From 6761c3042ef32c8badedd68710b839accc88944e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 3 Aug 2022 16:48:20 +0200 Subject: [PATCH 4/5] update --- examples/scripts/test.db | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 examples/scripts/test.db diff --git a/examples/scripts/test.db b/examples/scripts/test.db deleted file mode 100644 index e69de29b..00000000 From 6af0582a545de29ba8c0ea96439d54c4a72f15cf Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 3 Aug 2022 16:55:18 +0200 Subject: [PATCH 5/5] update --- requirements.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1a6781bf..32bd1766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ -pytorch-lightning +lightning +sqlmodel +streamlit +wandb +optuna