In [7]:
import json

import requests

from gremlin_python.process.anonymous_traversal import traversal
from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection
from gremlin_python.driver.aiohttp.transport import AiohttpTransport 
from gremlin_python.process.graph_traversal import __
from gremlin_python.process.traversal import P
from gremlin_python.process.traversal import T
from gremlin_python.statics import long 

Инициализация бакетов.

In [None]:
import boto3


session = boto3.session.Session()
s3 = session.client(
    service_name='s3',
    endpoint_url="http://s3:9000",
    aws_access_key_id='123',
    aws_secret_access_key='12345678'
)

try:
    s3.create_bucket(Bucket='util-bucket')
except BaseException:
    pass

try:
    s3.create_bucket(Bucket='test-bucket')
except BaseException:
    pass

# Пример использования 

В примере будет рассматриваться задача классификации вершин. Данные представляют собой набор публикаций о войне в Сирии с 2011 до 2018 года (https://www.kaggle.com/datasets/mohamadalhasan/a-fake-news-dataset-around-the-syrian-war/). Вершины - это публикации (текстовый признак). Ребра связывают вершины с одинаковой датой публикации, обсуждаемым местом и издательством.  

Проверим какое количество вершин и ребер в исследуемом графе

In [None]:
g = traversal().with_remote(DriverRemoteConnection(
    'ws://janusgraph:8182/gremlin','g',
    transport_factory=lambda:AiohttpTransport(call_from_event_loop=True))
)

g.V().hasLabel("news").count().next(), g.E().hasLabel("link").count().next()

Эмбединг для текстовых данных может быть представлен как BagOfWord, Tf-Idf признак или хэш-вектор.  

Первым этапом работы будет запрос на экспорт данных из JanusGraph в S3. 

Определим адреса сервисов 

In [2]:
JANUS_GRAPH_ENDPOINT_URL = "ws://janusgraph:8182/gremlin"
S3_ENDPOINT_URL = "http://s3:9000"
EXPORT_SERVICE_URL = "http://export_service:8081/export"
PROCESSING_SERVICE_URL = "http://processing_service:8082/processing"
TRAIN_SERVICE_URL = "http://train_service:8083/modeltraining"
ENDPOINT_SERVICE_URL = "http://endpoint_service:8084/call"

In [None]:
news_export = {
    "command": "export-pg",
    "output_s3_path": S3_ENDPOINT_URL,
    "bucket": "test-bucket",
    "params": {
        "endpoint": JANUS_GRAPH_ENDPOINT_URL
    },
    "split": [0.8, 0.1, 0.1],
    "additional_params": {
        "jobs": [
            {
                "name": "news_job",
                "target": {
                    "node": "news",
                    "property": "real",
                    "type": "classification"
                },
                "features" :[
                    {
                        "node": "news",
                        "property": "text",
                        "type": "tfidf"
                    },
                    {
                        "edge": ["news", "link", "news"]
                    }
                ]
            }
        ]
    }
}

requests.post(
    EXPORT_SERVICE_URL, data=json.dumps(news_export)
)

* "output_s3_path" - путь к S3-директорию в которую помещаются экспортируемые данные.
* "bucket" - S3-бакет в который экспортируются данные.
* "params" - параметры JanusGraph сервера.
* "endpoint" - URL базы данных.
* "split" - доля разбиения данных на train, val, test.
* "additional_params" - параметры экспорта.
* "jobs" - задачи по экспорту.
* "name" - название задачи.
* "target" - целевой параметр для задачи обучения.
* "node" - лейбл целевой вершины. 
* "property" - свойства целевой вершины.
* "type" - тип задачи обучения.
* "features" - свойства вершин для обучения модели.
* "node\edge" - лейбл вершины или ребра.
* "type" - тип кодирования.

Теперь нужно преобразовать экспортированные данные в dgl-граф.

In [None]:
news_proc = {
    "s3_endpoint_url": S3_ENDPOINT_URL,
    "bucket": "test-bucket",
    "job_path": "janusgraph_ml/news_job",
    "processed_data_s3_location": "janusgraph_ml/news_processed",
    "config_file_name": "train_config.json"
}

requests.post(
    PROCESSING_SERVICE_URL, data=json.dumps(news_proc)
)

* "job_path" - путь к S3-хранилищу с экспортированными данными.
* "processed_data_s3_location" - путь к S3 хранилищу для обработанных данных.

In [None]:
news_train = {
    "s3_params": {
      "bucket": "test-bucket",
      "s3_endpoint_url": S3_ENDPOINT_URL
    },
    "train_config_s3_key": "janusgraph_ml/news_job/train_config.json",
    "processing_config_s3_key": "janusgraph_ml/news_processed/processing_config.json" 
}

requests.post(
    TRAIN_SERVICE_URL, data=json.dumps(news_train)
)

Обучается SAGEConv-модель. 

Гиперпараметры модели подбираются с использованием RandomSampling на основе параметров, заданных пользователем в файле по адресу model_hpo_config_s3_key. Если model_hpo_config_s3_key не был задан, то используются стандартные гиперпараметры (представлены ниже).

```json
{
<!--   Количество эпох обучения -->
  "max_epochs": [10, 25, 50],
  "learning_rate": [0.1, 0.01, 0.001],
  "dropout": [0.2, 0.35, 0.5],
<!--   Размер скрытого h слоя задаётся размером входа input, выхода output и парой [a, b]. h = a * input + b * outout. Множители взяты из распространённых эвристик. -->
  "n_hidden_multiplier": [[0.5, 0.5], [0.6, 1], [2, 0]],
  "n_layers": [2, 3, 4]
}
```

Веса модели и конфигурация загружаются в отдельный бакет. Создаётся эндпоинт для этой модели.

Добавим новую вершину с фейковой новостью в граф и проверим, как будут предсказаны её свойства.

In [None]:
g = traversal().with_remote(DriverRemoteConnection(
    JANUS_GRAPH_ENDPOINT_URL,'g',
    transport_factory=lambda:AiohttpTransport(call_from_event_loop=True))
)

text = 'Syria attack symptoms consistent with nerve agent use WHO,"Wed 05 Apr 2017 Syria attack symptoms consistent with nerve agent use WHO. Victims of a suspected chemical attack in Syria appeared to show symptoms consistent with reaction to a nerve agent the World Health Organization said on Wednesday. ""Some cases appear to show additional signs consistent with exposure to organophosphorus chemicals a category of chemicals that includes nerve agents"" WHO said in a statement putting the death toll at at least 70. The United States has said the deaths were caused by sarin nerve gas dropped by Syrian aircraft. Russia has said it believes poison gas had leaked from a rebel chemical weapons depot struck by Syrian bombs. Sarin is an organophosporus compound and a nerve agent. Chlorine and mustard gas which are also believed to have been used in the past in Syria are not. A Russian Defence Ministry spokesman did not say what agent was used in the attack but said the rebels had used the same chemical weapons in Aleppo last year. The WHO said it was likely that some kind of chemical was used in the attack because sufferers had no apparent external injuries and died from a rapid onset of similar symptoms including acute respiratory distress. It said its experts in Turkey were giving guidance to overwhelmed health workers in Idlib on the diagnosis and treatment of patients and medicines such as Atropine an antidote for some types of chemical exposure and steroids for symptomatic treatment had been sent. A U.N. Commission of Inquiry into human rights in Syria has previously said forces loyal to Syrian President Bashar al-Assad have used lethal chlorine gas on multiple occasions. Hundreds of civilians died in a sarin gas attack in Ghouta on the outskirts of Damascus in August 2013. Assads government has always denied responsibility for that attack. Syria agreed to destroy its chemical weapons in 2013 under a deal brokered by Moscow and Washington. But Russia a Syrian ally and China have repeatedly vetoed any United Nations move to sanction Assad or refer the situation in Syria to the International Criminal Court. ""These types of weapons are banned by international law because they represent an intolerable barbarism"" Peter Salama Executive Director of the WHO Health Emergencies Programme said in the WHO statement. - REUTERS"'
source = 'ttr'
date = '4/5/2012'
location = 'homs'


source = (
    g.addV("news")
        .property("text", text)
        .property("source", source)
        .property("date", date)
        .property("location", location)
        .property("real", '0')
        .property('temp', 'temp')
        .next()
)

(
    g.V(source).as_("a")
        .V().hasLabel("news").as_("b")
        .or_(
            __.where("a", P.eq("b")).by("source"),
            __.where("a", P.eq("b")).by("date"),
            __.where("a", P.eq("b")).by("location")
        )
        .V(long(source.id)).addE("link").to("b")
        .addE("link").from_("b").to(__.V(long(source.id)))
        .iterate()
)

prediction = (
    g.call("predict")
        .with_("endpoint_id", "news_job")
        .with_("predict_entity_idx", str(source.id))
        .with_("interface", "inductive")
        .next()
)

prediction

Добавим другую вершину и отправим запрос с трансдуктивным интерфейсом.

In [None]:
g = traversal().with_remote(DriverRemoteConnection(
    JANUS_GRAPH_ENDPOINT_URL,'g',
    transport_factory=lambda:AiohttpTransport(call_from_event_loop=True))
)

text = 'Sun 01 Feb 2015 Explosion rocks down town Damascus . An explosion inside a bus killed six people and injured another ten in down town Damascus according to preliminary reports on Sunday. Syrias government run TV said the explosion took place in Al Kalassa region of the Syrian capital and that 19 people were injured and an unspecified number of others killed.'
source = 'nna'
date = '2/1/2015'
location = 'damascus'


source = (
    g.addV("news")
        .property("text", text)
        .property("source", source)
        .property("date", date)
        .property("location", location)
        .property("real", '1')
        .property('temp', 'temp')
        .next()
)

(
    g.V(source).as_("a")
        .V().hasLabel("news").as_("b")
        .or_(
            __.where("a", P.eq("b")).by("source"),
            __.where("a", P.eq("b")).by("date"),
            __.where("a", P.eq("b")).by("location")
        )
        .V(long(source.id)).addE("link").to("b")
        .addE("link").from_("b").to(__.V(long(source.id)))
        .iterate()
)

prediction = (
    g.call("predict")
        .with_("endpoint_id", "news_job")
        .with_("predict_entity_idx", str(source.id))
        .with_("interface", "transductive")
        .next()
)

prediction