Skip to content

Commit

Permalink
Make the schema.yaml more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartmcalpine committed Jun 3, 2024
1 parent 1118b04 commit ca87303
Show file tree
Hide file tree
Showing 5 changed files with 457 additions and 469 deletions.
263 changes: 118 additions & 145 deletions scripts/create_registry_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
}

# Load the schema from the `schema.yaml` file
schema_columns, schema_unique = load_schema()
schema_data = load_schema()
schema_data = schema_data["tables"]


def _get_column_definitions(schema, table):
Expand All @@ -59,204 +60,146 @@ def _get_column_definitions(schema, table):
"""

return_dict = {}
for column in schema_columns[table].keys():
for column in schema_data[table]["column_definitions"].keys():
# Special case where column has a foreign key
if schema_columns[table][column]["foreign_key"]:
if schema_data[table]["column_definitions"][column]["foreign_key"]:
fk_schema = schema
if schema_columns[table][column]["foreign_key_schema"] != "self":
fk_schema = schema_columns[table][column]["foreign_key_schema"]
if (
schema_data[table]["column_definitions"][column]["foreign_key_schema"]
!= "self"
):
fk_schema = schema_data[table]["column_definitions"][column][
"foreign_key_schema"
]

return_dict[column] = Column(
column,
_TYPE_TRANSLATE[schema_columns[table][column]["type"]],
_TYPE_TRANSLATE[
schema_data[table]["column_definitions"][column]["type"]
],
ForeignKey(
_get_ForeignKey_str(
fk_schema,
schema_columns[table][column]["foreign_key_table"],
schema_columns[table][column]["foreign_key_column"],
schema_data[table]["column_definitions"][column][
"foreign_key_table"
],
schema_data[table]["column_definitions"][column][
"foreign_key_column"
],
)
),
primary_key=schema_columns[table][column]["primary_key"],
nullable=schema_columns[table][column]["nullable"],
primary_key=schema_data[table]["column_definitions"][column][
"primary_key"
],
nullable=schema_data[table]["column_definitions"][column]["nullable"],
)

# Normal case
else:
return_dict[column] = Column(
column,
_TYPE_TRANSLATE[schema_columns[table][column]["type"]],
primary_key=schema_columns[table][column]["primary_key"],
nullable=schema_columns[table][column]["nullable"],
_TYPE_TRANSLATE[
schema_data[table]["column_definitions"][column]["type"]
],
primary_key=schema_data[table]["column_definitions"][column][
"primary_key"
],
nullable=schema_data[table]["column_definitions"][column]["nullable"],
)

return return_dict


class Base(DeclarativeBase):
pass


def _get_ForeignKey_str(schema, table, column):
def _get_table_metadata(schema, table):
"""
Get the string reference to the "<shema>.<table>.<column>" a foreign key
will point to.
The schema address will only be included for postgres backends.
Build the table meta data dict, e.g., the schema name and any unique
constraints, for this table.
Parameters
---------
----------
schema : str
table : str
column : str
Returns
-------
- : str
meta : dict
The table metadata
"""

if schema is None:
return f"{table}.{column}"
else:
return f"{schema}.{table}.{column}"


def _Provenance(schema):
"""Keeps track of database/schema versions."""

class_name = f"{schema}_provenance"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "provenance")

# Table metadata
meta = {"__tablename__": "provenance", "__table_args__": {"schema": schema}}

Model = type(class_name, (Base,), {**columns, **meta})
return Model


def _Execution(schema):
"""Stores executions, which datasets can be linked to."""

class_name = f"{schema}_execution"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "execution")

# Table metadata
meta = {"__tablename__": "execution", "__table_args__": {"schema": schema}}

Model = type(class_name, (Base,), {**columns, **meta})
return Model


def _ExecutionAlias(schema):
"""To asscociate an alias to an execution."""

class_name = f"{schema}_execution_alias"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "execution_alias")

# Table metadata
meta = {
"__tablename__": "execution_alias",
"__tablename__": table,
}

# Add unique constraints
if "execution_alias" in schema_unique.keys():
if (
"index" in schema_data[table].keys()
and "unique_constraints" in schema_data[table].keys()
):
meta["__table_args__"] = (
UniqueConstraint(
*schema_unique["execution_alias"]["unique_list"],
name=schema_unique["execution_alias"]["name"],
*schema_data[table]["unique_constraints"]["unique_list"],
name=schema_data[table]["unique_constraints"]["name"],
),
Index(*schema_data[table]["index"]["index_list"]),
{"schema": schema},
)
else:
meta["__table_args__"] = {"schema": schema}

Model = type(class_name, (Base,), {**columns, **meta})
return Model


def _DatasetAlias(schema):
"""To asscociate an alias to a dataset."""

class_name = f"{schema}_dataset_alias"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "dataset_alias")

# Table metadata
meta = {
"__tablename__": "dataset_alias",
}

# Add unique constraints
if "dataset_alias" in schema_unique.keys():
elif "unique_constraints" in schema_data[table].keys():
meta["__table_args__"] = (
UniqueConstraint(
*schema_unique["dataset_alias"]["unique_list"],
name=schema_unique["dataset_alias"]["name"],
*schema_data[table]["unique_constraints"]["unique_list"],
name=schema_data[table]["unique_constraints"]["name"],
),
{"schema": schema},
)
else:
meta["__table_args__"] = {"schema": schema}

Model = type(class_name, (Base,), {**columns, **meta})
return Model
return meta


def _Dataset(schema):
"""Primary table, stores dataset information."""
class Base(DeclarativeBase):
pass

class_name = f"{schema}_dataset"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "dataset")
def _get_ForeignKey_str(schema, table, column):
"""
Get the string reference to the "<shema>.<table>.<column>" a foreign key
will point to.
# Table metadata
meta = {
"__tablename__": "dataset",
}
The schema address will only be included for postgres backends.
# Add unique constraints
if "dataset" in schema_unique.keys():
meta["__table_args__"] = (
UniqueConstraint(
*schema_unique["dataset"]["unique_list"],
name=schema_unique["dataset"]["name"],
),
Index("relative_path", "owner", "owner_type"),
{"schema": schema},
)
else:
meta["__table_args__"] = (
Index("relative_path", "owner", "owner_type"),
{"schema": schema},
)
Parameters
---------
schema : str
table : str
column : str
Model = type(class_name, (Base,), {**columns, **meta})
return Model
Returns
-------
- : str
"""

if schema is None:
return f"{table}.{column}"
else:
return f"{schema}.{table}.{column}"


def _Dependency(schema, has_production, production="production"):
def _FixDependencyColumns(columns, has_production, production):
"""
Links datasets through "dependencies".
Special case for dependencies table where some column names need to be tweeked.
Columns dict is modified in place.
Parameters
----------
schema str Name of schema we're writing to
has_production boolean True if this schema refers to production schema
production string Name of production schema
columns : dict
has_production : bool
True if database has a production schema
production : str
Name of the production schema
"""

class_name = f"{schema}_dependency"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "dependency")

# Remove link to production schema if unneeded.
if not has_production:
del columns["input_production_id"]
Expand All @@ -270,8 +213,36 @@ def _Dependency(schema, has_production, production="production"):
del columns["input_production_id"]
columns["input_production_id"] = new_input_production_id

# Table metadata
meta = {"__tablename__": "dependency", "__table_args__": {"schema": schema}}

def _BuildTable(schema, table_name, has_production, production):
"""
Builds a generic schema table from the information in the `schema.yaml` file.
Parameters
----------
schema : str
table_name : str
has_production : bool
True if database has a production schema
production : str
Name of the production schema
Returns
-------
Model : class object
"""

class_name = f"{schema}_{table_name}"

# Column definitions (from `schema.yaml` file)
columns = _get_column_definitions(schema, table_name)

# Special case for dependencies table
if table_name == "dependency":
_FixDependencyColumns(columns, has_production, production)

# Table metadata (from `schema.yaml` file)
meta = _get_table_metadata(schema, table_name)

Model = type(class_name, (Base,), {**columns, **meta})
return Model
Expand Down Expand Up @@ -361,13 +332,15 @@ def _Dependency(schema, has_production, production="production"):
print(f"Could not grant access to {acct} on schema {schema}")

# Create the tables
# for SCHEMA in SCHEMA_LIST:
_Dataset(schema)
_DatasetAlias(schema)
_Dependency(schema, db_connection.dialect != "sqlite", production=prod_schema)
_Execution(schema)
_ExecutionAlias(schema)
_Provenance(schema)
for table_name in [
"dataset",
"dataset_alias",
"dependency",
"execution",
"execution_alias",
"provenance",
]:
_BuildTable(schema, table_name, db_connection.dialect != "sqlite", prod_schema)

# Generate the database
if schema:
Expand Down
10 changes: 5 additions & 5 deletions src/dataregistry/registrar/base_table_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, db_connection, root_dir, owner, owner_type):
self._DEFAULT_MAX_CONFIG = _DEFAULT_MAX_CONFIG

# Load and store the schema yaml file
self.schema_yaml, _ = load_schema()
self.schema_yaml = load_schema()

def _get_table_metadata(self, tbl):
return self._metadata_getter.get(tbl)
Expand Down Expand Up @@ -117,11 +117,11 @@ def modify(self, entry_id, modify_fields):
# Loop over each column to be modified
for key, v in modify_fields.items():
# Make sure the column is in the schema
if key not in self.schema_yaml[self.which_table].keys():
if key not in self.schema_yaml["tables"][self.which_table]["column_definitions"].keys():
raise ValueError(f"The column {key} does not exist in the schema")

# Make sure the column is modifiable
if not self.schema_yaml[self.which_table][key]["modifiable"]:
if not self.schema_yaml["tables"][self.which_table]["column_definitions"][key]["modifiable"]:
raise ValueError(f"The column {key} is not modifiable")

# Update the entries
Expand Down Expand Up @@ -183,8 +183,8 @@ def get_modifiable_columns(self):
"""

mod_list = []
for att in self.schema_yaml[self.which_table]:
if self.schema_yaml[self.which_table][att]["modifiable"]:
for att in self.schema_yaml["tables"][self.which_table]["column_definitions"]:
if self.schema_yaml["tables"][self.which_table]["column_definitions"][att]["modifiable"]:
mod_list.append(att)

return mod_list
Loading

0 comments on commit ca87303

Please sign in to comment.