In [None]:
%matplotlib inline
import torch # type: ignore
import pickle
import warnings
from pathlib import Path
from DPMoSt import DPMoSt
from utility import data_creation, plot_solution
import os

In [None]:
warnings.filterwarnings("ignore")

In [None]:
random_state=42
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device='cpu'
print(f'Device used: {device}')

# Parameters setting

In [None]:
n_sim=100
noise_std=0.5

v_n_features=[2,5,10]
v_n_subjects=[50]
v_n_time_for_subject=[1,2,5,7,10,20]
v_snr=['low', 'medium', 'high']

# Data Creation

In [None]:
for n_features in v_n_features:
    for n_subjects in v_n_subjects:
        for n_time_for_subject in v_n_time_for_subject:
            for snr in v_snr:
                torch.manual_seed(random_state)
                if snr=='low':
                    max_dist=0.5
                if snr=='medium':
                    max_dist=1
                if snr=='high':
                    max_dist=2
                
                name_folder=f'simulations/sim_features_{n_features}_subjects_{n_subjects}_time_points_{n_time_for_subject}_snr_{snr}'
                Path(name_folder).mkdir(parents=True, exist_ok=True)
                for idx in range(n_sim):
                    if not os.path.isfile(f'{name_folder}/data_{idx}.pkl'):
                        print(f'features: {n_features} -- subjects: {n_subjects} -- time points: {n_time_for_subject} -- snr: {snr} -- data: {idx}')
                        dict_data=data_creation(n_subjects=n_subjects, n_time_points=n_time_for_subject, n_features=n_features, 
                                            noise_std=noise_std, max_dist=max_dist, time_shifted=True, 
                                            device=device, name_path=f'{name_folder}/data_{idx}')

# DP-MoSt

In [None]:
for n_features in v_n_features:
    for n_subjects in v_n_subjects:
        for n_time_for_subject in v_n_time_for_subject:
            for snr in v_snr:
                torch.manual_seed(random_state)
                name_folder=f'simulations/sim_features_{n_features}_subjects_{n_subjects}_time_points_{n_time_for_subject}_snr_{snr}'
                Path(name_folder).mkdir(parents=True, exist_ok=True)
                for idx in range(n_sim):
                    if not os.path.isfile(f'{name_folder}/sol_{idx}.pkl'):
                        print(f'sim -> features: {n_features} -- subjects: {n_subjects} -- time points: {n_time_for_subject} -- SNR: {snr} -- idx: {idx}')
                        with open(f'{name_folder}/data_{idx}.pkl', 'rb') as f:
                            dict_data = pickle.load(f) 
                        data=dict_data['data']

                        dpmost=DPMoSt(data=data, prior_time_shift='gaussian', stopping_criteria=True, verbose=False, name_path=name_folder, device=device)

                        dpmost.optimise(n_outer_iterations=30, n_inner_iterations=30, lr=1e-1)
                        dpmost.save(name_file=f'sol_{idx}')
                        plot_solution(dpmost, show=True, name_path=f'{name_folder}/fig_sol_{idx}')