In [None]:
import sys
import os
# Add the src directory to Python path so model.py can find ssn and net modules
sys.path.append(os.path.abspath('../src'))
import numpy as np
from loguru import logger
import torch
from src.model import model
from src.PDAP import retrain

In [2]:
# load the data
path = '../data_result/raw_data/gauss_cos_31x31.npy'# Initialize the weights
data = np.load(path)
logger.info(f"Loaded data with shape: {data.shape}, dtype: {data.dtype}")

[32m2025-09-22 12:19:58.421[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mLoaded data with shape: (961,), dtype: [('x', '<f8', (2,)), ('dv', '<f8', (2,)), ('v', '<f8')][0m


In [3]:
# Initialize the parameter
power = 2.1
M = 50 # number greedy insertion selected
num_iterations = 10
loss_weights = (1.0, 0.0)
pruning_threshold = 1e-15

gamma = 5.0
alpha = 1e-5
lr_adam = 1e-5
regularization = (gamma, alpha) 
th = 0.0

In [4]:
# Initialize the models
model_1 = model(activation=torch.relu, power=power, regularization=regularization, optimizer='Adam', loss_weights=loss_weights, th = th)
model_2 = model(activation=torch.relu, power=power, regularization=regularization, optimizer='SSN', loss_weights=loss_weights, th=th, train_outerweights=True)

[32m2025-09-22 12:20:00[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_configure_logger[0m:[36m105[0m - [1mModel initialized[0m
[32m2025-09-22 12:20:00[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_configure_logger[0m:[36m105[0m - [1mModel initialized[0m


In [5]:
# prepare the data
data_train, data_valid = model_1._prepare_data(data)


[32m2025-09-22 12:20:00[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m150[0m - [1mTraining set: 864 samples, Validation set: 97 samples[0m
[32m2025-09-22 12:20:00[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m153[0m - [1mData ranges - x: [-1.00, 1.00], v: [-0.72, 1.00], dv: [-5.86, 5.86][0m


In [6]:
from src.greedy_insertion import _sample_uniform_sphere_points, insertion
from src.PDAP import prune_small_weights

In [9]:
history = []
alpha = model_1.alpha

# Track best model across all iterations
best_iteration = 0
best_val_loss = float('inf')
W_hidden, b_hidden = _sample_uniform_sphere_points(M)

# Training loop
for i in range(num_iterations):
    logger.info(f"Iteration {i} - Starting...")
    model_1.train(data_train, data_valid, inner_weights=W_hidden, inner_bias=b_hidden, iterations = 1000, display_every = 100)
    state_1 = model_1.net.state_dict()
    W_hidden, b_hidden, W_out = state_1['hidden.weight'].detach().cpu().numpy(), state_1['hidden.bias'].detach().cpu().numpy(), state_1['output.weight'].detach().cpu().numpy()
    model_2.train(data_train, data_valid, inner_weights=W_hidden, inner_bias=b_hidden, outer_weights=W_out, iterations = 1000, display_every = 100)
    
    # Count and prune small weights
    state_2 = model_2.net.state_dict()
    # Use torch-native ops to count small weights, then convert to int for logging
    small_mask = (state_2['output.weight'].abs().flatten() < pruning_threshold)
    small_count = int(small_mask.sum().item())
    logger.info(f"Small weights count: {small_count}, Pruning...")
    # Prune neurons based on the trained outer weights from model_2
    W_hidden, b_hidden, _ = prune_small_weights(
        W_hidden, 
        b_hidden,
        state_2['output.weight'].detach().cpu().numpy(),
        pruning_threshold,
    )
    
    logger.info(f"Recording...")
    record = {'iteration': int(i), 'artifact': model_2.config, 'num_neurons': W_hidden.shape[0]}
    history.append(record)
    if model_2.config['best_val_loss'] < best_val_loss:
        best_val_loss = model_2.config['best_val_loss']
        best_iteration = i
        logger.info(f"New best model found at iteration {i} with validation loss: {best_val_loss:.6f}")
    
    # Insert neurons and train
    W_to_insert, b_to_insert = insertion(data_train, model_1, M, alpha)
    W_hidden = np.concatenate((W_hidden, W_to_insert), axis=0)
    b_hidden = np.concatenate((b_hidden, b_to_insert), axis=0)

[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mIteration 0 - Starting...[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m307[0m - [1mStarting network training session[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_create_network[0m:[36m194[0m - [1mCreating network with 50 neurons[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_setup_optimizer[0m:[36m244[0m - [1mUsing Adam optimizer with lr=0.01[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m322[0m - [1mTraining hyperparameters: iterations=1000, batch_size=1620, display_every=100[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m323[0m - [1mLoss weights: value=1.0, gradient=0.0[0m
[32m2025-09-22 12:32:21[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m362

In [None]:
# def save(self):
#     """Save training history."""
#     timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#     filename = f"training_history_{timestamp}.pkl"
#     filepath = os.path.join(self.stats_dir, filename)
    
#     with open(filepath, 'wb') as f:
#         pickle.dump(self.history, f)
        
#     logger.info(f"Saved training history to {filepath}")
#     return filepath