<a href="https://colab.research.google.com/github/Kentaro-Kamaishi/AICompositionSupport/blob/main/Pitman_Yor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 自己流 

In [1]:
# params
HPYLM_INITIAL_D = 0.5
HPYLM_INITIAL_THETA = 2.0
HPYLM_INITIAL_A = 1.0
HPYLM_INITIAL_B = 1.0
HPYLM_INITIAL_ALPHA = 1.0
HPYLM_INITIAL_BETA = 1.0

VPYLM_BETA_STOP = 4
VPYLM_BETA_PASS = 1

ID_BOS = 0
ID_EOS = 1 

In [95]:
import numpy as np

# PSTのノード：文脈uに対応する
# word:customer, 
class Node:
    def __init__(self, id):
        self.parent = None
        self.depth = 0
        self.stop_count = 0
        self.pass_count = 0
        self.children = {}
        self.arrangement = {} #key:id, value:list #cus per table
        self.num_customers = 0 #nodeが表すhistoryがgivenのときの総customer数
        self.num_tables = 0 #nodeが表すhistoryがgivenのときの総table数
        self.token_id = id #このノードが表す単語のid

    """
    d_m: d per depth, theta_m: theta per depth 
    g0: 基底分布=語彙の一様分布
    """
    def Pw_given_h(self, w_id, g0, d_m, theta_m):
        self.init_hyperparams_at_depth_if_needed(self.depth, d_m, theta_m)
        d_u = d_m[self.depth]
        theta_u = theta_m[self.depth]
        t_u = self.num_tables
        c_u = self.num_customers
        # 写経メモ：c++における関数findはkeyのポインタを返す keyがない場合はendが返ってくる 
        if w_id not in self.arrangement.keys():
            coeff = (theta_u + d_u*t_u)/(theta_u+c_u)
            if self.parent != None: # depth > 1の場合
                return self.parent.Pw_given_h(w_id, g0, d_m, theta_m) * coeff
            return g0*coeff # depth == 1
        
        # calculate P(w|h')
        # w_idに対応するアトムから再びw_idが生成される確率と基底分布からw_idが生成される確率を足す
        parent_Pw = g0
        if self.parent != None:
            parent_Pw = self.parent.Pw_given_h(w_id, g0, d_m, theta_m)
        # 写経メモ: c++におけるsecondは辞書のvalueを表す
        num_customers_at_table = self.arrangement[w_id]
        #print(num_customers_at_table)
        c_uw = sum(num_customers_at_table)
        t_uw = len(num_customers_at_table)
        already = max(0, c_uw-d_u*t_uw)/(theta_u+c_u)
        new = (theta_u+d_u*t_u)/(theta_u+c_u)
        return already + new * parent_Pw
    
    # 疑問: Pw_given_hと統合しないのはなぜ？ オプション引数を追加すればできそう
    def Pw_given_h_with_parent_Pw(self, token_id, parent_Pw, d_m, theta_m):
        self.init_hyperparams_at_depth_if_needed(self.depth, d_m, theta_m)
        d_u = d_m[self.depth]
        theta_u = theta_m[self.depth]
        t_u = self.num_tables
        c_u = self.num_customers

        if token_id not in self.arrangement.keys():
            coeff = (theta_u + d_u * t_u)/(theta_u+c_u)
            return parent_Pw * coeff
        
        num_customers_at_table = self.arrangement[token_id]
        c_uw = sum(num_customers_at_table)
        t_uw = len(num_customers_at_table)
        already = max(0, c_uw-d_u*t_uw)/(theta_u+c_u)
        new = (theta_u+d_u*t_u)/(theta_u+c_u)
        return already + new * parent_Pw
    
    ### ?: what does the return value mean?
    def add_customer(self, w_id, g0, d_m, theta_m, update_beta_count=True):
        self.init_hyperparams_at_depth_if_needed(self.depth, d_m, theta_m)
        d_u = d_m[self.depth]
        theta_u = theta_m[self.depth]
        parent_Pw = g0
        if self.parent != None:
            parent_Pw = self.parent.Pw_given_h(w_id, g0, d_m, theta_m)
        
        if w_id not in self.arrangement.keys():
            self.add_customer_to_new_table(w_id, g0, d_m, theta_m)
            if update_beta_count:
                self.increment_stop_count()
            return True
        
        num_customers_at_table = self.arrangement[w_id]
        # sum: 正規化項 sum(c-d)+(theta+dt)の計算
        sum = 0
        for k in range(len(num_customers_at_table)):
            sum += max(0, num_customers_at_table[k] - d_u)
        t_u = self.num_tables
        sum += (theta_u + d_u * t_u) * parent_Pw

        normalizer = 1/sum
        bernoulli = np.random.uniform(0,1)
        stack = 0
        # max(0, c-d)の比率で既存のテーブルに追加, theta+dtの比率で新規テーブルに追加
        # ※既存のテーブルとは、w_idにとっての既存のテーブルであり、全ての単語に対して初出とは限らない
        for k in range(len(num_customers_at_table)):
            # t[k] = max(0, c[k-1]-d)とすると、
            # t[k]<= bernoulli <= t[k]+max(0, c[k]-d)のときに追加する
            stack += max(0, num_customers_at_table[k] - d_u) * normalizer
            if bernoulli <= stack:
                self.add_customer_to_table(w_id, k, g0, d_m, theta_m)
                if update_beta_count:
                    self.increment_stop_count()
                return True
 
        self.add_customer_to_new_table(w_id, g0, d_m, theta_m)
        if update_beta_count:
            self.increment_stop_count()
        return True

    def remove_customer(self, w_id, update_beta_count=True):
        if w_id not in self.arrangement.keys():
            return True
        c_uw = self.arrangement[w_id]

        bernoulli = np.random.uniform(0, 1) #なぜ一様分布なのにbernoulliなのか
        normalizer = 1/sum(c_uw)
        stack = 0
        for k in range(len(c_uw)):
            stack += c_uw[k]*normalizer
            if bernoulli <= stack:
                self.remove_customer_from_table(w_id, k)
                if update_beta_count:
                    self.decrement_stop_count()
                return True
        self.remove_customer_from_table(w_id, len(c_uw)-1)
        if update_beta_count:
            self.decrement_stop_count()
        return True

    def add_customer_to_table(self, w_id, k, g0, d_m, theta_m):
        # 疑問: この関数を呼んでいる時点でtable kは既存なのでは？
        if w_id not in self.arrangement.keys():
            return self.add_customer_to_new_table(w_id, g0, d_m, theta_m)
        self.arrangement[w_id][k] += 1
        self.num_customers += 1
        return True

    def add_customer_to_new_table(self, w_id, g0, d_m, theta_m):
        if w_id not in self.arrangement.keys():
            self.arrangement[w_id] = [1]
        else:
            self.arrangement[w_id].append(1)
        self.num_tables += 1
        self.num_customers += 1
        if self.parent != None:
            self.parent.add_customer(w_id, g0, d_m, theta_m, False)
        return True
    
    def remove_customer_from_table(self, w_id, k):
        c_uw = self.arrangement[w_id]
        c_uw[k] -= 1
        self.num_customers -= 1
        if c_uw[k] == 0:
            if self.parent != None:
                self.parent.remove_customer(w_id, False)
            c_uw.pop(k)
            self.num_tables -= 1
            if len(c_uw) == 0:
                self.arrangement.pop(w_id)
        return True

    # 自分のstop countを1増やした後、先祖のpass countも1増やす
    # 再帰は重いので、とりあえず先祖のpass++はループで
    def increment_stop_count(self):
        self.stop_count += 1
        parent = self.parent
        while parent != None:
            parent.pass_count += 1
            parent = parent.parent
        return True
    
    # 自分のstop countを1減らした後、先祖のpass countも1減らす
    def decrement_stop_count(self):
        self.stop_count -= 1
        parent = self.parent
        while parent != None:
            parent.pass_count -= 1
            parent = parent.parent
        return True
    
    def find_child_node(self, token_id, generate_if_not_exist=False):
        if token_id in self.children: # 子リストにidがある
            return self.children[token_id]
        # 以下は子リストにidがない場合
        if not generate_if_not_exist: # 子を新規作成しない場合
            return None
        # 新規作成する場合
        child = Node(token_id)
        child.parent = self # これで書き方いいのかな？ thisのつもり
        child.depth = self.depth+1
        self.children[token_id] = child
        return child

    # 深さdepthまでのhpを初期化
    def init_hyperparams_at_depth_if_needed(self, depth, d_m, theta_m):
        if depth >= len(d_m):
            while depth >= len(d_m):
                d_m.append(HPYLM_INITIAL_D)
            while depth >= len(theta_m): # dと分ける必要ある？
                theta_m.append(HPYLM_INITIAL_THETA)
    
    def need_to_remove_from_parent(self):
        if self.parent is None:
            return False
        if len(self.children) == 0 and len(self.arrangement) == 0: # 子もテーブルもない
            return True
        return False
    
    def remove_from_parent(self):
        if self.parent is None:
            return False
        self.parent.delete_child_node(self.token_id)
        return True
    
    def delete_child_node(self, token_id):
        child = self.find_child_node(token_id)
        if child: 
            del self.children[token_id]
        #if self.need_to_remove_from_parent(): # シンプルなif文から変更
        if len(self.children) == 0 and len(self.arrangement) == 0:
            self.remove_from_parent()
    
    def stop_probability(self, beta_stop, beta_pass, recursive=True):
        p = (self.stop_count + beta_stop) / (self.stop_count + self.pass_count + beta_stop + beta_pass)
        if not recursive:
            return p
        # このノードでbeta_stop回止まった＝先祖でbeta_stop回通った
        if not self.parent is None:
            p *= self.parent.pass_probability(beta_stop, beta_pass)
        return p
    
    def pass_probability(self, beta_stop, beta_pass, recursive=True):
        p = (self.pass_count + beta_pass) / (self.stop_count + self.pass_count + beta_stop + beta_pass)
        if not recursive:
            return p
        # このノードでbeta_stop回止まった＝先祖でbeta_stop回通った
        if not self.parent is None:
            p *= self.parent.pass_probability(beta_stop, beta_pass)
        return p
    
    def auxiliary_log_x_u(self, theta_u):
        if self.num_customers >= 2:
            x_u = np.random.beta(theta_u+1, self.num_customers-1)
            return np.log(x_u+1e-8)
        return 0
    
    def auxiliary_y_ui(self, theta_u, d_u):
        sum_y_ui = 0
        if self.num_tables >= 2:
            for i in range(1, self.num_tables):
                y_ui = np.random.binomial(1, theta_u/(theta_u+d_u*i))
                sum_y_ui += y_ui
        return sum_y_ui
        
    def auxiliary_1_y_ui(self, theta_u, d_u): #1-y_ui
        sum_1_y_ui = 0
        if self.num_tables >= 2:
            for i in range(1, self.num_tables):
                y_ui = np.random.binomial(1, theta_u/(theta_u+d_u*i))
                sum_1_y_ui += 1-y_ui
        return sum_1_y_ui
    
    def auxiliary_1_z_uwkj(self, d_u):
        sum_1_z_uwkj = 0
        for c_uw in self.arrangement.values():
            for c_uwk in c_uw:
                if c_uwk >= 2:
                    for j in range(1, c_uwk):
                        z_uwkj = np.random.binomial(1, (j-1)/(j-d_u))
                        sum_1_z_uwkj += 1-z_uwkj
        return sum_1_z_uwkj
    
    ### モニター用
    def get_num_nodes(self):
        # 自分は数えない
        num = len(self.children)
        for child in self.children.values():
            num += child.get_num_nodes()
        return num
    
    def get_num_customers(self):
        num = 0
        for table in self.arrangement.values():
            num += sum(table)
        for child in self.children.values():
            num += child.get_num_customers()
        return num

In [68]:
# PST
import math
class VPYLM:
    def __init__(self):
        # w_id=0が空文字/空語？
        self.root = Node(0)
        self.depth = 0 # 最深部の深さ
        self.d_m = []
        self.theta_m = []
        self.g0 = 0

        self.a_m = [] # ベータ分布のパラメータ dの推定用
        self.b_m = [] # ベータ分布のパラメータ dの推定用
        self.alpha_m = [] # ガンマ分布のパラメータ thetaの推定用
        self.beta_m = [] # ガンマ分布のパラメータ thetaの推定用
        self.beta_stop = VPYLM_BETA_STOP
        self.beta_pass = VPYLM_BETA_PASS
        self.max_depth = 999
        self.sampling_table = [0]*self.max_depth
    
    # _VPYLM

    def delete_node(self, node):
        for child in node.children.values():
            self.delete_node(child)
        del node

    # idがtoken_ids[token_t_index]のcustomerを追加する
    def add_customer_at_timestep(self, token_ids, token_t_index, depth_t):
        node = self.find_node_by_tracing_back_context(token_ids, token_t_index, 
                                                      depth_t, True)
        token_t = token_ids[token_t_index]
        return node.add_customer(token_t, self.g0, self.d_m, self.theta_m)
    
    def remove_customer_at_timestep(self, token_ids, token_t_index, depth_t):
        node = self.find_node_by_tracing_back_context(token_ids, token_t_index,
                                                      depth_t, True)
        token_t = token_ids[token_t_index]
        node.remove_customer(token_t)
        if node.need_to_remove_from_parent():
            node.remove_from_parent()
        return True

    def sample_depth_at_timestep(self, context_token_ids, token_t_index):
        if token_t_index == 0:
            return 0
        token_t = context_token_ids[token_t_index]

        eps = 1e-24
        sum = 0
        p_pass = 1
        pw = 0
        sampling_table_size = 0
        node = self.root

        # sampling_table[l]=p(n_t=l | w, n-, s-) 
        for n in range(0, token_t_index+1):
            if node:
                pw = node.Pw_given_h(token_t, self.g0, self.d_m, self.theta_m)
                p_stop = node.stop_probability(self.beta_stop, self.beta_pass, False) * p_pass
                p = pw * p_stop # p(n_t|w, n-, s-)=p(w_t|w-, n, s-)p(n_t|w-, n-, s-)
                p_pass *= node.pass_probability(self.beta_stop, self.beta_pass, False)
                self.sampling_table[n] = p
                sampling_table_size += 1
                sum += p
                if p_stop < eps:
                    break
                if n < token_t_index: # token_t_index以外のときは、履歴内の一つ前の語に対応するノードにする
                    context_token_id = context_token_ids[token_t_index - n - 1]
                    node = node.find_child_node(context_token_id)
            else: # 後続の子がない = token_t_indexの履歴としてここまでしか考えられない（?）
                p_stop = p_pass * self.beta_stop/(self.beta_stop + self.beta_pass)
                p = pw * p_stop
                self.sampling_table[n] = p
                sampling_table_size += 1
                sum += p
                p_pass *= self.beta_pass / (self.beta_stop + self.beta_pass)
                if p_stop < eps:
                    break

        normalizer = 1 / sum
        bernoulli = np.random.uniform(0, 1)
        stack = 0
        for n in range(sampling_table_size):
            stack += self.sampling_table[n] * normalizer
            if bernoulli < stack:
                return n
        return self.sampling_table[sampling_table_size-1]

    # token_ids: 学習データの部分トークンid列
    # token_t_index: 探索対象のトークンid
    # order_t: 文脈長
    def find_node_by_tracing_back_context(self, token_ids, token_t_index, order_t, 
                        generate_node_if_needed=False, return_middle_node=False):
        if token_t_index - order_t < 0:
            return None
        
        node = self.root
        # token_t_indexに近い方がルートに近くなることに注意
        for depth in range(1, order_t+1):
            context_token_id = token_ids[token_t_index-depth]
            child = node.find_child_node(context_token_id, generate_node_if_needed)
            if child is None:
                if return_middle_node:
                    return node
                return None
            node = child
        return node
    
    def Pw_given_h(self, token_id, context_token_ids):
        node = self.root
        eps = 1e-24
        parent_Pw = self.g0
        pw_h = 0
        p_pass = 1
        p_stop = 1
        depth = 0

        while p_stop > eps:
            if node:
                # pwはp(w|n,h)
                pw = node.Pw_given_h_with_parent_Pw(token_id, parent_Pw, self.d_m, self.theta_m)
                p_stop = p_pass * node.stop_probability(self.beta_stop, self.beta_pass, False)
                p_pass *= node.pass_probability(self.beta_stop, self.beta_pass, False)
                pw_h += pw * p_stop
                parent_Pw = pw
                if depth < len(context_token_ids):
                    context_token_id = context_token_ids[len(context_token_ids)-depth-1]
                    node = node.find_child_node(context_token_id)
                else: # 文脈がcontext_idsにはこれ以上ない
                    node = None
            else: # nodeがNoneの場合
                p_stop = p_pass * self.beta_stop/(self.beta_stop + self.beta_pass)
                pw_h += parent_Pw * p_stop # 疑問: 基底分布はそのまま使い続けていいのか？ よさそうだけど絶対これじゃなきゃダメなのかな？
                p_pass *= self.beta_pass / (self.beta_stop + self.beta_pass)
            depth += 1 
        
        return pw_h
    
    # 深さ優先探索で深さ1以上の全てのノードを探索 auxiliary_termsを計算
    def sum_auxiliary_terms(self, node, sum_log_x_u_m, sum_y_ui_m, sum_1_y_ui_m, sum_1_z_uwkj_m):
        for child in node.children.values():
            depth = child.depth
            d = self.d_m[depth]
            theta = self.theta_m[depth]

            #print(f'depth={depth}')
            sum_log_x_u_m[depth] += child.auxiliary_log_x_u(theta)
            sum_y_ui_m[depth] += child.auxiliary_y_ui(d, theta)
            sum_1_y_ui_m[depth] += child.auxiliary_1_y_ui(d, theta)
            sum_1_z_uwkj_m[depth] += child.auxiliary_1_z_uwkj(d)

            self.sum_auxiliary_terms(child, sum_log_x_u_m, sum_y_ui_m, sum_1_y_ui_m, sum_1_z_uwkj_m)

    def sample_hyperparams(self):
        max_depth = self.get_depth()
        #print(f'max_depth={max_depth}')
        sum_log_x_u_m = [0]*(1+max_depth) #+1はルートの分
        sum_y_ui_m = [0]*(1+max_depth)
        sum_1_y_ui_m = [0]*(1+max_depth)
        sum_1_z_uwkj_m = [0]*(1+max_depth)

        # root
        sum_log_x_u_m[0] = self.root.auxiliary_log_x_u(self.theta_m[0])
        sum_y_ui_m[0] = self.root.auxiliary_y_ui(self.d_m[0], self.theta_m[0])
        sum_1_y_ui_m[0] = self.root.auxiliary_1_y_ui(self.d_m[0], self.theta_m[0])
        sum_1_z_uwkj_m[0] = self.root.auxiliary_1_z_uwkj(self.d_m[0])
        
        # 深さ1以上
        self.init_hyperparameters_at_depth_if_needed(max_depth)
        self.sum_auxiliary_terms(self.root, sum_log_x_u_m, sum_y_ui_m, \
                                 sum_1_y_ui_m, sum_1_z_uwkj_m)
        
        for m in range(max_depth+1):
            self.d_m[m] = np.random.beta(self.a_m[m]+sum_1_y_ui_m[m], self.b_m[m]+sum_1_z_uwkj_m[m])
            # Teh論文はパラメータを(shape, rate)とする方のGamma distributionの定義を採用していることに注意
            # np.random.gammaはパラメータを(shape, scale)とする方の定義を採用している
            # scale = 1/rateなので、rateをscaleに変換してから np.random.gamma に渡す
            self.theta_m[m] = np.random.gamma(self.alpha_m[m]+sum_y_ui_m[m], 1/(self.b_m[m]-sum_log_x_u_m[m]))
        print(f'd_m:{self.d_m}, theta_m:{self.theta_m}')
    
    # depth怪しい
    def get_depth(self):
        return self.update_max_depth(self.root, 0)
    
    def update_max_depth(self, node, max_depth):
        max_depth = max(max_depth, node.depth)
        #print(f'node:{node.token_id}, depth:{node.depth}, children:{node.children}')
        for child in node.children.values():
            max_depth = max(max_depth, self.update_max_depth(child, max_depth))
        return max_depth
    
    def init_hyperparameters_at_depth_if_needed(self, depth):
        if depth >= len(self.d_m):
            while depth >= len(self.d_m):
                self.d_m.append(HPYLM_INITIAL_D)
        if depth >= len(self.theta_m):
            while depth >= len(self.theta_m):
                self.theta_m.append(HPYLM_INITIAL_THETA)
        if depth >= len(self.a_m):
            while depth >= len(self.a_m):
                self.a_m.append(HPYLM_INITIAL_A)
        if depth >= len(self.b_m):
            while depth >= len(self.b_m):
                self.b_m.append(HPYLM_INITIAL_B)
        if depth >= len(self.alpha_m):
            while depth >= len(self.alpha_m):
                self.alpha_m.append(HPYLM_INITIAL_ALPHA)
        if depth >= len(self.beta_m):
            while depth >= len(self.beta_m):
                self.beta_m.append(HPYLM_INITIAL_BETA)
    
    ### モニター用
    def get_num_nodes(self):
        return self.root.get_num_nodes()
    
    def get_num_customers(self):
        return self.root.get_num_customers()
    
    def Pw_log2(self, token_ids):
        sum_pw_h = 0
        for t in range(1, len(token_ids)):
            token_id = token_ids[t]
            pw_h = self.Pw_given_h(token_id, token_ids[:t])
            sum_pw_h += math.log2(pw_h)
        return sum_pw_h
    
    def log_Pw(self, token_ids):
        sum_pw_h = 0
        context_token_ids = token_ids[0:1]
        for token_id in token_ids:
            pw_h = self.Pw_given_h(token_id, context_token_ids)
            sum_pw_h += math.log(pw_h)
            context_token_ids.append(token_id)
        return sum_pw_h

In [69]:
class Vocab:
    def __init__(self):
        self.token_ids = {ID_BOS, ID_EOS}
        self.string_by_token_id = {ID_BOS:'<bos>', ID_EOS:'<eos>'}
    
    def add_string(self, s):
        token_id = hash(s)
        self.string_by_token_id[token_id] = s
        self.token_ids.add(token_id)
        return token_id

In [70]:
import random, pickle, datetime
class Model:
    def __init__(self):
        self.vpylm = VPYLM()
        self.vocab = Vocab()
        self.dataset_train = []
        self.dataset_test = []
        self.prev_depths_for_data = []
        self.rand_indices = []

        self.word_count = {}
        self.sum_word_count = 0
        self.gibbs_first_addition = True
        self.vpylm_loaded = False
        self.is_ready = False
        self.time = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9)))
    
    # テキストファイルをデータセットに追加
    def load_textfile(self, filename, train_split_ratio):
        with open(filename, 'r') as f:
            lines = f.read().splitlines()
        train_split = int(len(lines)*train_split_ratio)
        self.rand_indices = list(range(len(lines)))
        random.shuffle(self.rand_indices)
        for i in range(len(lines)):
            sentence = lines[self.rand_indices[i]]
            if i<train_split:
                self.add_data_to(sentence, self.dataset_train)
            else:
                self.add_data_to(sentence, self.dataset_test)

    # 文をデータセットに追加
    def add_data_to(self, sentence, dataset):
        words = sentence.split(' ')
        if len(words)>0:
            token_ids = []
            token_ids.append(ID_BOS)
            for word in words:
                if len(word) == 0:
                    continue
                token_id = self.vocab.add_string(word)
                token_ids.append(token_id)
                if token_id in self.word_count.keys():
                    self.word_count[token_id] += 1
                else:
                    self.word_count[token_id] = 1
                self.sum_word_count += 1
            token_ids.append(ID_EOS)
            dataset.append(token_ids)

    def compile(self):
        for data_index in range(len(self.dataset_train)):
            token_ids = self.dataset_train[data_index]
            prev_depths = [-1]*(len(token_ids))
            self.prev_depths_for_data.append(prev_depths)
        self.is_ready = True

    def get_num_vocab(self):
        return len(self.word_count)

    def set_g0(self, g0):
        self.vpylm.g0 = g0

    def perform_gibbs_sampling(self):
        if len(self.rand_indices) != len(self.dataset_train):
            self.rand_indices = list(range(len(self.dataset_train)))
        random.shuffle(self.rand_indices)

        # dataset: sentenceの集合?
        for data_index in self.rand_indices:
            token_ids = self.dataset_train[data_index]
            prev_depths = self.prev_depths_for_data[data_index]

            for token_t_index in range(1, len(token_ids)):
                if not self.gibbs_first_addition:
                    prev_depth = prev_depths[token_t_index]
                    self.vpylm.remove_customer_at_timestep(token_ids, token_t_index, prev_depth)
                new_depth = self.vpylm.sample_depth_at_timestep(token_ids, token_t_index)
                self.vpylm.add_customer_at_timestep(token_ids, token_t_index, new_depth)
                prev_depths[token_t_index] = new_depth

        self.gibbs_first_addition = False
    
    def sample_hyperparams(self):
        self.vpylm.sample_hyperparams()
    
    def get_time(self):
        return self.time

    ### 以下はモニター用の関数
    def get_num_nodes(self):
        return self.vpylm.get_num_nodes()
    
    def get_num_customers(self):
        return self.vpylm.get_num_customers()
    
    def get_vpylm_depth(self):
        return self.vpylm.get_depth()
    
    def perplexity_train(self):
        return self.perplexity(self.dataset_train)
    
    def perplexity_test(self):
        return self.perplexity(self.dataset_test)
    
    def perplexity(self, dataset):
        log_P = 0
        for token_ids in dataset:
            log_P += self.vpylm.Pw_log2(token_ids) / len(token_ids) # なぜ割るのか？
        return 2.0 ** (-log_P/len(dataset))
    
    def log_P_train(self):
        return self.log_P(self.dataset_train)

    def log_P_test(self):
        return self.log_P(self.dataset_test)    

    def log_P(self, dataset):
        log_P = 0
        for token_ids in dataset:
            log_P += self.vpylm.log_Pw(token_ids)
        return log_P

    def save(self, out_dir):
        with open(f'{out_dir}/vocab_{self.time}.pickle', mode='wb') as f:
            pickle.dump(self.vocab, f)
        with open(f'{out_dir}/vpylm_{self.time}.pickle', mode='wb') as f:
            pickle.dump(self.vpylm, f)

In [16]:
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
%cd '/content/drive/MyDrive/KentaroKamaishi/Pitman-Yor' 

/content/drive/MyDrive/KentaroKamaishi/Pitman-Yor


In [8]:
filename = './PY_dataset/input0.txt'
out_dir = 'out_input0'

In [96]:
# 学習
model = Model()
model.load_textfile(filename, 0.9)
model.set_g0(1.0/model.get_num_vocab())
model.compile()
model.get_num_vocab()

9

In [56]:
import os
try:
    os.mkdir(out_dir)
except:
    pass

In [97]:
import time
import pandas as pd

num_epoch = 500
csv_perplexity = []
csv = {'epoch':[], 'elapsed_time':[], 'max_depth':[], 'num_nodes':[], 'num_customers':[]}

for epoch in range(num_epoch):
    start = time.time()
    model.perform_gibbs_sampling()
    model.sample_hyperparams()

    elapsed_time =time.time() - start
    depth = model.get_vpylm_depth()
    num_nodes = model.get_num_nodes()
    num_cs = model.get_num_customers()

    print(f'epoch:{epoch}/{num_epoch}, elapsed_time:{elapsed_time}, \
    depth:{depth}, num_nodes:{num_nodes}, num_cs:{num_cs}')

    csv["epoch"].append(epoch)
    csv["elapsed_time"].append(elapsed_time)
    csv["max_depth"].append(depth)
    csv["num_nodes"].append(num_nodes)
    csv["num_customers"].append(num_cs)

    if epoch % 10 == 0:
        perp_train = model.perplexity_train()
        perp_test = model.perplexity_test()
        print(f'log_likelihood:{model.log_P_train()}, {model.log_P_test()}')
        print(f'perplexity_train:{perp_train}, perplexity_test:{perp_test}')
        model.save(out_dir)
        csv_perplexity.append([epoch, perp_train, perp_test])

        data = pd.DataFrame(csv_perplexity)
        data.columns = ["epoch", "perplexity_train", "perplexity_test"]
        data.to_csv(f'{out_dir}/perplexity_{model.get_time()}.csv')
        data = pd.DataFrame(csv)
        data.to_csv(f'{out_dir}/basic_{model.get_time()}.csv')

d_m:[0.6318794971437558, 0.9000050462231192, 0.3060359353811652], theta_m:[0.3274652233883894, 0.8189433492791626, 0.3646441083943356]
epoch:0/500, elapsed_time:0.01435399055480957,     depth:2, num_nodes:15, num_cs:172
log_likelihood:-363.7169798871716, -52.86550804620016
perplexity_train:4.69443579805813, perplexity_test:4.875025881291437
d_m:[0.713401902133892, 0.7915741562526684, 0.5393791993410609, 0.6834341127679481, 0.4154731552387504], theta_m:[0.9179673223826831, 0.8569660305344691, 0.6543801190984198, 1.32453882699216, 1.9975571878537481]
epoch:1/500, elapsed_time:0.009035587310791016,     depth:4, num_nodes:15, num_cs:184
d_m:[0.5879124189015824, 0.6475607151539037, 0.4729983666164141, 0.6834341127679481, 0.4154731552387504], theta_m:[0.8061288083448773, 0.5159278906142785, 0.007910461998078261, 1.32453882699216, 1.9975571878537481]
epoch:2/500, elapsed_time:0.009076595306396484,     depth:2, num_nodes:15, num_cs:197
d_m:[0.6323139997951098, 0.509354593251587, 0.189252421194

In [None]:
import pickle
with open('vocab.pickle', mode='wb') as f:
    pickle.dump(model.vocab, f)
with open('vpylm.pickle', mode='wb') as f:
    pickle.dump(model.vpylm, f)

In [73]:
def sample_next_token(context_token_ids, all_token_ids):
    word_ids = []
    probs = []
    sum = 0
    max_p = 0
    max_id = ID_BOS

    for token_id in all_token_ids:
        if token_id == ID_BOS:
            continue
        pw_h = model.vpylm.Pw_given_h(token_id, context_token_ids)
        if pw_h > 0:
            word_ids.append(token_id)
            probs.append(pw_h)
            sum += pw_h
            if pw_h > max_p:
                max_p = pw_h
                max_id = token_id

    if len(word_ids) == 0:
        return ID_EOS
    if sum == 0:
        return ID_EOS
    
    return max_id

    """
    normalizer = 1/sum
    bernoulli = np.random.uniform(0,1)
    stack = 0
    for i in range(len(word_ids)):
        stack += probs[i] * normalizer
        if stack >= bernoulli:
            return word_ids[i]
    return word_ids[-1]
    """

In [74]:
def token_ids_to_sentence(token_ids):
    out = ""
    for token_id in token_ids:
        out += model.vocab.string_by_token_id[token_id] + " "
    return out

def generate_sentence(x):
    start_i = 1
    if x is None:
        context_token_ids = [ID_BOS]
    else:
        context_token_ids = x
        start_i = len(x)
    for n in range(100):
        next_id = sample_next_token(context_token_ids, model.vocab.token_ids)
        if next_id == ID_EOS:
            return token_ids_to_sentence(context_token_ids[start_i:])
        context_token_ids.append(next_id)
    return token_ids_to_sentence(context_token_ids[start_i:])

In [75]:
model.vocab.string_by_token_id

{0: '<bos>',
 1: '<eos>',
 6661234110526308594: '000',
 -7219619411728825083: '001',
 -486668966899472649: '.',
 5260776809832962493: '101',
 6914546132773419944: '110',
 -2887652700584036013: '010',
 7191366621530416125: '111',
 7125477652977117189: '100',
 7105728146654439195: '011'}

In [80]:
token_id_by_string = {v:k for k,v in model.vocab.string_by_token_id.items()}

In [81]:
token_id_by_string

{'<bos>': 0,
 '<eos>': 1,
 '000': 6661234110526308594,
 '001': -7219619411728825083,
 '.': -486668966899472649,
 '101': 5260776809832962493,
 '110': 6914546132773419944,
 '010': -2887652700584036013,
 '111': 7191366621530416125,
 '100': 7125477652977117189,
 '011': 7105728146654439195}

In [94]:
print("文章を生成します")
input = '000 100 100 100'.split()
context_token_ids = [token_id_by_string[s] for s in input]
next_id = sample_next_token(context_token_ids, model.vocab.token_ids)
print(token_ids_to_sentence([next_id]))

文章を生成します
101 
