In [1]:
# SDK imports
from rai_gnns_experimental.core.types import ColumnDType
from rai_gnns_experimental.core.metrics import EvaluationMetric
from rai_gnns_experimental.core.gnn_table import GNNTable, ForeignKey
from rai_gnns_experimental.core.task import NodeTask, TaskType
from rai_gnns_experimental.core.trainer import Trainer
from rai_gnns_experimental.core.config_trainer import TrainerConfig
from rai_gnns_experimental.core.dataset import Dataset
from rai_gnns_experimental.core.connector import LocalConnector
from rai_gnns_experimental.core.job_manager import JobManager
from rai_gnns_experimental.common.export import OutputConfig

from IPython.display import Image, display

# Jupyter magic commands
%load_ext autoreload
%load_ext jupyter_black
%autoreload 2

## Node Classification Example (Synthetic Data)

In [2]:
# Testing logic with local files
connector_name = "parquet"
class_table_pth = "/data/benchmark_datasets/synthetic_academic_ranking/data/classes.pqt"
student_takes_class_pth = (
    "/data/benchmark_datasets/synthetic_academic_ranking/data/student_takes_class.pqt"
)
students_pth = "/data/benchmark_datasets/synthetic_academic_ranking/data/students.pqt"
train_tbl_pth = (
    "/data/benchmark_datasets/synthetic_academic_ranking/tasks/student_rank_train.pqt"
)
val_tbl_pth = "/data/benchmark_datasets/synthetic_academic_ranking/tasks/student_rank_validation.pqt"
test_tbl_pth = (
    "/data/benchmark_datasets/synthetic_academic_ranking/tasks/student_rank_test.pqt"
)

In [3]:
# create local connector
connector = LocalConnector(
    url="http://localhost", port=8024, connector_type=connector_name
)

### 1. Create data tables

In [None]:
# Option 1. Define everything that you need and create table
student_table = GNNTable(
    connector=connector, name="Students", source=students_pth, primary_key="studentId"
)

# Option 2. Define table and then  add primary key
class_table = GNNTable(connector=connector, name="Classes", source=class_table_pth)
# error when we validate
# class_table.validate_table()
# adding a primary key that does not exist will throw an assertion error
# class_table.set_primary_key("foo")
class_table.set_primary_key("classId")

# Example add one foreign key then set the next one
student_takes_class_table = GNNTable(
    connector=connector,
    name="StudentsTakeClass",
    source=student_takes_class_pth,
    foreign_keys=[ForeignKey(column_name="studentId", link_to="Students.studentId")],
)
student_takes_class_table.set_foreign_key(
    ForeignKey(column_name="classId", link_to="Classes.classId")
)

# describe table
class_table.show_table()
# drop column from metadata
class_table.remove_column(col_name="credits")
class_table.show_table()
# add it back
class_table.add_column(col_name="credits", dtype=ColumnDType.float_t)
class_table.show_table()

### 2. Create a node classification task

In [None]:
# If we do not set a primary key for the table, we use the source_entity_column instead
# That implies that the task table primary key column has the same name as the
# source entity column in the source entity table
node_task = NodeTask(
    connector=connector,
    name="my_node_task",
    task_data_source={
        "train": train_tbl_pth,
        "test": test_tbl_pth,
        "validation": val_tbl_pth,
    },
    source_entity_column="studentId",
    source_entity_table="Students",
    target_column="label",
    task_type=TaskType.BINARY_CLASSIFICATION,
    current_time=True,
)
node_task.set_evaluation_metric(EvaluationMetric(name="accuracy"))
node_task.show_task()

### 3. Dataset setup

In [6]:
dataset = Dataset(
    connector=connector,
    dataset_name="toy_dataset_v2",
    tables=[student_table, student_takes_class_table, class_table],
    task_description=node_task,
)

In [None]:
graph = dataset.visualize_dataset()

plt = Image(graph.create_png())
display(plt)

In [None]:
dataset.experiment_name

### 4. Trainer setup

In [9]:
trainer_config = TrainerConfig(connector=connector, device="cuda", n_epochs=2)

In [None]:
trainer = Trainer(connector=connector, config=trainer_config)

### 5. Train the model

In [12]:
train_job = trainer.fit(dataset=dataset)

In [None]:
train_job.get_status()

In [None]:
train_job.stream_logs()

In [None]:
train_job.get_status()

In [None]:
train_job.model_run_id

In [None]:
train_job.experiment_name

In [None]:
train_job.register_model("test_model_mlflow")

In [None]:
train_job.registered_models

### 6. Inference

In [None]:
inference_job1 = trainer.predict(
    dataset=dataset,
    # experiment_name="toy_dataset_v2/binary_classification/my_node_task",
    # test_table=test_tbl_pth,
    model_run_id=train_job.model_run_id,
    output_config=OutputConfig.local(
        artifacts_dir="/data/dafni/artifacts/",
        extension="parquet",
    ),
    output_alias="foo",
    extract_embeddings=True,
)

In [None]:
inference_job1.model_run_id

In [None]:
inference_job1.get_status()

In [None]:
inference_job2 = trainer.predict(
    # dataset=dataset,
    experiment_name="toy_dataset_v2/binary_classification/my_node_task",
    test_table=test_tbl_pth,
    select_best_model=True,
    evaluation_metric=EvaluationMetric(name="accuracy"),
    output_config=OutputConfig.local(
        artifacts_dir="/data/dafni/artifacts/",
        extension="parquet",
    ),
    output_alias="test",
    extract_embeddings=True,
)

In [None]:
inference_job2.get_status()

In [None]:
inference_job2.model_run_id

In [None]:
inference_job3 = trainer.predict(
    # dataset=dataset,
    experiment_name="toy_dataset_v2/binary_classification/my_node_task",
    test_table=test_tbl_pth,
    registered_model_name="test_model_mlflow",
    version="10",
    output_config=OutputConfig.local(
        artifacts_dir="/data/dafni/artifacts/",
        extension="parquet",
    ),
    output_alias="foo",
    extract_embeddings=True,
)

In [None]:
inference_job3.get_status()

In [None]:
inference_job3.model_run_id

### 7. Train and Inference

In [None]:
# Snowflake output
output_config = OutputConfig.snowflake(
    database_name="SYNTHETIC_ACADEMIC_RANKING_DB", schema_name="PUBLIC"
)

train_inf_job = trainer.fit_predict(
    dataset=dataset,
    output_config=output_config,
    output_alias="foo",
    extract_embeddings=True,
)

In [None]:
# Local file output
output_config = OutputConfig.local(
    artifacts_dir="/data/haythem/gnn-learning-engine/",
    extension="parquet",
)

train_inf_job = trainer.fit_predict(
    dataset=dataset,
    output_config=output_config,
    output_alias="foo",
    extract_embeddings=True,
)

In [None]:
train_inf_job.get_status()

In [None]:
train_inf_job.model_run_id

In [None]:
train_inf_job.register_model("test_model_mlflow")

### 8. Job Manager

In [None]:
job_manager = JobManager(connector=connector)

In [None]:
train_job = trainer.fit(dataset=dataset)

In [None]:
inference_job1 = trainer.predict(
    dataset=dataset,
    model_run_id=train_job.model_run_id,
    output_config=OutputConfig.local(
        artifacts_dir="/data/dafni/artifacts/",
        extension="parquet",
    ),
    output_alias="foo",
)

In [None]:
train_job_2 = trainer.fit(dataset=dataset)
train_job_3 = trainer.fit(dataset=dataset)
train_job_4 = trainer.fit(dataset=dataset)
train_job_5 = trainer.fit(dataset=dataset)

In [None]:
job_manager.show_jobs()

In [None]:
finished_job = job_manager.fetch_job(job_id="e3950b85-2f9d-44c3-9b33-050f69f7ee01")

In [None]:
finished_job.model_run_id

In [None]:
train_job_5.get_status()

In [None]:
train_job_4.cancel()

In [None]:
job_manager.show_jobs()

In [None]:
job_manager.cancel_job(train_job_3.job_id)
job_manager.show_jobs()

In [None]:
copy_of_train_job_5 = job_manager.fetch_job(train_job_5.job_id)

In [None]:
copy_of_train_job_5.cancel()

In [None]:
train_job_5.get_status()