In [None]:
import torch
from models import (
    NoSkipConnectionSupervenientFeatureNetwork,
    SkipConnectionSupervenientFeatureNetwork,
    DownwardSmileMIEstimator
)
from datasets import BitStringDataset, ECoGDataset
import lovely_tensors as lt
import wandb
import tqdm
from trainers import train_feature_network

seed = 0
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "torch_seed": seed,
    "dataset_type": "bits",
    "num_atoms": 6,
    "batch_size": 1000,
    "train_mode": False,
    "train_model_B": False,
    "adjust_Psi": False,
    "clip": 5,
    "feature_size": 1,
    "epochs": 5,
    "start_updating_f_after": 0,
    "update_f_every_N_steps": 0,
    "minimize_neg_terms_until": 0,
    "downward_critics_config": {
        "hidden_sizes_v_critic": [512, 1024, 1024, 512],
        "hidden_sizes_xi_critic": [512, 512, 512],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
    "decoupled_critic_config": {
        "hidden_sizes_encoder_1": [512, 512, 512],
        "hidden_sizes_encoder_2": [512, 512, 512],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
    "feature_network_config": {
        "hidden_sizes": [256, 256, 256, 256, 256],
        "lr": 1e-4,
        "bias": True,
        "weight_decay": 1e-3,
    }
}




In [None]:
import os
import tqdm

In [3]:
import torch

In [None]:
import numpy as np

In [None]:
print("Loading dataset")

In [None]:
bits_dataset_config = {
    "num_data_points": int(1e6),
    "extra_bit_correlation": 0.99,
    "parity_bit_correlation": 0.99,
}


dataset = BitStringDataset(
    gamma_parity=bits_dataset_config['parity_bit_correlation'],
    gamma_extra=bits_dataset_config['extra_bit_correlation'],
    length=bits_dataset_config['num_data_points'],
)

trainloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

print("Dataset loaded")

In [None]:
skip_model = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_network_config']['hidden_sizes'],
    include_bias=config['feature_network_config']['bias'],
).to(device)

no_skip_model = NoSkipConnectionSupervenientFeatureNetwork(
    num_atoms=config['num_atoms'],
    feature_size=config['feature_size'],
    hidden_sizes=config['feature_network_config']['hidden_sizes'],
    include_bias=config['feature_network_config']['bias'],
).to(device)




In [None]:



project_name = "no-skip-connection-MIs"

no_skip_model = train_feature_network(
    config=config,
    trainloader=trainloader,
    feature_network_training=no_skip_model,
    project_name=project_name
)