In [384]:
import numpy as np
from tqdm import tqdm
import multiprocessing as mp

In [195]:
def generate_sequence(start, end_pos):
    seq = [start]
    pos = start
    while pos!=0 and pos!=end_pos:
        if np.random.rand()<0.5:
            pos = pos-1
        else:
            pos = pos+1
        seq.append(pos)
    return seq

In [86]:
class random_walk():
    def __init__(self, pos_, n_state_=7):
        self.end_pos = n_state_-1
        self.pos = pos_
    
    def step(self):
        if self.pos!=0 and self.pos!=self.end_pos:
            if np.random.rand()<0.5:
                self.pos = self.pos-1
            else:
                self.pos = self.pos+1
            return None
        else:
            return 1 if self.pos==self.end_pos else 0

In [316]:
n_state = 7

def update(w, seq, alpha, lm):
    dw = np.zeros([n_state])
    et = np.zeros([n_state])
    #walk = random_walk(4, n_state)
    #t = 1
    for t in range(len(seq)):
        de = np.zeros([n_state])
        de[seq[t]] = 1
        et = de + lm * et
        
        P0 = w[seq[t]] 
        
        if t < len(seq)-1:
            P1 = w[seq[t+1]]
        else:
            P1 = 1 if seq[-1]==n_state-1 else 0
            
        #print(t, P1, P0, et)
        #print(P1-P0, sum(abs(et)))
        dw += alpha * (P1-P0) * et
        
    return dw
        
#     while True:
#         de = np.zeros([n_state])
#         de[walk.pos] = 1
#         et = de + lm * et
        
#         r = walk.step()
#         P1 = w[walk.pos] if r is None else r    
#         dw += alpha*(P1-P0)*et
#         #print(alpha*(P1-P0),sum(abs(et)),sum(abs(w)))
#         P0 = P1
#         if r==0 or r==1:
#             break
#         #t+=1
#     return dw

In [373]:
n_state = 7
w_true = np.array([0, 1/6.0, 1/3.0, 0.5, 2/3.0, 5/6.0, 1])
training_size = 10
n_training_set = 100
seqs = [[generate_sequence(3, n_state-1) for _ in range(training_size)] for _ in range(n_training_set)]

In [387]:
def rmse(w, w_true):
    return np.sqrt(np.mean((w-w_true)**2))

def training_1(alpha, lm, training_set):
    w = np.ones([n_state])*0.5
    max_iteration = 10000

    seqs = []

    for n in range(max_iteration):
        dw = np.zeros([n_state])
        for seq in training_set:
            dw += update(w, seq, alpha/len(training_set), lm)
        w += dw
        #print(sum(abs(dw)), w, "\n")
        if sum(abs(dw)) < 1e-4:
            #print(f"converge in {n} iteration")
            break
    #w, sum(abs(dw))
    error = rmse(w[1:-1], w_true[1:-1])
    
    return error

def training_2(alpha, lm, training_set):
    w = np.ones([n_state])*0.5
    for seq in training_set:
        w += update(w, seq, alpha, lm)
    error = rmse(w[1:-1], w_true[1:-1])
    
    return error

def experiment_1(alpha, lm):
    return np.mean([training_1(alpha, lm, training_set) for training_set in tqdm(seqs)])

def experiment_2(alpha, lm):
    return np.mean([training_2(alpha, lm, training_set) for training_set in tqdm(seqs)])

In [426]:
pool = mp.Pool(mp.cpu_count())

for lm in np.linspace(0,1,11):
    errors = pool.apply_async(experiment_1, args=(0.1, lm))
errors = [errors]

# error = {}
# for lm in tqdm(np.linspace(0,1,11)):
#     error[lm] = np.mean([training_1(alpha=0.1, lm=lm, seqs[i]) for i in range(100)])

100%|██████████| 100/100 [00:22<00:00,  4.41it/s]
100%|██████████| 100/100 [00:23<00:00,  4.56it/s]
100%|██████████| 100/100 [00:24<00:00,  4.70it/s]
 96%|█████████▌| 96/100 [00:25<00:01,  3.61it/s]]
 93%|█████████▎| 93/100 [00:26<00:01,  4.15it/s]]
100%|██████████| 100/100 [00:26<00:00,  6.48it/s]
100%|██████████| 100/100 [00:27<00:00,  6.70it/s]
100%|██████████| 100/100 [00:27<00:00,  6.70it/s]
100%|██████████| 100/100 [00:28<00:00,  7.06it/s]
100%|██████████| 100/100 [00:28<00:00,  6.90it/s]
100%|██████████| 100/100 [00:28<00:00,  6.57it/s]


In [424]:
pool = mp.Pool(mp.cpu_count())

for alpha, lm in itertools.product(np.linspace(0,1,2), np.linspace(0.05,0.6,2)):
    errors = pool.apply_async(experiment_2, args=(alpha, lm))
errors = [errors]

  0%|          | 0/100 [00:00<?, ?it/s]





[<multiprocessing.pool.ApplyResult at 0x1245bc5d0>]