In [1]:
import sys
import os
current_dir = os.getcwd()  # 获取当前工作目录
parent_dir = os.path.dirname(current_dir)  # 获取父目录
sys.path.append(parent_dir)

import pandas as pd
import pennylane as qml
import pickle

import logging 
from datetime import datetime
from tqdm import *
import argparse

import jax 
import jax.numpy as jnp  
import optax
from flax import nnx 


from datasets_utils import get_quantum_dataloaders
from model import DataReuploading
from train_utils import ClassificationTrainer
import wandb


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)


In [4]:
n_qubits = 5
n_layers = 1
n_samples = 5000
max_layers = 1
n_reps = 16
n_epochs = 50
n_repeats = 1
seed = 0
batch_size = 200
lr = 0.01
optax_optimizer = 'adam'
loss_fn = 'cross_entropy'
project_name = 'classification_linear'

In [5]:
config = {
    'n_qubits': n_qubits,
    'n_layers': n_layers,
    'max_layers':max_layers,
    'n_reps': n_reps,
    'optimizer': 'adam',
    'loss_fn': 'cross_entropy',
    'batch_size': batch_size,
    'learning_rate': lr,
    'n_epochs': n_epochs,
    'n_repeats': n_repeats,
    'seed':seed,
    'use_wandb': True,
    'save_epoch_metrics': False,
    'test_every_epoch': True,
    'save_best_model': True,
    'project_name': project_name,
    'group_name': f'qubits_{n_qubits}_layers_{n_layers}_reps_{n_reps}_samples_{n_samples}'
}

In [14]:
train_loader, test_loader = get_quantum_dataloaders(n_qubits=n_qubits, n_layers=n_layers, n_samples=n_samples, data_type="classification_linear",batch_size=batch_size)

In [15]:
qnet = DataReuploading(n_qubits=n_qubits, n_reps=n_reps, n_layers=n_layers,max_layers=max_layers,measurement_type="probs",measure_wires = [0],seed=seed,ansatz_type="zero_padding")

In [7]:
for batch in train_loader:
    break;

In [17]:
qnet.quantum_model.vn_entropy(x)

Array(0.68671433, dtype=float64)

In [18]:
qnet

DataReuploading( # Param: 384 (3.1 KB)
  quantum_model=ZeroPaddingCircuit( # Param: 384 (3.1 KB)
    n_qubits=8,
    interface='jax',
    device=<default.qubit device (wires=8) at 0x7ee9d2bc18d0>,
    measurement_type='probs',
    hamiltonian=None,
    measure_wires=[0],
    params=Param( # 384 (3.1 KB)
      value=Array(shape=(16, 1, 8, 3), dtype=dtype('float64'))
    ),
    n_reps=16,
    n_layers=1,
    max_layers=1,
    shape=(16, 1, 8, 3)
  )
)

In [19]:
trainer = ClassificationTrainer(config, qnet, train_loader, test_loader)

In [20]:
qnet,metrics = trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mwangxiaojin12138[0m ([33mx-wang-tsinghua[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training epochs: 100%|██████████| 50/50 [09:47<00:00, 11.74s/it]
2025-03-26 21:57:04,329 - INFO - Training completed, final metrics: {'loss': 0.6807901462037488, 'accuracy': 0.6648000000000001, 'pred_error': 0.4870752403671223}


0,1
test_accuracy,▂▃▁▁▃▂▂▃▂▃▄▄▅▇▇▆▇█▇█▇██████▇▇▆▅▆▅▅▅▅▅▄▄▅
test_loss,▆▇██▇▆▆▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂
test_pred_error,▆▇██▇▆▆▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂
train_accuracy,▁▃▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████████
train_loss,█▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_pred_error,█▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_accuracy,0.5079
test_loss,0.69304
test_pred_error,0.4994
train_accuracy,0.6608
train_loss,0.6804
train_pred_error,0.48667


In [26]:
x = batch[0][2]

In [30]:
z = jnp.zeros_like(x)

In [31]:
qnet.quantum_model.vn_entropy(z)

Array(0.68672929, dtype=float64)

In [31]:
# metrics['final_metrics']

In [32]:
# metrics['epoch_metrics']['train']

In [33]:
from metric import Metrics

class ExperimentManager:
    def __init__(self, config):
        self.config = config
        self.results_df = pd.DataFrame()
        self.metrics = Metrics()
        self.model_list = []
        
        self.setup_config()
        self.setup_metrics()
        
       
    def setup_config(self):
        self.n_repeats = self.config['n_repeats']
     
    def setup_metrics(self):
        self.metrics.register_metric("loss",split="train",index_type="repeat")
        self.metrics.register_metric("accuracy",split="train",index_type="repeat")
        self.metrics.register_metric("pred_error",split="train",index_type="repeat")
        self.metrics.register_metric("loss",split="test",index_type="repeat")
        self.metrics.register_metric("accuracy",split="test",index_type="repeat")
        self.metrics.register_metric("pred_error",split="test",index_type="repeat")
        
        
        
    def run_experiments(self):
        """运行多次实验"""
        for i in range(self.n_repeats):
            print(f"\nRunning experiment {i+1}/{self.n_repeats}")
            # 每次实验使用不同的随机种子
            seed = i
            self.config['seed'] = seed
            qnet = DataReuploading(n_qubits=n_qubits, n_reps=n_reps, n_layers=n_layers,max_layers=max_layers,measurement_type="probs",measure_wires = [0],seed=seed,ansatz_type="zero_padding")
            trainer = ClassificationTrainer(self.config,qnet,train_loader,test_loader)
            qnet,train_metrics = trainer.train()
            _,test_metrics = trainer.test()
            self.metrics.update("loss", train_metrics['final_metrics']['loss'], split="train", index_type="repeat")
            self.metrics.update("accuracy", train_metrics['final_metrics']['accuracy'], split="train", index_type="repeat")
            self.metrics.update("pred_error", train_metrics['final_metrics']['pred_error'], split="train", index_type="repeat")
            self.metrics.update("loss", test_metrics['final_metrics']['loss'], split="test", index_type="repeat")
            self.metrics.update("accuracy", test_metrics['final_metrics']['accuracy'], split="test", index_type="repeat")
            self.metrics.update("pred_error", test_metrics['final_metrics']['pred_error'], split="test", index_type="repeat")
            
            model_params = qnet.quantum_model.get_params()
            self.model_list.append(model_params)
            
            
            
        train_results = self.metrics.get_metrics(split='train')
        train_values = train_results['values']
        train_stats = train_results['stats']

        test_results = self.metrics.get_metrics(split='test')
        test_values = test_results['values']
        test_stats = test_results['stats']

        # Create DataFrame with results
        data = {'Experiment': range(1, n_repeats+1)}

        # Add train metrics
        for key in train_values.keys():
            data[f'Train {key.capitalize()}'] = train_values[key]

        # Add test metrics  
        for key in test_values.keys():
            data[f'Test {key.capitalize()}'] = test_values[key]

        self.results_df = pd.DataFrame(data)

        
        print("\n" + "="*50)
        print("Statistics".center(50))
        print("="*50)
        for metric in train_stats.keys():
            print(f"\n{metric.capitalize()}:")
            print("-"*30)
            print("Train:")
            for stat_name, value in train_stats[metric].items():
                print(f"{stat_name.capitalize():>15}: {value:.4f}")
            print("-"*30)
            print("Test:")
            for stat_name, value in test_stats[metric].items():
                print(f"{stat_name.capitalize():>15}: {value:.4f}")
            print("-"*30)



        # 保存结果
        self.save_results()
        
    
    def save_results(self):
        """保存实验结果"""
        results_dir = '../results'
        os.makedirs(f'{results_dir}/{self.config["project_name"]}', exist_ok=True)
        # 保存DataFrame为CSV
        csv_path = f'{results_dir}/{self.config["project_name"]}/experiment_results_{self.config["group_name"]}.csv'
        self.results_df.to_csv(csv_path, index=False)
        
        # 保存所有数据（包括参数）到pickle文件
        full_results = {
            'config': self.config,
            'model_list': self.model_list
        }
        pickle_path = f'{results_dir}/{self.config["project_name"]}/full_results_{self.config["group_name"]}.pkl'
        with open(pickle_path, 'wb') as f:
            pickle.dump(full_results, f)
        
        print(f"Results saved to {csv_path} and {pickle_path}")   


In [34]:
experiment_manager = ExperimentManager(config)


In [None]:
experiment_manager.run_experiments()


Running experiment 1/2


Training epochs:   0%|          | 0/50 [00:00<?, ?it/s]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7e1afa808910>>
Traceback (most recent call last):
  File "/home/xwang/miniconda3/envs/quantum/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Training epochs:  20%|██        | 10/50 [00:12<00:13,  2.96it/s]

In [16]:
# Load the saved results
results_dir = '../results'
project_name = config["project_name"]
group_name = config["group_name"]
pickle_path = f'{results_dir}/{project_name}/full_results_{group_name}.pkl'

with open(pickle_path, 'rb') as f:
    loaded_results = pickle.load(f)

print("Loaded results:")
print(f"Config: {loaded_results['config']}")
print(f"Number of models: {len(loaded_results['model_list'])}")


Loaded results:
Config: {'n_qubits': 1, 'n_layers': 8, 'max_layers': 8, 'n_reps': 8, 'optimizer': 'adam', 'loss_fn': 'cross_entropy', 'batch_size': 200, 'learning_rate': 0.01, 'n_epochs': 50, 'n_repeats': 2, 'seed': 1, 'use_wandb': True, 'save_epoch_metrics': False, 'test_every_epoch': False, 'save_best_model': True, 'project_name': 'classification_linear', 'group_name': 'qubits_1_layers_8_reps_8_samples_600'}
Number of models: 2


In [18]:
loaded_results['model_list'][0]

Array([[[[-0.30038128,  1.04383215,  1.57592131]],

        [[-0.63471539, -0.00962056,  0.06373682]],

        [[-0.88191157,  1.9311169 , -0.39301373]],

        [[-0.35502995, -1.21814589,  0.09401016]],

        [[-0.08605331, -1.19914295, -0.8620287 ]],

        [[ 0.50431685, -0.50552769,  1.03440312]],

        [[-0.75342598,  1.58501085,  0.24514733]],

        [[-1.10183755, -0.31509729, -0.58563505]]],


       [[[-0.47320803,  0.62100707, -0.37348003]],

        [[-1.16540027,  0.19600207, -0.38075203]],

        [[ 0.64297934,  0.46571483,  0.22733414]],

        [[-0.7151162 , -0.28432722,  0.72989529]],

        [[ 1.80395613,  0.30776894, -1.25404772]],

        [[-1.17583088,  1.24825963,  2.04078925]],

        [[ 0.35782603, -2.22282976,  0.30992756]],

        [[-0.56769994,  0.94222343,  0.19204637]]],


       [[[ 0.74798515, -0.16757212,  0.4267965 ]],

        [[ 0.72150718, -1.3996204 ,  0.14527948]],

        [[ 0.26071809,  0.63204436,  1.12476995]],

        

In [19]:
qnet.quantum_model.update_params(loaded_results['model_list'][0])

In [20]:
qnet.quantum_model.get_params()

Array([[[[-0.30038128,  1.04383215,  1.57592131]],

        [[-0.63471539, -0.00962056,  0.06373682]],

        [[-0.88191157,  1.9311169 , -0.39301373]],

        [[-0.35502995, -1.21814589,  0.09401016]],

        [[-0.08605331, -1.19914295, -0.8620287 ]],

        [[ 0.50431685, -0.50552769,  1.03440312]],

        [[-0.75342598,  1.58501085,  0.24514733]],

        [[-1.10183755, -0.31509729, -0.58563505]]],


       [[[-0.47320803,  0.62100707, -0.37348003]],

        [[-1.16540027,  0.19600207, -0.38075203]],

        [[ 0.64297934,  0.46571483,  0.22733414]],

        [[-0.7151162 , -0.28432722,  0.72989529]],

        [[ 1.80395613,  0.30776894, -1.25404772]],

        [[-1.17583088,  1.24825963,  2.04078925]],

        [[ 0.35782603, -2.22282976,  0.30992756]],

        [[-0.56769994,  0.94222343,  0.19204637]]],


       [[[ 0.74798515, -0.16757212,  0.4267965 ]],

        [[ 0.72150718, -1.3996204 ,  0.14527948]],

        [[ 0.26071809,  0.63204436,  1.12476995]],

        