In [None]:
import pandas as pd
import numpy as np
from numpy.random import randint  # np.random.randint() [ )
from sampling import sample
import math
import time
import pickle

with open('./sequences_with_motifs_100.pkl','rb') as file:
    sequences_with_motifs_100 = pickle.load(file)
    
with open('./label_starting_pos_100.pkl','rb') as file:
    label_starting_pos_100 = pickle.load(file)
    
with open('./label_motif_100.pkl','rb') as file:
    label_motif_100 = pickle.load(file)


# Initialize the state (motif positions) randomly

group_nums = 3
K = 400     # Number of sequences in each group
N = 200     # Length of each sequence
alphabet = ['A', 'T', 'C', 'G']
b = 0.5    # 1也可以，来不及调参了，不然可以试试1的效果
MAX_ITER = 100  # Set a large max iteration in case convergence doesn't happen


def compute_model(sequences, pos, alphabet, w, b):
    q = {x: [b] * w for x in alphabet}  # Probability matrix for motif positions
    p = {x: b for x in alphabet}  # Background probability

    # Counting character occurrence in motif positions
    for i in range(len(sequences)): # 400
        start_pos = pos[i]
        for j in range(w):
            c = sequences[i][start_pos + j] # one of "A T C G"
            q[c][j] += 1
            
    # Normalize motif position probabilities
    for c in alphabet:
        for j in range(w):
            q[c][j] = q[c][j] / float(K - 1 + len(alphabet))

    # Counting character occurrence in background
    for i in range(len(sequences)):
        for j in range(len(sequences[i])): # 200
            if j < pos[i] or j > pos[i] + w:
                c = sequences[i][j]
                p[c] += 1
                
    # Normalize background probabilities
    total = sum(p.values())
    for c in alphabet:
        p[c] = p[c] / float(total)
    return q, p

def compute_F(sequences, pos, b, w,motif_counts_matrix,num_sequences,background_probs):  # after iteration in each group
    '''
    w: length of motif
    motif_counts_matrix: complete alignment
    num_sequences = 400
    background_probs: bg_probs of complete alignment
    '''
    q, p = compute_model(sequences, pos, alphabet, w, b)
    F = 0
    for i in range(w):
        for base in alphabet: # A T C G
            q_ij = q[base][i]
            p_j = p[base]
            F += motif_counts_matrix[base][i] * math.log(q_ij / p_j)
    return F
            
# incomplete_data log_prob_ratio:G
def compute_log_prob_ratio(F, K, possible_positions, weights):
    '''
    F is calculated above
    K is the # of seqs, K = 400
    possible positions: the # of possible positions of each seq in a certain group
    weights: normalized weights of each position
    '''
    G = F
    for i in range(K):  # the ith seq in all k seqs
        L_prime = possible_positions[i]  # the # of possible starting positions of ith seq
        G -= math.log(L_prime)  

        # deal with every pos in a motif
        for j in range(L_prime):
            Y_ij = weights[i][j]  # normalized weights
            if Y_ij > 0:
                G -= Y_ij * math.log(Y_ij)  # - Y_{i,j} log(Y_{i,j})
    return G


def compute_overlap_percentage(start1, len1, start2, len2):
    # starting pos and ending_pos of overlap
    start_overlap = max(start1, start2)
    end_overlap = min(start1 + len1, start2 + len2)
    
    # if overlap exists
    if start_overlap < end_overlap:
        overlap = end_overlap - start_overlap + 1
    else:
        overlap = 0
    return overlap/len1

def shift_positions(pos, N, motif_len):  # √
    """
    pos: starting_pos, value
    N: 200
    motif_len
    """
    positions = [pos]  # original_pos
    
    # left shift
    for shift_left in range(1, 4):
        if pos - shift_left >= 0:
            positions.append(pos - shift_left)
    
    # right shift
    for shift_right in range(1, 4):
        if pos + shift_right <= N - motif_len:
            positions.append(pos + shift_right)
    
    return positions  # a list of possible starting position after shift



start_time = time.time()

values = list(sequences_with_motifs_100.values()) # 100 * 400
F_ContainAllGroups_list = [] # each group, each possible motif len, each iteration    10 * 18 * 100
StartingPos_ContainAllGroups_list = []   # each group, each seq based on the optimal motif len   10 * 400
MotifLen_ContainAllGroups_list = []   # each group, optimal motif len  10 * 1
Overlap_ContainAllGroups_list = []  # each group, each seq   10 *400
# Loop until max iterations
# 在主循环中更新每个序列的起始位置时，使用 shift_positions
for group_num in range(group_nums):  # 10 groups in total
    print('group:', group_num)
    sequences = values[group_num]
    G = -np.inf  # initialize G
    final_starting_pos_dict = {}
    F_ContainAllMotifLens_list = []
    
    for motif_len in range(5, 23):  # 5-22
        print('motif length: ', motif_len)
        pos = [randint(0, N - motif_len + 1) for x in range(K)]  # 随机初始化起始位置
        F = -np.inf
        F_list = []
        # starting_pos_list = []
        
        for it in range(MAX_ITER):
            print('iter_num:', it)
            possible_positions = [N - motif_len + 1] * K
            
            for i in range(K):  # 对 400 条序列进行采样
                # Exclude current sequence from model calculation
                seq_minus = sequences[:]
                del seq_minus[i]
                pos_minus = pos[:]
                del pos_minus[i]
                q, p = compute_model(seq_minus, pos_minus, alphabet, motif_len, b)

                # 计算当前序列的所有可能起始位置的概率
                qx = [1] * (N - motif_len + 1)
                px = [1] * (N - motif_len + 1)
                for j in range(N - motif_len + 1):
                    for k in range(motif_len):
                        c = sequences[i][j + k]
                        qx[j] *= q[c][k]  # Motif 位置的概率
                        px[j] *= p[c]  # 背景位置的概率

                # 计算 motif 和背景的比率
                Ai = [x / y for (x, y) in zip(qx, px)]  # 权重
                norm_c = sum(Ai)
                Ai = list(map(lambda x: x / norm_c, Ai))  # 归一化
                pos[i] = sample(range(N - motif_len + 1), Ai)
                
                # 遍历当前位置和左右 3 个位置
                if it % 10 == 0:
                    candidate_positions = shift_positions(pos[i], N, motif_len)
                    best_pos = pos[i]
                    best_prob = Ai[pos[i]]  # 初始位置的概率

                    # 遍历候选位置，找出概率最大的起始位置
                    for candidate in candidate_positions:
                        if Ai[candidate] > best_prob:
                            best_pos = candidate
                            best_prob = Ai[candidate]

                    pos[i] = best_pos  # 更新为最佳位置
            
            # 继续执行后续的 motif 和背景概率的计算
            motif_count_matrix = {x: [0] * motif_len for x in alphabet}
            for idx in range(len(sequences)):
                start_pos = pos[idx]
                for a in range(motif_len):
                    c = sequences[idx][start_pos + a]
                    motif_count_matrix[c][a] += 1

            # 计算背景概率
            background_probs = {x: 0 for x in alphabet}
            for idx in range(len(sequences)):
                for j in range(len(sequences[idx])):
                    if j < pos[idx] or j > pos[idx] + motif_len:
                        c = sequences[idx][j]
                        background_probs[c] += 1

            # 计算新的 F 值
            F_new = compute_F(sequences, pos, b, motif_len, motif_count_matrix, K, background_probs)
            if F_new > F:  # 更新 F 值
                F = F_new
                pos_final = pos
            F_list.append(F_new)
            # starting_pos_list.append(pos)
        final_starting_pos_dict[motif_len] = pos_final
        
        # calculate G to determine the w
        # calculate normalized weights of complete alignment
        weights =   []

        q_all, p_all = compute_model(sequences, pos, alphabet, motif_len, b) # based on complete alignment
        for i in range(K):
            # Calculate probabilities for each possible position in sequence i
            qx = [1] * (N - motif_len + 1) # the likelihood of each starting pos
            px = [1] * (N - motif_len + 1) # the background likelihood of each starting pos
            for j in range(N - motif_len + 1):  # starting pos
                for k in range(motif_len): # len of motif
                    c = sequences[i][j + k]
                    qx[j] *= q_all[c][k]  # Motif probability matrix
                    px[j] *= p_all[c]  # Background probability matrix
            # Compute the ratio between motif and background
            Ai = [x / y for (x, y) in zip(qx, px)]  # weight for each position
            norm_c = sum(Ai)
            Ai = list(map(lambda x: x / norm_c, Ai))  # Normalize to get probabilities
            weights.append(Ai)  # K * (N-w+1)
        G_new = compute_log_prob_ratio(F, K, possible_positions, weights) / (3*motif_len)
        if G_new > G:
            G = G_new
            final_motif_len = motif_len
    F_ContainAllMotifLens_list.append(F_list)    
    MotifLen_ContainAllGroups_list.append(final_motif_len)   # 10 * 1
  
    overlap_percentage = []  # store overlap percentage in each group
    for i in range(K): # 400
        overlap = compute_overlap_percentage(label_starting_pos_100[group_num][i], len(label_motif_100[group_num][i]), final_starting_pos_dict[motif_len][i], final_motif_len)
        overlap_percentage.append(overlap)
    F_ContainAllGroups_list.append(F_ContainAllMotifLens_list) # 10 * 18 * 100
    StartingPos_ContainAllGroups_list.append(final_starting_pos_dict) # 10 *18 * 400   10* dict
    Overlap_ContainAllGroups_list.append(overlap_percentage)  # 10 *400
        
        
end_time = time.time()
elapsed_time = end_time - start_time
print('operation time: ', elapsed_time)

with open('./F.pkl','wb') as file:
    pickle.dump(F_ContainAllGroups_list,file)

with open('./StartingPos.pkl','wb') as file:
    pickle.dump(StartingPos_ContainAllGroups_list,file)

with open('./MotifLen.pkl','wb') as file:
    pickle.dump(MotifLen_ContainAllGroups_list,file)

with open('./Overlap.pkl','wb') as file:
    pickle.dump(Overlap_ContainAllGroups_list,file)