In [14]:
import torch
from bit_dataset import BitStringDataset
import lovely_tensors as lt
import wandb
import tqdm
from einops import rearrange, reduce, repeat
from models import (SupervenientFeatureNetwork,
                    CLUB,
                    DecoupledSmileMIEstimator,
                    DownwardSmileMIEstimator,
                    GeneralSmileMIEstimator,
                    SkipConnectionSupervenientFeatureNetwork
                    )
from trainers import train_feature_network

lt.monkey_patch()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

torch.manual_seed(42)



<torch._C.Generator at 0x74f53c5aedd0>

In [None]:
config = {
    "epochs": 20,
    "batch_size": 1000,
    "num_atoms": 6,
    "feature_size": 1,
    "clip": 5,
    "critic_output_size": 32,
    "downward_hidden_sizes_v_critic": [512, 512, 512, 256],
    "downward_hidden_sizes_xi_critic": [512, 512, 512, 256],
    "feature_hidden_sizes": [256],
    "decoupled_critic_hidden_sizes_1": [512, 512, 512],
    "decoupled_critic_hidden_sizes_2": [512, 512, 512],
    "feature_lr": 1e-5,
    "decoupled_critic_lr": 1e-4,
    "downward_lr": 1e-4,    
    "bias": True,
    "update_f_every_N_steps": 10,
    "weight_decay": 0,
    "start_updating_f_after": 1000,
}





In [15]:
bits_dataset_config = {
    "num_data_points": int(1e6),
    "extra_bit_correlation": 0.99,
    "parity_bit_correlation": 0.99,
}


config = {
    "dataset_type": "bits",
    "batch_size": 1000,
    "training_mode": True,
    "train_model_B": False,
    "adjust_Psi": False,
    "clip": 5,
    "epochs": 20,
    "downward_critics_config": {
        "hidden_sizes_v_critic": [512, 512, 512, 256],
        "hidden_sizes_xi_critic": [512, 512, 512, 256],
        "critic_output_size": 32,
        "lr": 1e-4,
        "include_bias": True,
        "weight_decay": 0,
    },
    "decoupled_critic_config": {
        "hidden_sizes_encoder_1": [512, 512, 512],
        "hidden_sizes_encoder_2": [512, 512, 512],
        "critic_output_size": 32,
        "lr": 1e-4,
        "include_bias": True,
        "weight_decay": 0,
    },
    "feature_network_config": {
        "num_atoms": 6,
        "feature_size": 1,
        "hidden_sizes": [256, 256],
        "lr": 1e-5,
        "include_bias": True,
        "weight_decay": 0,
        "update_f_every_N_steps": 5,
        "start_updating_f_after": 1000,
    }
}



In [16]:


dataset = BitStringDataset(
    gamma_parity=bits_dataset_config['parity_bit_correlation'],
    gamma_extra=bits_dataset_config['extra_bit_correlation'],
    length=bits_dataset_config['num_data_points'],
)

trainloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

In [None]:


SkipConnectionSupervenientFeatureNetwork(
    num_atoms=config['feature_netork_config']['num_atoms'],
    feature_size=config['feature_netork_config']['feature_size'],
    hidden_sizes=config['feature_netork_config']['hidden_sizes'],
    include_bias=config[