In [1]:
import torch
from models import (
    NoSkipConnectionSupervenientFeatureNetwork,
    SkipConnectionSupervenientFeatureNetwork,
    DownwardSmileMIEstimator
)
from datasets import BitStringDataset, ECoGDataset
import lovely_tensors as lt
import wandb
import tqdm
from trainers import train_feature_network





In [2]:
torch.cuda.empty_cache()

In [3]:
bits_dataset_config = {
    "num_data_points": int(1e6),
    "extra_bit_correlation": 0.99,
    "parity_bit_correlation": 0.99,
}

dataset = BitStringDataset(
    gamma_parity=bits_dataset_config['parity_bit_correlation'],
    gamma_extra=bits_dataset_config['extra_bit_correlation'],
    length=bits_dataset_config['num_data_points'],
)

trainloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

print("Dataset loaded")

NameError: name 'config' is not defined

In [None]:
for seed in range(1,6):
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # clear memory
    torch.cuda.empty_cache()

    config = {
        "torch_seed": seed,
        "dataset_type": "bits",
        "num_atoms": 6,
        "batch_size": 1000,
        "train_mode": False,
        "train_model_B": False,
        "adjust_Psi": False,
        "clip": 5,
        "feature_size": 1,
        "epochs": 6,
        "start_updating_f_after": 0,
        "update_f_every_N_steps": 0,
        "minimize_neg_terms_until": 0,
        "downward_critics_config": {
            "hidden_sizes_v_critic": [512, 1024, 1024, 512],
            "hidden_sizes_xi_critic": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "decoupled_critic_config": {
            "hidden_sizes_encoder_1": [512, 512, 512],
            "hidden_sizes_encoder_2": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "feature_network_config": {
            "hidden_sizes": [256, 256, 256, 256, 256],
            "lr": 1e-4,
            "bias": True,
            "weight_decay": 1e-3,
        }
    }


    skip_model = SkipConnectionSupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_network_config']['hidden_sizes'],
        include_bias=config['feature_network_config']['bias'],
    ).to(device)

    project_name = "NEURIPS-FINAL-skip-connection-MIs"

    skip_model = train_feature_network(
        config=config,
        trainloader=trainloader,
        feature_network_training=skip_model,
        project_name=project_name
    )

In [4]:
# clear cache
del skip_model

torch.cuda.empty_cache()

for seed in range(0,6):
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # clear memory
    torch.cuda.empty_cache()

    config = {
        "torch_seed": seed,
        "dataset_type": "bits",
        "num_atoms": 6,
        "batch_size": 1000,
        "train_mode": False,
        "train_model_B": False,
        "adjust_Psi": False,
        "clip": 5,
        "feature_size": 1,
        "epochs": 6,
        "start_updating_f_after": 0,
        "update_f_every_N_steps": 0,
        "minimize_neg_terms_until": 0,
        "downward_critics_config": {
            "hidden_sizes_v_critic": [512, 1024, 1024, 512],
            "hidden_sizes_xi_critic": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "decoupled_critic_config": {
            "hidden_sizes_encoder_1": [512, 512, 512],
            "hidden_sizes_encoder_2": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "feature_network_config": {
            "hidden_sizes": [256, 256, 256, 256, 256],
            "lr": 1e-4,
            "bias": True,
            "weight_decay": 1e-3,
        }
    }


    no_skip_model = NoSkipConnectionSupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_network_config']['hidden_sizes'],
        include_bias=config['feature_network_config']['bias'],
    ).to(device)

    project_name = "NEURIPS-FINAL-NO-skip-connection-MIs"

    no_skip_model = train_feature_network(
        config=config,
        trainloader=trainloader,
        feature_network_training=no_skip_model,
        project_name=project_name
    )

NameError: name 'skip_model' is not defined

In [None]:
# train some models on ecog data

from datasets import ECoGDataset
import torch
from models import SkipConnectionSupervenientFeatureNetwork
from trainers import train_feature_network

dataset = ECoGDataset()

device = 'cuda'

trainloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

project_name = "learning-emergent-ecog-features-with-infomin"

for seed in range(5,8):

    config = {
        "torch_seed": seed,
        "dataset_type": "ecog",
        "num_atoms": 64,
        "batch_size": 1000,
        "train_mode": True,
        "train_model_B": False,
        "adjust_Psi": False,
        "clip": 5,
        "feature_size": 3,
        "epochs": 70,
        "start_updating_f_after": 500,
        "update_f_every_N_steps": 5,
        "minimize_neg_terms_until": 9999999999,
        "downward_critics_config": {
            "hidden_sizes_v_critic": [512, 1024, 1024, 512],
            "hidden_sizes_xi_critic": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "decoupled_critic_config": {
            "hidden_sizes_encoder_1": [512, 512, 512],
            "hidden_sizes_encoder_2": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "feature_network_config": {
            "hidden_sizes": [256, 256, 256, 256, 256],
            "lr": 1e-4,
            "bias": True,
            "weight_decay": 1e-3,
        }
    }


    torch.manual_seed(seed)

    model = SkipConnectionSupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_network_config']['hidden_sizes'],
        include_bias=config['feature_network_config']['bias'],
    ).to(device)

    model = train_feature_network(
        config=config,
        trainloader=trainloader,
        feature_network_training=model,
        project_name=project_name
    )





    
