In [1]:
import torch
import time
import numpy as np
from src.utils.metrics import purity, info_score
from src.utils.os_utils import create_folder
import math
from sklearn.cluster import KMeans
from src.networks.lstm_pp import LSTMMultiplePointProcesses 
from src.dataset.random_seq import RandomGeneratedSequences
from junk_trainer import TrainerClusterwise

In [2]:
path = 'src/dataset/sin_K2_C5/'
n_steps = 8
n_clusters = 10
n_runs = 10
n_events = 7
dataset = RandomGeneratedSequences(path, num_of_event_types=n_events, num_of_steps=n_steps)
device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

In [3]:
data, target = [], []
for x, y in dataset:
    data.append(x), target.append(y)
data = torch.stack(data)
target = torch.stack(target)

In [None]:
# preparing folders
create_folder('experiments')
exp_folder = 'experiments/'
create_folder('experiments/results')
path_to_results = 'experiments/results'

# iterations over runs
i = 0
while i < n_runs:
    model = LSTMMultiplePointProcesses(input_size=n_events+1, hidden_size=20, num_layers=3, num_classes=n_events, num_clusters=n_clusters, n_steps=n_steps).to(device) 
    optimizer = torch.optim.Adam(model.parameters()) 
    best_model_path = path_to_results + '/exp_{}'.format(i) + '/best_model.pt'
    create_folder(path_to_results + '/exp_{}'.format(i))
    exp_folder = path_to_results + '/exp_{}'.format(i)
    trainer = TrainerClusterwise(model, optimizer, device, data, n_clusters, target=target)
    losses, results, cluster_part, stats = trainer.train()

    # results check
    if cluster_part is None:
        print('Solution failed')
        continue

    # saving results
    with open(exp_folder + '/losses.pkl', 'wb') as f:
        pickle.dump(losses, f)
    with open(exp_folder + '/results.pkl', 'wb') as f:
        pickle.dump(results, f)
    with open(exp_folder + '/stats.pkl', 'wb') as f:
        pickle.dump(stats, f)
    with open(exp_folder + '/args.json', 'w') as f:
        json.dump(vars(args), f)
    torch.save(trainer.model, exp_folder + '/last_model.pt')
    i += 1