In [2]:
import torch
import sys
import os
import numpy as np
import time
sys.path.append(os.path.abspath("../..")) 
from datetime import datetime

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

from dataloader.dataloader import load_MNIST_data,load_CINIC10_data,load_CIFAR10_data,load_Fmnist_data,load_Covtype_data
from train.training import train_for_DEQ, train_nomal
from train.evaluate import plot_loss_curve,plot_errorbar_losscurve,plot_confusion_matrix,plot_histograms,create_table,convergence_verify, convergence_verify_tabular
from result_management.data_manager import save_csv,auto_git_push,save_experiment_report,create_result_pdf
now = datetime.now()
formatted_time = now.strftime("%m%d%H%M")
formatted_time = int(formatted_time)
print(f'-----Formatted time: {formatted_time} -----')

GPU name: NVIDIA GeForce GTX 1660
Using device: cuda
-----Formatted time: 8191641 -----


In [4]:
experiment_type = "DEQ"
experiment_name = f"{experiment_type}{formatted_time}"

variable_param = "leverage" #ここで設定した項目は配列にすること(none,leverage,alpha)
save = True

params = {
    'none':[0], #variable_param=noneの際は1回だけ繰り返す
    #data---------------------------------------------
    'dataset': 'covtype', # 'mnist', 'cifar-10', 'cinic-10' , 'fashion-mnist'
    'batch_size': 512, #64 MNIST, 100 CIFAR10, 100 CINIC10

    #Encoder_Model--------------------------------
    'enc_type': 'PM', # 'none', 'MZM', 'LI'
    'alpha': np.pi/2, 
    #位相変調機の感度[np.pi*2,np.pi, np.pi/2, np.pi/4, np.pi/8, np.pi/16],pi:-π~π
    #class_model--------------------------------------
    'cls_type': 'MLP', # 'MLP' or 'CNN'
    'num_layer': 2,
    'fc': 'relu',
    'dropout': 0.0,

    #learning-----------------------------------------
    'loss_func': 'cross_entropy',
    'optimizer': 'adam',
    'lr': 0.001,

    #param--------------------------------------------
    'num_try': 3,
    'max_epochs': 10,
    'leverage': [2,6,9,18,27,54], #mnist:[1,2,4,8,16],cinic:[1,2,3,4,6,8,12,16,24,48] enc is not none
    'kernel_size': 4,

    #anderson param-----------------------------------
    'm': 5,
    'lam': 1e-5, 
    'num_iter': 50,
    'tol': 1e-4,  #早期終了条件
    'beta': 1.0,
    'gamma' : 0 #SNLinearRelaxのgamma値(使わない)
}
#save---------------------------------------------
folder_params = {k: params[k] for k in ['dataset', 'enc_type', 'cls_type']}
if save:
    save_experiment_report(variable_param, params,experiment_name=experiment_name)

data_loaders = {
    'covtype': load_Covtype_data,
    'cifar-10': load_CIFAR10_data,
    'cinic-10': load_CINIC10_data,
    'mnist': load_MNIST_data,
    'fashion-mnist':load_Fmnist_data
}

data_train,data_test = data_loaders[params["dataset"]]()

if params["enc_type"] == 'none':
    params["leverage"] = 1
results = []
All_last_LOSSs_ = []
All_last_ACCs_ = []
All_TIMEs_ = []

for variable in params[variable_param]: #variable:leverage,alpha
    print(f'----------------------Running with {variable_param}: {variable}----------------------')
#-----------------------------------------------------
    Relres_ = []
    Unresovable = 0
    k = 1000
    Show_rel = False
    for i in range(k):
        relres = convergence_verify_tabular(params,gamma=params['gamma'],data_train=data_train,data_test=data_test,device=device,Show=Show_rel)
        Relres_.append(len(relres))
        if len(relres) > 40:
            Unresovable += 1
        sys.stderr.write(f"\rIteration {i+1}/{k} completed. Current length: {len(relres)}")
        sys.stdout.flush()
    time.sleep(1)
    print(f"Average number of iterations: {np.mean(Relres_)}")
    print(f"Unresolvable cases: {Unresovable}")
#-----------------------------------------------------
    All_last_loss = []
    All_loss_test = []
    All_pro_time = []
    All_test_acc = []

    for num_times in range(params['num_try']):

        params_for_train = {k: v for k,v in params.items() if k not in ('none',variable_param)}#配列を除外

        if variable_param != 'none': #leverageやalpha可変のとき
            params_for_train.update({'num_times': num_times, variable_param: variable,'device': device})
        else: #パラメータ不変のとき
            params_for_train.update({'num_times': num_times,'device': device})
        
        #-----------training-----------
        loss_train_,loss_test_,pro_time_,Last_loss_test,Test_acc,all_labels,all_preds = train_for_DEQ(**params_for_train,data_train=data_train,data_test=data_test)

        All_loss_test.append(loss_test_)
        All_pro_time.append(sum(pro_time_))
        All_last_loss.append(Last_loss_test)
        All_test_acc.append(Test_acc)
        if save:
            datas = [loss_train_,loss_test_,all_labels,all_preds,Test_acc]
            save_csv(datas,variable_param,variable,num_times,**folder_params,save_type='trial',experiment_name=experiment_name)
        print(f'Test Accuracy: {Test_acc}')
        # plot_loss_curve(loss_train_,loss_test_)
        # plot_confusion_matrix(all_labels,all_preds,params["dataset"],Test_acc)

    if save:
        datas = [All_loss_test,All_test_acc,All_last_loss,All_pro_time]
        save_csv(datas,variable_param,variable,num_times,**folder_params,save_type='mid',experiment_name=experiment_name)

        datas = [Relres_,np.mean(Relres_),Unresovable]
        save_csv(datas,variable_param,variable,num_times,**folder_params,save_type='relres',experiment_name=experiment_name)
    print(f"Test Accuracy:{Test_acc:.2f}")
    # plot_errorbar_losscurve(All_loss_test)
    # create_table(All_test_acc,All_last_loss,All_pro_time)

    All_last_ACCs_.append(All_test_acc)
    All_last_LOSSs_.append(All_last_loss)
    All_TIMEs_.append(All_pro_time)

if variable_param != 'none'and save:
    datas = [All_last_ACCs_,All_last_LOSSs_,All_TIMEs_]
    save_csv(datas,variable_param,variable,num_times,**folder_params,save_type='final',experiment_name=experiment_name) #最終保存

if save:
    create_result_pdf(variable_param, params, experiment_name=experiment_name)

実験パラメータ報告書を保存しました: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\experiment_report.txt
----------------------Running with leverage: 2----------------------


Iteration 1000/1000 completed. Current length: 9

Average number of iterations: 9.02
Unresolvable cases: 0


1/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_1th_.csv
Test Accuracy: 79.85163894219599


3/3th Epoch:1/10(0.11%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_2th_.csv
Test Accuracy: 81.88687039060953


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_3th_.csv
Test Accuracy: 79.98416564116245
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_relres.csv
Test Accuracy:79.98
----------------------Running with leverage: 6----------------------


1/3th Epoch:1/10(0.11%) leted. Current length: 7

Average number of iterations: 9.061
Unresolvable cases: 0


2/3th Epoch:1/10(0.11%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\6leverage_1th_.csv
Test Accuracy: 67.56021789454661


3/3th Epoch:1/10(0.11%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\6leverage_2th_.csv
Test Accuracy: 68.98789187886716


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\6leverage_3th_.csv
Test Accuracy: 64.09989415075343
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\6leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\6leverage_relres.csv
Test Accuracy:64.10
----------------------Running with leverage: 9----------------------


1/3th Epoch:1/10(0.11%) leted. Current length: 9

Average number of iterations: 9.052
Unresolvable cases: 0


2/3th Epoch:1/10(0.22%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\9leverage_1th_.csv
Test Accuracy: 62.764300405325166


3/3th Epoch:1/10(0.11%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\9leverage_2th_.csv
Test Accuracy: 57.457208505804495


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\9leverage_3th_.csv
Test Accuracy: 55.97962186862645
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\9leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\9leverage_relres.csv
Test Accuracy:55.98
----------------------Running with leverage: 18----------------------


1/3th Epoch:1/10(0.22%) leted. Current length: 13

Average number of iterations: 9.032
Unresolvable cases: 0


1/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\18leverage_1th_.csv
Test Accuracy: 54.33680713923049


3/3th Epoch:1/10(0.33%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\18leverage_2th_.csv
Test Accuracy: 52.977117630353774


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\18leverage_3th_.csv
Test Accuracy: 49.67685860089671
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\18leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\18leverage_relres.csv
Test Accuracy:49.68
----------------------Running with leverage: 27----------------------


1/3th Epoch:1/10(0.44%) leted. Current length: 9

Average number of iterations: 8.866
Unresolvable cases: 0


2/3th Epoch:1/10(0.22%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\27leverage_1th_.csv
Test Accuracy: 48.78359422734353


2/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\27leverage_2th_.csv
Test Accuracy: 51.69229710076332


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\27leverage_3th_.csv
Test Accuracy: 48.9212843041918
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\27leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\27leverage_relres.csv
Test Accuracy:48.92
----------------------Running with leverage: 54----------------------


1/3th Epoch:1/10(0.77%) leted. Current length: 10

Average number of iterations: 8.896
Unresolvable cases: 0


2/3th Epoch:1/10(0.55%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\54leverage_1th_.csv
Test Accuracy: 48.78359422734353


3/3th Epoch:1/10(0.44%) %) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\54leverage_2th_.csv
Test Accuracy: 51.62775487724069


3/3th Epoch:10/10(100.00%) 

Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\54leverage_3th_.csv
Test Accuracy: 49.24313485882465
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\54leverage_mid.csv
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\54leverage_relres.csv
Test Accuracy:49.24
Saved at: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\Final_results.csv
Loading CSV file: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_1th_.csv
Loading CSV file: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_2th_.csv
Loading CSV file: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\MLP\DEQ8191641\2leverage_3th_.csv
Loading CSV file: C:\Users\Scent\OneDrive\PhotonicEncoder_data\covtype\leverage_variable\PM\ML