In [None]:
from pathlib import Path
from VAE_testing.vae_generate_data import PADSDataSynthesizer

# 1. Initialize synthesizer with your preprocessed dataset
preprocessed_dir = Path("../../../project_datasets/tremor/")
synthesizer = PADSDataSynthesizer(preprocessed_dir)

In [None]:
# 2. Train the TVAE model
synthesizer.train(epochs=300, batch_size=500)

In [None]:
# 3. Generate synthetic samples
synthetic_data = synthesizer.generate_synthetic_data(num_samples=200)

In [None]:
# 4. Display results
print("\n" + "="*50)
print("SYNTHETIC DATA SAMPLE")
print("="*50)
print(synthetic_data.head(10))
print(f"\nShape: {synthetic_data.shape}")

# 5. Evaluate quality
quality_report = synthesizer.evaluate_quality(synthetic_data)

In [None]:
# 6. Save the trained model
synthesizer.save_model("pads_tvae_model.pkl")

# 7. Save synthetic data
synthetic_data.to_csv("synthetic_pads_metadata.csv", index=False)
print("\nSynthetic data saved to 'synthetic_pads_metadata.csv'")

In [None]:
from Models.CNN_models.model_V9 import TremorNetV9

models = {}

# Create a model per movement
for movement_name in all_dataloaders.keys():
    model_name = f"{movement_name}_model"
    models[model_name] = TremorNetV9()

print(f"Created {len(models)} models:")
for name in models.keys():
    print(f" - {name}")

In [None]:
from training.trainer import train

# Train each model on its respective movement
for model_name, model in models.items():
    # Extract movement name by removing "_model"
    movement_name = model_name.replace("_model", "")
    
    # Get the corresponding dataloaders
    train_loader = all_dataloaders[movement_name]["train"]
    test_loader  = all_dataloaders[movement_name]["val"]

    
    # Set dynamic run and save names
    run_name = model_name
    save_name = f"{model_name}.pth"
    
    print(f"\n[INFO] Training {model_name} ...")
    
    train(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=test_loader,
        
        model_name=save_name,
        run_name=run_name,
        
        epochs = 50,
        per_movement = True,
        
        max_lr= 1e-4,
        debug_mode= False)