In [1]:
from data_gen import generate_data
from train import train
import numpy as np
import pandas as pd

In [2]:
# default args
datagen_args={}
datagen_args['T'] = 1000
datagen_args['corr'] = 1.0
datagen_args['A'] = 2
datagen_args['N'] = 10
datagen_args['M'] = 1
datagen_args['K'] = 8
datagen_args['num_seeds'] = 2
datagen_args['action_selection_method'] = 'greedy'
datagen_args['ensemble'] = 'sum'
datagen_args['ground_model_name'] = 'bitpop'
datagen_args['output'] = 'output/'

train_args = {}
train_args['model_name']='stomp'
train_args['P']=1e6
train_args['M']=1
train_args['L']=100
train_args['n_hidden_layers']=2
train_args['n_features']=2
train_args['num_codebooks']=10
train_args['enc2dec_ratio']=1
train_args['epochs']=20
train_args['learning_rate']=5e-5
train_args['batch_size']=8
train_args['outdir']='output/'
train_args['data_dir']=''
train_args['seed']=0
train_args['data_seed']=0
train_args['use_lr_scheduler']=False
train_args['step_LR']=30
train_args['gamma']=0.1
train_args['checkpoint_interval']=100
train_args['wandb_entity_name']=None
train_args['wandb_group_name']=None
train_args['wandb_job_type_name']=None

In [3]:
# variable parameters
corrvec=[0,1] # [0, 0.5, 1.0]   # agent pairwise action correlation
Nvec=np.array([10,100]) # [1e1,1e2,1e3,1e4]   # number of agents
single_agent_capacity = 256*100
Pvec=Nvec*single_agent_capacity  # number of parameters

# datagen vars
K=8             # state space dimension
M_sys=1         # number of agent groups
T=int(1e4)      # number of samples to learn from
data_seed = 0   # seed of data generation

# train vars
M_train=1       # assumed number of agent groups
epochs=1

datagen_args['K']=K
datagen_args['M']=M_sys
datagen_args['T']=T
datagen_args['ground_model_name'] = 'bitpop'

train_args['model_name']='single'
train_args['M']=M_train
train_args['data_seed']=data_seed
train_args['epochs']=epochs

In [4]:
# generate data from bitpop
hashtype='bitpop_data'
hash_data_list=[]
for corr in corrvec:
    datagen_args['corr']=corr
    for N in Nvec:
        datagen_args['N']=N
        bitpop_data_hash=generate_data(datagen_args.copy())
        hash_data_list.append((corr,N,-1,hashtype,bitpop_data_hash))
df=pd.DataFrame(hash_data_list,columns=['corr','N','P','hashtype','hash'])

bitpop
running seed 0 of 2
running seed 1 of 2
saving data_d1bac3a730
bitpop
running seed 0 of 2
running seed 1 of 2
saving data_4cace0bd77
bitpop
running seed 0 of 2
running seed 1 of 2
saving data_104df71b36
bitpop
running seed 0 of 2
running seed 1 of 2
saving data_ab133d9f57


In [5]:
# train simple
hashtype='train_simple'
hash_data_list=[]
for corr in corrvec:
    for nit,N in enumerate(Nvec):
        train_args['data_dir']='data_' + df.loc[
            (df['corr']==corr) & (df['N']==N) & (df['hashtype']=='bitpop_data'),'hash'].values[0]
        train_args['P']=Pvec[nit]
        train_simple_hash=train(train_args.copy())
        hash_data_list.append((corr,N,Pvec[nit],hashtype,bitpop_data_hash))
dftmp=pd.DataFrame(hash_data_list,columns=['corr','N','P','hashtype','hash'])
df=pd.concat((df,dftmp))

Using cpu device
seed 0 training of single model with modelsize 10 for 1 epochs using batchsize 8 and LR 5e-05
wandb run name: 4c842ab635_20240523-162941
using data:output/data_d1bac3a730/data.h5
{'file_attrs': {'A': 2, 'K': 8, 'M': 1, 'N': 10, 'T': 10000, 'action_selection_method': 'greedy', 'corr': 0, 'ensemble': 'sum', 'ground_model_name': 'bitpop', 'hash': 'd1bac3a730', 'num_seeds': 2, 'output': 'output/', 'timestamp': '20240523_162928'}}
state_dim: (10000, 8)
10
hidden_dim=0
number of parameters: 20
gap between P and num_parameters:  -10




pre training loss: 0.08664339780807495, acc: 0.48552


KeyboardInterrupt: 

In [None]:
# generate data from trained simple
hashtype = 'simple_data'
hash_data_list=[]
for corr in corrvec:
    for nit,N in enumerate(Nvec):
        data_hash=df.loc[
            (df['corr']==corr) & (df['N']==N) & (df['hashtype']=='bitpop_data'),'hash'].values[0]
        train_hash=df.loc[
            (df['corr']==corr) & (df['N']==N) & (df['hashtype']=='train_simple'),'hash'].values[0]
        datagen_args['ground_model_name'] = data_hash+'/'+train_hash
        simple_data_hash=generate_data(datagen_args.copy())
        hash_data_list.append((corr,N,Pvec[nit],hashtype,bitpop_data_hash))
dftmp=pd.DataFrame(hash_data_list,columns=['corr','N','P','hashtype','hash'])
df=pd.concat((df,dftmp))

In [None]:
# store
data_store={}
data_store['datagen_args']=datagen_args
data_store['train_args']=train_args
data_store['hashes']=df
data_filename = f"hashlist_K_{K}_Msys_{M_sys}_T_{T}_Mtrain_{M_train}_Ep_{epochs}_dataseed_{data_seed}"
np.save(data_filename+".npy",data_store)
df.to_csv(data_filename, index=False)

In [None]:

# #train match
# train_args['model_name']='match'
# match_train_hashes = []
# for data_hash in simple_data_hashes:
#   train_args['data_dir']='data_'+data_hash
#   match_train_hashes.append(train(train_args.copy()))
# # write_hashes(match_train_hashes,'match_train')
# hashlist_dict['match_train']=match_train_hashes
# print(''.join(['\n']*10))
# np.save(hashlist_filename,hashlist_dict)

# #train mlp
# train_args['model_name']='match'
# match_train_hashes = []
# for data_hash in simple_data_hashes:
#   train_args['data_dir']='data_'+data_hash
#   match_train_hashes.append(train(train_args.copy()))
# # write_hashes(match_train_hashes,'match_train')
# hashlist_dict['match_train']=match_train_hashes
# print(''.join(['\n']*10))
# np.save(hashlist_filename,hashlist_dict)

# #generate data from trained match
# match_data_hashes=[]
# for datahash,trainhash in zip(simple_data_hashes,match_train_hashes):
#   datagen_args['ground_model_name'] = datahash+'/'+trainhash
#   match_data_hashes.append(generate_data(datagen_args.copy()))
# write_hashes(match_data_hashes,'match_data')
# hashlist_dict['match_data']=match_data_hashes
# print(''.join(['\n']*10))

# np.save(hashlist_filename,hashlist_dict)


# def write_hashes(hash_list,hash_name,file_name=hashlist_filename):
#   with open(file_name,'a') as f:
#       f.write(hash_name)
#       for ha in hash_list:
#           f.write(ha)
# write_hashes(bitpop_data_hashes,'bitpop_data')
# write_hashes(simple_train_hashes,'simple_train')
# write_hashes(simple_data_hashes,'single_data')
