In [None]:
import torch
from models import (
    NoSkipConnectionSupervenientFeatureNetwork,
    SkipConnectionSupervenientFeatureNetwork,
    DownwardSmileMIEstimator
)
from datasets import BitStringDataset, ECoGDataset
import lovely_tensors as lt
import wandb
import tqdm
from trainers import train_feature_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:

for seed in range(2):

    config = {
        "torch_seed": seed,
        "dataset_type": "bits",
        "num_atoms": 6,
        "batch_size": 1000,
        "train_mode": False,
        "train_model_B": False,
        "adjust_Psi": False,
        "clip": 5,
        "feature_size": 1,
        "epochs": 4,
        "start_updating_f_after": 0,
        "update_f_every_N_steps": 0,
        "minimize_neg_terms_until": 0,
        "downward_critics_config": {
            "hidden_sizes_v_critic": [512, 1024, 1024, 512],
            "hidden_sizes_xi_critic": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "decoupled_critic_config": {
            "hidden_sizes_encoder_1": [512, 512, 512],
            "hidden_sizes_encoder_2": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "feature_network_config": {
            "hidden_sizes": [256, 256, 256, 256, 256],
            "lr": 1e-4,
            "bias": True,
            "weight_decay": 1e-3,
        }
    }


    no_skip_model = NoSkipConnectionSupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_network_config']['hidden_sizes'],
        include_bias=config['feature_network_config']['bias'],
    ).to(device)
    
    project_name = "no-skip-connection-MIs"

    no_skip_model = train_feature_network(
        config=config,
        trainloader=trainloader,
        feature_network_training=no_skip_model,
        project_name=project_name
    )

In [None]:
import torch
from models import SkipConnectionSupervenientFeatureNetwork
from datasets import ECoGDataset
from trainers import train_feature_network

dataset = ECoGDataset()

device = 'cuda'

trainloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

for seed in range(1):

    torch.manual_seed(seed)

    config = {
        "torch_seed": seed,
        "dataset_type": "ecog",
        "num_atoms": 64,
        "batch_size": 1000,
        "train_mode": False,
        "train_model_B": False,
        "adjust_Psi": True,
        "clip": 5,
        "feature_size": 3,
        "epochs": 5,
        "start_updating_f_after": 500,
        "update_f_every_N_steps": 5,
        "minimize_neg_terms_until": 9999999999,
        "downward_critics_config": {
            "hidden_sizes_v_critic": [512, 1024, 1024, 512],
            "hidden_sizes_xi_critic": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "decoupled_critic_config": {
            "hidden_sizes_encoder_1": [512, 512, 512],
            "hidden_sizes_encoder_2": [512, 512, 512],
            "critic_output_size": 32,
            "lr": 1e-3,
            "bias": True,
            "weight_decay": 0,
        },
        "feature_network_config": {
            "hidden_sizes": [256, 256, 256, 256, 256],
            "lr": 1e-4,
            "bias": True,
            "weight_decay": 1e-3,
        }
    }

    # skip_model = SkipConnectionSupervenientFeatureNetwork(
    #     num_atoms=config['num_atoms'],
    #     feature_size=config['feature_size'],
    #     hidden_sizes=config['feature_network_config']['hidden_sizes'],
    #     include_bias=config['feature_network_config']['bias'],
    # ).to(device)

    skip_model_A = SkipConnectionSupervenientFeatureNetwork(
        num_atoms=config['num_atoms'],
        feature_size=config['feature_size'],
        hidden_sizes=config['feature_network_config']['hidden_sizes'],
        include_bias=config['feature_network_config']['bias'],
    ).to(device)

    model_A_path = "/vol/bitbucket/dm2223/info-theory-experiments/models/feature_network_earthy-sun-6.pth"

    skip_model_A.load_state_dict(torch.load(model_A_path))

    project_name = "verifying-EFs-ecog"

    skip_model = train_feature_network(
        config=config,
        trainloader=trainloader,
        feature_network_training=skip_model_A,
        project_name=project_name
    )
