In [1]:
from typing import Union
import traceback
from sqlalchemy import create_engine, inspect, func, select, Table, MetaData
import pandas as pd

In [2]:
import time
import jwt
import requests
from numpy import dot
from numpy.linalg import norm


KEY = '7bf001734ef2fd7f7a55bf51dadd7cbb.BMAsoKRDFTmTEPwj'

# 实际KEY，过期时间


def generate_token(apikey: str, exp_seconds: int):
    try:
        id, secret = apikey.split('.')
    except Exception as e:
        raise Exception('invalid apikey', e)

    payload = {
        'api_key': id,
        'exp': int(round(time.time() * 1000)) + exp_seconds * 1000,
        'timestamp': int(round(time.time() * 1000)),
    }
    return jwt.encode(
        payload,
        secret,
        algorithm='HS256',
        headers={'alg': 'HS256', 'sign_type': 'SIGN'},
    )

In [3]:
def ask_glm(content):
    url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
    headers = {
      'Content-Type': 'application/json',
      'Authorization': generate_token(KEY, 1000)
    }

    data = {
        "model": "glm-3-turbo",
        "messages": [{"role": "user", "content": content}]
    }

    response = requests.post(url, headers=headers, json=data)
    return response.json()

In [23]:
from collections import defaultdict
from typing import Tuple
import pandas as pd

class DBParser:

    def __init__(self, url: str) -> None:
        self.db_type = 'sqlite'

        # 链接数据库
        self.engine = create_engine(url, echo=False)
        self.connect = self.engine.connect()
        self.url = url

        self.inspector = inspect(self.engine)
        self.table_names = self.inspector.get_table_names()

        self._table_fields = defaultdict(dict) # 记录数据表字段
        self._table_sample = dict() # 记录表的样例

        # 遍历所有的表
        for table_name in self.table_names:
            print('Table -> ', table_name)

            table_instance = Table(table_name, MetaData(), autoload_with=self.engine)
            table_columns = self.inspector.get_columns(table_name)
            
            # 记录每张表的字段
            for column in table_columns:
                self._table_fields[table_name][column['name']] = column
                
                column_data = getattr(table_instance.columns, column['name'])

                # 统计unique
                query = select(func.count(func.distinct(column_data)))
                distinct_count = self.connect.execute(query).fetchone()[0]
                self._table_fields[table_name][column['name']]['dictinct'] = distinct_count

                # 统计missing
                query = select(func.count()).filter(column_data is None)
                nan_count = self.connect.execute(query).fetchone()[0]
                self._table_fields[table_name][column['name']]['nan_count'] = nan_count

            
            query = select(table_instance).order_by(func.random()).limit(5)
            self._table_sample[table_name] = pd.DataFrame(self.connect.execute(query).fetchall(), columns=[column['name'] for column in table_columns])

    def get_table_fields(self, table_name: str) -> pd.DataFrame:
        return pd.DataFrame.from_dict(self._table_fields[table_name])

    def get_data_relations(self, table_name: str) -> pd.DataFrame:
        return pd.DataFrame.from_dict(self._foreign_keys[table_name])

    def get_table_sample(self, table_name: str) -> pd.DataFrame:
        return self._table_sample[table_name]

    def check_sql(self, sql) -> Tuple[bool, str]:
        try:
            self.engine.execute(sql)
            return True, 'ok'
        except:
            err_msg = traceback.format_exc()
            return False, err_msg

    def execute_sql(self, sql):
        result = self.engine.execute(sql)
        return list(result)

                



In [24]:
path = 'D:/yyk/competition/bs_challenge_financial_14b_dataset/dataset/博金杯比赛数据.db'
parser = DBParser(f'sqlite:///{path}')

Table ->  A股公司行业划分表
Table ->  A股票日行情表
Table ->  基金份额持有人结构
Table ->  基金债券持仓明细
Table ->  基金可转债持仓明细
Table ->  基金基本信息
Table ->  基金日行情表
Table ->  基金股票持仓明细
Table ->  基金规模变动表
Table ->  港股票日行情表


In [25]:
parser.get_table_sample('A股票日行情表')

Unnamed: 0,股票代码,交易日,昨收盘(元),今开盘(元),最高价(元),最低价(元),收盘价(元),成交量(股),成交金额(元)
0,27,20210506,8.47,8.58,8.94,8.5,8.79,127112422.0,1111651000.0
1,603879,20190701,8.02,8.08,8.38,8.08,8.25,2939315.0,24168350.0
2,601101,20200226,4.11,4.07,4.25,4.05,4.19,13437030.0,56178200.0
3,688228,20210517,34.77,34.55,34.62,34.0,34.24,238386.0,8163225.0
4,23,20201022,17.17,17.2,17.51,16.89,17.29,647600.0,11149770.0


In [26]:
parser.get_table_fields('基金规模变动表')

Unnamed: 0,基金代码,基金简称,公告日期,截止日期,报告期期初基金总份额,报告期基金总申购份额,报告期基金总赎回份额,报告期期末基金总份额,定期报告所属年度,报告类型
name,基金代码,基金简称,公告日期,截止日期,报告期期初基金总份额,报告期基金总申购份额,报告期基金总赎回份额,报告期期末基金总份额,定期报告所属年度,报告类型
type,TEXT,TEXT,TIMESTAMP,TIMESTAMP,REAL,REAL,REAL,REAL,INTEGER,TEXT
nullable,True,True,True,True,True,True,True,True,True,True
default,,,,,,,,,,
primary_key,0,0,0,0,0,0,0,0,0,0
dictinct,4340,4340,104,12,29077,27003,25739,29898,3,1
nan_count,0,0,0,0,0,0,0,0,0,0
