In [1]:
import os
import json
import contextlib

import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker

import pandas as pd
import numpy as np

import urllib.parse

In [2]:
CONFIG_PATH = "config.json"
OUT_FILE = "traintest.pq"
VERBOSE = False

In [3]:
CONFIG = None
ENGINES = {}
TABLES = {}
BINDS = {}
SESSION = None


def config_template():
    default_conn = {
        "dialect": "postgresql",
        "host": "localhost",
        "port": 5432,
        "dbname": "INVALID",
        "schema": "public",
        "user": "INVALID",
        "passwd": "INVALID"
    }
    return {
        "dbs": {
            "login": default_conn.copy(),
            "sm": default_conn.copy(),
            "exp": default_conn.copy(),
            "ap": default_conn.copy()
        }
    }


def get_config():
    global CONFIG
    
    if CONFIG is not None:
        return CONFIG
    if not os.path.exists(CONFIG_PATH):
        with open(CONFIG_PATH, "w") as fout:
            print(json.dumps(config_template(), indent=4, sort_keys=True), file=fout)
        raise ValueError(
            f"config file missing. new file was created at '{CONFIG_PATH}'. "
            "please correct values in file and run again")
    with open(CONFIG_PATH, "r") as fin:
        CONFIG = json.load(fin)
    return CONFIG


def get_engine(dbname):
    res = ENGINES.get(dbname)
    if res is not None:
        return res
    db = get_config()["dbs"][dbname]
    user = urllib.parse.quote_plus(db["user"])
    passwd = urllib.parse.quote_plus(db["passwd"])
    engine = sa.create_engine(
        f"{db['dialect']}://{user}:{passwd}@{db['host']}:{db['port']}/{db['dbname']}",
        echo=VERBOSE)
    engine = engine.execution_options(
        schema_translate_map={None: db['schema']})
    res = engine, sa.MetaData()
    ENGINES[dbname] = res
    return res


def get_table(dbname, tablename):
    global SESSION
    
    key = (dbname, tablename)
    res = TABLES.get(key)
    if res is not None:
        return res
    SESSION = None
    engine, metadata = get_engine(dbname)
    res = sa.Table(
        tablename,
        metadata,
        autoload_with=engine)
    TABLES[key] = res
    BINDS[res] = engine
    return res


@contextlib.contextmanager
def get_session():
    global SESSION
    
    session = SESSION
    if session is None:
        session = sessionmaker()
        session.configure(binds=BINDS)
        SESSION = session
    with session() as res:
        yield res

In [4]:
# global tables
t_tags = get_table("login", "tags")

# solution mapping tables
t_sm_pads = get_table("sm", "pads")
t_sm_tagging = get_table("sm", "tagging")

# action plan tables
t_ap_pads = get_table("ap", "pads")
t_ap_tagging = get_table("ap", "tagging")

# experiments tables
t_exp_pads = get_table("exp", "pads")
t_exp_tagging = get_table("exp", "tagging")

OperationalError: (psycopg2.OperationalError) connection to server at "acclabs.postgres.database.azure.com" (13.69.105.208), port 5432 failed: FATAL:  database "login" does not exist

(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [None]:
with get_session() as session:
    stmt = sa.select(sa.func.count(t_sm_pads.c.id))
    stmt = stmt.where(t_sm_pads.c.status >= 2)
    sm_pad_count = int(session.execute(stmt).one()[0])
    print(sm_pad_count)

In [None]:
with get_session() as session:
    stmt = sa.select(sa.func.count(t_ap_pads.c.id))
    stmt = stmt.where(t_ap_pads.c.status >= 2)
    ap_pad_count = int(session.execute(stmt).one()[0])
    print(ap_pad_count)

In [None]:
with get_session() as session:
    stmt = sa.select(sa.func.count(t_exp_pads.c.id))
    stmt = stmt.where(t_exp_pads.c.status >= 2)
    exp_pad_count = int(session.execute(stmt).one()[0])
    print(exp_pad_count)

In [None]:
tags = {}
with get_session() as session:
    stmt = sa.select(t_tags.c.id, t_tags.c.name, t_tags.c.type)
    for row in session.execute(stmt):
        tags[row[0]] = (row[1], row[2])

In [None]:
with get_session() as session:
    stmt = sa.select(t_sm_pads.c.id, t_sm_pads.c.title, t_sm_pads.c.sections, t_sm_pads.c.full_text)
    stmt = stmt.where(t_sm_pads.c.status >= 2)
    stmt = stmt.limit(10)
    for row in session.execute(stmt):
        print("=TITLE=================")
        print(row[1])
        print("=TEXT==================")
        print(row[3])
        print("=TAGS==================")
        tstmt = sa.select(t_sm_tagging.c.tag_id)
        tstmt = tstmt.where(t_sm_tagging.c.pad == row[0])
        for tag in session.execute(tstmt):
            t_name, t_type = tags[tag[0]]
            if t_type == "thematic_areas":
                print(t_name)
        print()

In [None]:
sm_tag_counts = {}
with get_session() as session:
    stmt = sa.select(t_sm_tagging.c.tag_id, sa.func.count(t_sm_tagging.c.tag_id))
    stmt = stmt.group_by(t_sm_tagging.c.tag_id).order_by(sa.func.count(t_sm_tagging.c.tag_id).desc())
    for theme in [True, False]:
        for row in session.execute(stmt):
            t_name, t_type = tags[row[0]]
            if (t_type == "thematic_areas") != theme:
                continue
            print(f"{row[1] / sm_pad_count * 100.0:.2f}% {row[1]} {t_name} ({t_type})")
            sm_tag_counts[t_name] = int(row[1])
        print()

In [None]:
ap_tag_counts = {}
with get_session() as session:
    stmt = sa.select(t_ap_tagging.c.tag_id, sa.func.count(t_ap_tagging.c.tag_id))
    stmt = stmt.group_by(t_ap_tagging.c.tag_id).order_by(sa.func.count(t_ap_tagging.c.tag_id).desc())
    for theme in [True, False]:
        for row in session.execute(stmt):
            t_name, t_type = tags[row[0]]
            if (t_type == "thematic_areas") != theme:
                continue
            print(f"{row[1] / ap_pad_count * 100.0:.2f}% {row[1]} {t_name} ({t_type})")
            ap_tag_counts[t_name] = int(row[1])
        print()

In [None]:
exp_tag_counts = {}
with get_session() as session:
    stmt = sa.select(t_exp_tagging.c.tag_id, sa.func.count(t_exp_tagging.c.tag_id))
    stmt = stmt.group_by(t_exp_tagging.c.tag_id).order_by(sa.func.count(t_exp_tagging.c.tag_id).desc())
    for theme in [True, False]:
        for row in session.execute(stmt):
            t_name, t_type = tags[row[0]]
            if (t_type == "thematic_areas") != theme:
                continue
            print(f"{row[1] / exp_pad_count * 100.0:.2f}% {row[1]} {t_name} ({t_type})")
            exp_tag_counts[t_name] = int(row[1])
        print()

In [None]:
all_tags = [(tid, tname) for (tid, (tname, ttype)) in tags.items() if ttype == "thematic_areas"]
all_tags

In [None]:
rng = np.random.default_rng(42)

In [None]:
train_size = 1000
test_size = 1000
train_test_ixs = list(rng.choice(list(range(sm_pad_count)), train_size + test_size, replace=False))
is_train = set(train_test_ixs[:train_size])
is_test = set(train_test_ixs[train_size:])

In [None]:
content = {
    "stage": [],
    "db": [],
    "title": [],
    "text": [],
}
ctags = {}
ctagnames = []
for (tag_id, tag_name) in all_tags:
    col_name = f"tag_{tag_name}"
    ctagnames.append(col_name)
    ctags[tag_id] = col_name
    content[col_name] = []
tables = [
    ("sm", t_sm_pads, t_sm_tagging),
    ("ap", t_ap_pads, t_ap_tagging),
    ("exp", t_exp_pads, t_exp_tagging),
]
with get_session() as session:
    for (t_name, t_pads, t_tagging) in tables:
        stmt = sa.select(t_pads.c.id, t_pads.c.title, t_pads.c.full_text)
        stmt = stmt.where(t_pads.c.status >= 2)
        cur_ix = 0
        for row in session.execute(stmt):
            cur_ix += 1
            if t_name == "sm":
                stage = "validation"
                if cur_ix in is_train:
                    stage = "train"
                if cur_ix in is_test:
                    stage = "test"
            else:
                stage = "validation"
            content["stage"].append(stage)
            content["db"].append(t_name)
            content["title"].append(row[1])
            content["text"].append(row[2])
            for cname in ctagnames:
                content[cname].append(False)
            tstmt = sa.select(t_tagging.c.tag_id)
            tstmt = tstmt.where(t_tagging.c.pad == row[0])
            for tag in session.execute(tstmt):
                cname = ctags.get(tag[0])
                if cname is not None:
                    content[cname][-1] = True
df = pd.DataFrame(content, columns=["stage", "db", "title", "text"] + sorted(ctagnames))
final_tags = []
for cname in ctagnames:
    if df[cname].all() or not df[cname].any():
        print(f"drop {cname}")
        del df[cname]
    else:
        final_tags.append(cname)

In [None]:
df

In [None]:
for cname in final_tags:
    train_cat_count = df.loc[df["stage"] == "train", cname].sum()
    test_cat_count = df.loc[df["stage"] == "test", cname].sum()
    validation_cat_count = df.loc[df["stage"] == "validation", cname].sum()
    text = f"{train_cat_count} {test_cat_count} {validation_cat_count} {cname}"
    print(text)
    if not test_cat_count:
        print("^" * len(text))

In [None]:
df.to_parquet(OUT_FILE)