Skip to content
This repository has been archived by the owner on Apr 12, 2023. It is now read-only.

Add support for Sweep Table #19

Merged
merged 5 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 8 additions & 5 deletions lightning_hpo/app/sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ def run(self):
self.db.run()
self.db_viz.run()
if self.file_server.alive() and self.db.alive():
trials = []
for sweep in self.sweeps.values():
sweep.run()
trials.extend(sweep.get_trials())
if trials:
for trial in trials:
resp = requests.post(self.db.url + "/trial", data=trial.json())
trials = sweep.get_trials()
if trials:
for trial in trials:
resp = requests.post(self.db.url + "/trial", data=trial.json())
assert resp.status_code == 200
resp = requests.put(self.db.url + f"/sweep?sweep_id={sweep.sweep_id}&num_trials={sweep.num_trials}")
assert resp.status_code == 200

def create_sweep(self, config: SweepConfig) -> str:
sweep_ids = list(self.sweeps.keys())
if config.sweep_id not in sweep_ids:
resp = requests.post(self.db.url + "/sweep", data=config.json())
assert resp.status_code == 200
self.sweeps[config.sweep_id] = Sweep.from_config(
config,
code={"drive": self.drive, "name": config.sweep_id},
Expand Down
116 changes: 100 additions & 16 deletions lightning_hpo/commands/sweep.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,125 @@
import json
import os
import re
import sys
from argparse import ArgumentParser
from getpass import getuser
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Dict, Generic, List, Optional, TypeVar
from uuid import uuid4

import requests
from fastapi.encoders import jsonable_encoder
from lightning.app.source_code import LocalSourceCodeDir
from lightning.app.source_code.uploader import FileUploader
from lightning.app.utilities.commands import ClientCommand
from pydantic import BaseModel
from pydantic import parse_obj_as
from pydantic.main import ModelMetaclass
from sqlalchemy import Column
from sqlmodel import Field, JSON, SQLModel
from sqlmodel import Field, JSON, SQLModel, TypeDecorator

T = TypeVar("T")


# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
def pydantic_column_type(pydantic_type):
class PydanticJSONType(TypeDecorator, Generic[T]):
impl = JSON()

def __init__(
self,
json_encoder=json,
):
self.json_encoder = json_encoder
super(PydanticJSONType, self).__init__()

def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.json_encoder.dumps
if impl_processor:

def process(value: T):
if value is not None:
if isinstance(pydantic_type, ModelMetaclass):
# This allows to assign non-InDB models and if they're
# compatible, they're directly parsed into the InDB
# representation, thus hiding the implementation in the
# background. However, the InDB model will still be returned
value_to_dump = pydantic_type.from_orm(value)
else:
value_to_dump = value
value = jsonable_encoder(value_to_dump)
return impl_processor(value)

else:

def process(value):
if isinstance(pydantic_type, ModelMetaclass):
# This allows to assign non-InDB models and if they're
# compatible, they're directly parsed into the InDB
# representation, thus hiding the implementation in the
# background. However, the InDB model will still be returned
value_to_dump = pydantic_type.from_orm(value)
else:
value_to_dump = value
value = dumps(jsonable_encoder(value_to_dump))
return value

return process

def result_processor(self, dialect, coltype) -> T:
impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:

def process(value):
value = impl_processor(value)
if value is None:
return None

data = value
# Explicitly use the generic directly, not type(T)
full_obj = parse_obj_as(pydantic_type, data)
return full_obj

else:

def process(value):
if value is None:
return None

# Explicitly use the generic directly, not type(T)
full_obj = parse_obj_as(pydantic_type, value)
return full_obj

return process

def compare_values(self, x, y):
return x == y

return PydanticJSONType


class Params(SQLModel, table=False):
params: Dict[str, str]
params: Dict[str, str] = Field(sa_column=Column(pydantic_column_type(Dict[str, str])))


class Distributions(SQLModel, table=False):
distribution: Dict[str, Params] = Field(sa_column=Column(JSON))
distribution: str
params: Params = Field(sa_column=Column(pydantic_column_type(Params)))


# class SweepConfig(SQLModel, table=True):
class SweepConfig(BaseModel):
class SweepConfig(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
sweep_id: str
script_path: str
n_trials: int
simultaneous_trials: int
requirements: List[str]
script_args: List[str]
# distributions: Dict[str, Distributions] = Field(sa_column=Column(JSON))
distributions: Dict[str, Any]
num_trials: int = 0
requirements: List[str] = Field(..., sa_column=Column(pydantic_column_type(List[str])))
script_args: List[str] = Field(..., sa_column=Column(pydantic_column_type(List[str])))
distributions: Dict[str, Distributions] = Field(
..., sa_column=Column(pydantic_column_type(Dict[str, Distributions]))
)
framework: str
cloud_compute: str
num_nodes: int = 1
Expand Down Expand Up @@ -173,10 +258,10 @@ def run(self) -> None:
repo.package()
repo.upload(url=f"{url}/uploadfile/{sweep_id}")

# distributions = {
# k: Distributions(distribution={x['distribution']: Params(params=x['params'])})
# for k, x in distributions.items()
# }
distributions = {
k: Distributions(distribution=x["distribution"], params=Params(params=x["params"]))
for k, x in distributions.items()
}

config = SweepConfig(
sweep_id=sweep_id,
Expand All @@ -192,6 +277,5 @@ def run(self) -> None:
logger=hparams.logger,
direction=hparams.direction,
)
breakpoint()
response = self.invoke_handler(config=config)
print(response)
13 changes: 13 additions & 0 deletions lightning_hpo/components/servers/db/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ async def insert_sweep(sweep: SweepConfig):
session.refresh(sweep)
return sweep

@app.put("/sweep/")
async def update_sweep(sweep_id: str, num_trials: int):
with Session(engine) as session:
statement = select(SweepConfig).where(SweepConfig.sweep_id == sweep_id)
results = session.exec(statement)
sweeps = results.all()
assert len(sweeps) == 1
sweep = sweeps[0]
sweep.num_trials = int(num_trials)
session.add(sweep)
session.commit()
session.refresh(sweep)

@app.get("/trials/")
async def collect_trials() -> List[Trial]:
with Session(engine) as session:
Expand Down
3 changes: 2 additions & 1 deletion lightning_hpo/components/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lightning_hpo.commands.sweep import SweepConfig
from lightning_hpo.components.servers.db.models import Trial
from lightning_hpo.distributions import Distribution
from lightning_hpo.distributions.distributions import parse_distributions
from lightning_hpo.framework.agnostic import BaseObjective
from lightning_hpo.loggers import LoggerType
from lightning_hpo.utilities.utils import (
Expand Down Expand Up @@ -172,7 +173,7 @@ def from_config(cls, config: SweepConfig, code: Optional[Code] = None):
simultaneous_trials=config.simultaneous_trials,
framework=config.framework,
script_args=config.script_args,
distributions=config.distributions,
distributions=parse_distributions(config.distributions),
cloud_compute=CloudCompute(config.cloud_compute, config.num_nodes),
sweep_id=config.sweep_id,
code=code,
Expand Down
9 changes: 9 additions & 0 deletions lightning_hpo/distributions/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, TypedDict

from lightning_hpo.commands.sweep import Distributions


class DistributionDict(TypedDict):
distribution: str
Expand Down Expand Up @@ -47,3 +49,10 @@ def __init__(self, low: int, high: int, step: Optional[int] = 1) -> None:

def to_dict(self) -> DistributionDict:
return {"distribution": "int_uniform", "params": {"low": self.low, "high": self.high, "step": self.step}}


_DISTRIBUTION = {"uniform": Uniform, "int_uniform": IntUniform, "log_uniform": LogUniform, "categorical": Categorical}


def parse_distributions(distributions: Dict[str, Distributions]) -> Dict[str, Distribution]:
return {k: _DISTRIBUTION[v.distribution](**v.params.params) for k, v in distributions.items()}