In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../'))
import numpy as np
import random
from andi_datasets.models_phenom import models_phenom
from andi_datasets.datasets_phenom import datasets_phenom
from andi_datasets.utils_challenge import label_continuous_to_list

In [2]:
N = 2
T = 600
L = None

WINDOW_WIDTHS = np.arange(20, 100, 2)
SHIFT_WIDTH = 40
REG_JUMP = 2

In [3]:
def uncumulate(xs:np.ndarray):
    assert xs.ndim == 1
    uncum_list = [0.]
    for i in range(1, len(xs)):
        uncum_list.append(xs[i] - xs[i-1])
    return np.array(uncum_list)

In [4]:
def make_signal(x_pos, y_pos, win_widths):
    all_vals = []
    for win_width in win_widths:
        if win_width >= len(x_pos):
            continue
        vals = []
        for checkpoint in range(int(win_width/2), len(x_pos) - int(win_width/2)):
            xs = x_pos[checkpoint - int(win_width/2) : checkpoint + int(win_width/2)]
            ys = y_pos[checkpoint - int(win_width/2) : checkpoint + int(win_width/2)]

            xs1 = xs[:int(len(xs)/2)] - float(xs[:int(len(xs)/2)][0])
            xs2 = xs[int(len(xs)/2):] - float(xs[int(len(xs)/2):][0])

            ys1 = ys[:int(len(ys)/2)] - float(ys[:int(len(ys)/2)][0])
            ys2 = ys[int(len(ys)/2):] - float(ys[int(len(ys)/2):][0])

            std_xs1 = np.std(xs1)
            std_xs2 = np.std(xs2)
            std_ys1 = np.std(ys1)
            std_ys2 = np.std(ys2)

            surface_xs1 = abs(np.sum(xs1)) / win_width
            surface_xs2 = abs(np.sum(xs2)) / win_width
            surface_ys1 = abs(np.sum(ys1)) / win_width
            surface_ys2 = abs(np.sum(ys2)) / win_width


            xs1 = np.cumsum(abs(xs1)) #* surface_xs1
            xs2 = np.cumsum(abs(xs2)) #* surface_xs2
            ys1 = np.cumsum(abs(ys1)) #* surface_ys1
            ys2 = np.cumsum(abs(ys2)) #* surface_ys2


            xs_max_val = max(np.max(abs(xs1)), np.max(abs(xs2)))
            xs1 = xs1 / xs_max_val
            xs2 = xs2 / xs_max_val
            xs1 = xs1 / win_width
            xs2 = xs2 / win_width

            ys_max_val = max(np.max(abs(ys1)), np.max(abs(ys2)))
            ys1 = ys1 / ys_max_val
            ys2 = ys2 / ys_max_val
            ys1 = ys1 / win_width 
            ys2 = ys2 / win_width

            vals.append(abs(np.sum(xs1 - xs2 + ys1 - ys2)) 
                       * (max(std_xs1, std_xs2) / min(std_xs1, std_xs2)) 
                             * (max(std_ys1, std_ys2) / min(std_ys1, std_ys2)))

        vals = np.concatenate((np.ones(int(win_width/2)) * 0, vals))
        vals = np.concatenate((vals, np.ones(int(win_width/2)) * 0))
        vals = np.array(vals)
        all_vals.append(vals)
    
    all_vals = np.array(all_vals) + 1e-7
    normalized_vals = all_vals.copy()
    for i in range(len(normalized_vals)):
            normalized_vals[i] = normalized_vals[i] / np.max(normalized_vals[i])
    return all_vals, normalized_vals

In [5]:
def slice_data(signal_seq, jump_d, ext_width, shift_width):
    slice_d = []
    indice = []
    for i in range(ext_width, signal_seq.shape[1] - ext_width, jump_d):
        crop = signal_seq[:, i - shift_width//2: i + shift_width//2]
        if crop.shape[1] != shift_width:
            crop = np.hstack((crop, np.zeros((crop.shape[0], shift_width - crop.shape[1])) ))
        slice_d.append(crop)
        indice.append(i)
    return np.array(slice_d), np.array(indice) - ext_width

In [6]:
def signal_from_extended_data(x, y, win_widths, ext_width, shift_width):
    datas = []
    for data in [x, y]:
        delta_prev_data = -uncumulate(data[:min(data.shape[0], ext_width)])
        delta_prev_data[0] += float(data[0])
        prev_data = np.cumsum(delta_prev_data)[::-1]

        delta_next_data = -uncumulate(data[data.shape[0] - min(data.shape[0], ext_width):][::-1])
        delta_next_data[0] += float(data[-1])
        next_data = np.cumsum(delta_next_data)

        ext_data = np.concatenate((prev_data, data))
        ext_data = np.concatenate((ext_data, next_data))
        datas.append(ext_data)

    signal, norm_signal = make_signal(datas[0], datas[1], win_widths)
    sliced_signals, slice_indice = slice_data(signal, 1, min(data.shape[0], ext_width), 10)

    return (signal[:, delta_prev_data.shape[0]:signal.shape[1] - delta_next_data.shape[0]],
            norm_signal[:, delta_prev_data.shape[0]:signal.shape[1] - delta_next_data.shape[0]],
           sliced_signals,
           slice_indice,
           signal)

In [7]:
def chop_with_shift(signal, norm_signal, slice_sum_norm, changepoints=None, count_0=None, count_1=None):
    chopped_signals = []
    chopped_labels = []
    reg_chopped_signals = []
    reg_chopped_labels = []
    chopped_slice_sum = []
    pat=0

    changepoints_reg = []
    for cp in changepoints:
        changepoints_reg.extend(range(cp - SHIFT_WIDTH//4, cp + SHIFT_WIDTH//4))
    changepoints_reg = set(changepoints_reg)
        
    if len(changepoints) != 0:
        for cp in changepoints:
            if cp >= (SHIFT_WIDTH//2) and cp < signal.shape[1] - (SHIFT_WIDTH//2):
                same_c = 0
                chopped_signals.append(np.hstack((signal[:, cp - (SHIFT_WIDTH//2):cp + (SHIFT_WIDTH//2)],
                                                  norm_signal[:, cp - (SHIFT_WIDTH//2):cp + (SHIFT_WIDTH//2)])))
                chopped_slice_sum.append(slice_sum_norm[cp - (SHIFT_WIDTH//2):cp + (SHIFT_WIDTH//2)])
                chopped_labels.append(1)
                count_1 += 1
                for relative_x in range(-REG_JUMP*3 + cp, REG_JUMP*3 + cp, REG_JUMP):
                    reg_signal_seq = signal[:, relative_x - (SHIFT_WIDTH//2): relative_x + (SHIFT_WIDTH//2)]
                    norm_reg_signal_seq = norm_signal[:, relative_x - (SHIFT_WIDTH//2): relative_x + (SHIFT_WIDTH//2)]
                    if reg_signal_seq.shape[1] == SHIFT_WIDTH:
                        reg_chopped_signals.append(np.hstack((reg_signal_seq, norm_reg_signal_seq)))
                        reg_chopped_labels.append(cp - relative_x)
                    
                while True:
                    pat += 1
                    random_selec = np.random.randint(0, T)
                    random_selec_reg = set(np.arange(random_selec - SHIFT_WIDTH//4, random_selec + SHIFT_WIDTH//4))
                        
                    if len(changepoints_reg & random_selec_reg) == 0:
                        if random_selec >= (SHIFT_WIDTH//2) and random_selec < signal.shape[1] - (SHIFT_WIDTH//2):
                            chopped_signals.append(np.hstack((signal[:, random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)],
                                                             norm_signal[:, random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)])))
                            chopped_slice_sum.append(slice_sum_norm[random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)])
                            chopped_labels.append(0)

                        elif random_selec < (SHIFT_WIDTH//2):
                            chopped_signals.append(np.hstack((signal[:, random_selec:random_selec + SHIFT_WIDTH],
                                                             norm_signal[:, random_selec:random_selec + SHIFT_WIDTH])))
                            chopped_slice_sum.append(slice_sum_norm[random_selec:random_selec + SHIFT_WIDTH])
                            chopped_labels.append(0)

                        else:
                            chopped_signals.append(np.hstack((signal[:, random_selec - SHIFT_WIDTH:random_selec],
                                                             norm_signal[:, random_selec - SHIFT_WIDTH:random_selec])))
                            chopped_slice_sum.append(slice_sum_norm[random_selec - SHIFT_WIDTH:random_selec])
                            chopped_labels.append(0)
                        count_0 += 1
                        same_c += 1
                        if same_c >= 3:
                            break
                    if pat >= 20:
                        break
    else:
        for _ in range(1):
            random_selec = np.random.randint(0, T)
            if random_selec >= (SHIFT_WIDTH//2) and random_selec < signal.shape[1] - (SHIFT_WIDTH//2):
                chopped_signals.append(np.hstack((signal[:, random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)],
                                                norm_signal[:, random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)])))
                chopped_slice_sum.append(slice_sum_norm[random_selec - (SHIFT_WIDTH//2):random_selec + (SHIFT_WIDTH//2)])
                chopped_labels.append(0)
            elif random_selec < (SHIFT_WIDTH//2):
                chopped_signals.append(np.hstack((signal[:, random_selec:random_selec + SHIFT_WIDTH],
                                                 norm_signal[:, random_selec:random_selec + SHIFT_WIDTH])))
                chopped_slice_sum.append(slice_sum_norm[random_selec:random_selec + SHIFT_WIDTH])
                chopped_labels.append(0)
            else:
                chopped_signals.append(np.hstack((signal[:, random_selec - SHIFT_WIDTH:random_selec],
                                                 norm_signal[:, random_selec - SHIFT_WIDTH:random_selec])))
                chopped_slice_sum.append(slice_sum_norm[random_selec - SHIFT_WIDTH:random_selec])              
                chopped_labels.append(0)
            count_0 += 1
    return (np.array(chopped_signals), np.array(chopped_labels), 
            count_0, count_1, 
            np.array(reg_chopped_signals), np.array(reg_chopped_labels),
            np.array(chopped_slice_sum))

In [None]:
input_signals = []
input_labels = []
input_reg_signals = []
input_reg_labels = []
input_features = []
input_slice_sum = []
input_slice_snr = []

K_bound = [1e-12, 1000000.0]
alpha_bound = [0, 1.999]
alphas1 = [0.01, 0.4]
alphas2 = [0.65, 1.9]
count_0 = 0
count_1 = 0

for step in range(3000):
    if step % 100 == 0: print(step, count_0, count_1, end=' | ')
    alpha1 = np.random.uniform(alphas1[0], alphas1[1])
    alpha2 = np.random.uniform(alphas2[0], alphas2[1])
    single_alpha = np.random.choice([alpha1, alpha2])
    multi_trajs_model, multi_labels_model = models_phenom().multi_state(N=N,
                                                            L=L,
                                                            T=T,
                                                            alphas=[alpha1, alpha2],  # Fixed alpha for each state
                                                            Ds=[[0.04, 0.0], [0.1, 0.0]],# Mean and variance of each state
                                                            M=[[0.99, 0.01], [0.01, 0.99]]
                                                           )

    single_trajs_model, single_labels_model = models_phenom().multi_state(N=N,
                                                            L=L,
                                                            T=T,
                                                            alphas=[single_alpha, single_alpha],  # Fixed alpha for each state
                                                            Ds=[[0.1, 0.0], [0.1, 0.0]],# Mean and variance of each state
                                                            M=[[1.0, 0.0], [0.0, 1.0]]
                                                           )
    
    for i in range(N):
        multi_s, multi_s_norm, multi_sliced_signals, _, _ = signal_from_extended_data(multi_trajs_model[:, i, 0],
                                                                                      multi_trajs_model[:, i, 1],
                                                                                      WINDOW_WIDTHS,
                                                                                      WINDOW_WIDTHS[-1]//2,
                                                                                      SHIFT_WIDTH)

        slice_sum = np.sum(multi_sliced_signals, axis=(1, 2))
        slice_sum /= np.max(slice_sum)
        slice_sum_SNR = np.mean(slice_sum)**2 / np.std(slice_sum)**2
        
        #multi_s, multi_s_norm = make_signal(multi_trajs_model[:, i, 0], multi_trajs_model[:, i, 1], WINDOW_WIDTHS)
        changepoints, alphas_cp, Ds, state_num = label_continuous_to_list(multi_labels_model[:, i, :])
        chop_signal, chop_label, count_0, count_1, reg_signal, reg_label, chop_slice_sum = chop_with_shift(multi_s,
                                                                                           multi_s_norm,
                                                                                           slice_sum,
                                                                                           changepoints[:-1], 
                                                                                           count_0, count_1)
        
        input_signals.extend(chop_signal)
        input_labels.extend(chop_label)
        input_reg_signals.extend(reg_signal)
        input_reg_labels.extend(reg_label)
        
        input_slice_sum.extend(chop_slice_sum)
        input_slice_snr.extend([slice_sum_SNR] * len(chop_slice_sum))
        
        feat1 = np.array([np.mean(multi_s, axis=1)**2 / np.std(multi_s, axis=1)**2] * chop_signal.shape[0])
        input_features.extend(feat1)
        
        single_s, single_s_norm, single_sliced_singals, _, _ = signal_from_extended_data(single_trajs_model[:, i, 0],
                                                                                         single_trajs_model[:, i, 1],
                                                                                         WINDOW_WIDTHS,
                                                                                         WINDOW_WIDTHS[-1]//2,
                                                                                         SHIFT_WIDTH)
        
        slice_sum = np.sum(single_sliced_singals, axis=(1, 2))
        slice_sum /= np.max(slice_sum)
        slice_sum_SNR = np.mean(slice_sum)**2 / np.std(slice_sum)**2
        
        #single_s, single_s_norm = make_signal(single_trajs_model[:, i, 0], single_trajs_model[:, i, 1], WINDOW_WIDTHS)
        changepoints, alphas_cp, Ds, state_num = label_continuous_to_list(single_labels_model[:, i, :])    
        chop_signal, chop_label, count_0, count_1, _, _, chop_slice_sum = chop_with_shift(single_s,
                                                                                          single_s_norm,
                                                                                          slice_sum,
                                                                                          changepoints[:-1],
                                                                                          count_0, count_1)
        input_signals.extend(chop_signal)
        input_labels.extend(chop_label)
        
        input_slice_sum.extend(chop_slice_sum)
        input_slice_snr.extend([slice_sum_SNR] * len(chop_slice_sum))
                
        feat1 = np.array([np.mean(single_s, axis=1)**2 / np.std(single_s, axis=1)**2] * chop_signal.shape[0])
        input_features.extend(feat1)
        

for i in range(10000):
    s_alphas = [0.5, 1.5]
    s_alpha = np.random.uniform(s_alphas[0], s_alphas[1])
    single_trajs_model, single_labels_model = models_phenom().multi_state(N=2,
                                                            L=L,
                                                            T=200,
                                                            alphas=[s_alpha, s_alpha],  # Fixed alpha for each state
                                                            Ds=[[0.1, 0.01], [0.1, 0.01]],# Mean and variance of each state
                                                            M=[[1.0, 0.0], [0.0, 1.0]]
                                                           )
    
    single_s, single_s_norm, single_sliced_singals, _, _ = signal_from_extended_data(single_trajs_model[:, 0, 0],
                                                                                     single_trajs_model[:, 0, 1],
                                                                                     WINDOW_WIDTHS,
                                                                                     WINDOW_WIDTHS[-1]//2,
                                                                                     SHIFT_WIDTH)
    
    slice_sum = np.sum(single_sliced_singals, axis=(1, 2))
    slice_sum /= np.max(slice_sum)
    slice_sum_SNR = np.mean(slice_sum)**2 / np.std(slice_sum)**2
    
    #single_s, single_s_norm = make_signal(single_trajs_model[:, 0, 0], single_trajs_model[:, 0, 1], WINDOW_WIDTHS)
    changepoints, alphas_cp, Ds, state_num = label_continuous_to_list(single_labels_model[:, 0, :])
    chop_signal, chop_label, count_0, count_1, _, _, chop_slice_sum = chop_with_shift(single_s,
                                                                                      single_s_norm,
                                                                                      slice_sum,
                                                                                      changepoints[:-1],
                                                                                      count_0, count_1)        
    input_signals.extend(chop_signal)
    input_labels.extend(chop_label)
    
    input_slice_sum.extend(chop_slice_sum)
    input_slice_snr.extend([slice_sum_SNR] * len(chop_slice_sum))
    
    feat1 = np.array([np.mean(single_s, axis=1)**2 / np.std(single_s, axis=1)**2] * chop_signal.shape[0])
    input_features.extend(feat1)

0 0 0 | 100 2666 1121 | 200 5273 2229 | 300 7937 3358 | 400 10587 4482 | 500 13182 5607 | 600 15755 6723 | 700 18388 7812 | 

In [None]:
input_signals = np.array(input_signals)
input_labels = np.array(input_labels)
input_reg_signals = np.array(input_reg_signals)
input_reg_labels = np.array(input_reg_labels)
input_features = np.array(input_features)
input_slice_sum = np.array(input_slice_sum)
input_slice_snr = np.array(input_slice_snr)

In [None]:
print(input_signals.shape, input_labels.shape)
print(input_reg_signals.shape, input_reg_labels.shape)
print(input_features.shape)
print(input_slice_sum.shape, input_slice_snr.shape)
print(count_0, count_1)

In [None]:
np.savez_compressed(f'./training_set_{SHIFT_WIDTH}_{REG_JUMP}.npz',
                    input_signals=input_signals,
                    input_labels=input_labels,
                    input_reg_signals=input_reg_signals,
                    input_reg_labels=input_reg_labels,
                    count_0=count_0,
                    count_1=count_1,
                    input_features=input_features,
                    input_slice_sum=input_slice_sum,
                    input_slice_snr=input_slice_snr
                   )