In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from custom_datasets import BitStringDataset
from torch.utils.data import DataLoader

# Define the RNN model
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])  # Use the last time step's output
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

# Training function
def train_rnn(config, trainloader, model, criterion, optimizer, device):
    model.train()
    for epoch in range(config['epochs']):
        for batch in trainloader:
            inputs = batch[:, :-1, :].to(device)
            targets = batch[:, -1, :].to(device)

            hidden = model.init_hidden(inputs.size(0)).to(device)
            optimizer.zero_grad()
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            wandb.log({"loss": loss.item()})



In [2]:
# # Main script
# def main():
#     # Configuration
#     config = {
#         'epochs': 5,
#         'batch_size': 32,
#         'learning_rate': 0.001,
#         'input_size': 6,
#         'hidden_size': 1,
#         'output_size': 6,
#         'gamma_parity': 0.9,
#         'gamma_extra': 0.9,
#         'length': 100000,
#         'num_runs': 5
#     }

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # Create directory for saving models
#     os.makedirs('RNNs', exist_ok=True)

#     for run in range(config['num_runs']):
#         wandb.init(project="RNN_BitStringDataset", config=config, name=f"run_{run+1}")

#         # Dataset and DataLoader
#         dataset = BitStringDataset(config['gamma_parity'], config['gamma_extra'], config['length'])
#         trainloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

#         # Model, criterion, and optimizer
#         model = SimpleRNN(config['input_size'], config['hidden_size'], config['output_size']).to(device)
#         criterion = nn.MSELoss()
#         optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

#         # Train the model
#         train_rnn(config, trainloader, model, criterion, optimizer, device)

#         # Save the model
#         torch.save(model.state_dict(), f"RNNs/rnn_model_run_{run+1}.pth")

#         wandb.finish()



# ###

# main()


In [3]:
import torch

def get_RNN_hiden_state_rep(input_batch):
    # Load the model
    model_path = 'RNNs/rnn_model_run_3.pth'
    input_size = 6
    hidden_size = 1
    output_size = 6
    model = SimpleRNN(input_size, hidden_size, output_size).to('cuda')
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model.eval()

    # Initialize hidden state
    batch_size = input_batch.size(0)
    hidden = torch.zeros(1, batch_size, hidden_size).to('cuda')

    # Convert input batch to tensor if not already
    input_batch = torch.tensor(input_batch, dtype=torch.float32).to('cuda')

    # Forward pass through the model
    with torch.no_grad():
        _, hidden = model(input_batch.unsqueeze(1), hidden)

    output = hidden.squeeze(0).to('cuda')

    assert output.size() == (batch_size, hidden_size), output.size()

    return hidden.to('cuda')


In [4]:
# for batch in dataloader:
#     print(batch[:, 0, :].size())
#     print(get_RNN_hiden_state_rep(batch[:, 0, :]).squeeze(0).detach().size())
#     break

In [20]:
seed = 4

bits_config_test = {
    "gamma_parity": 0.99,
    "gamma_extra": 0.99,
    "dataset_length": 1000000,
    "torch_seed": seed,
    "dataset_type": "bits",
    "num_atoms": 6,
    "batch_size": 1000,
    "train_mode": False,
    "train_model_B": False,
    "adjust_Psi": False,
    "clip": 5,
    "feature_size": 1,
    "epochs": 2,
    "start_updating_f_after": 1000,
    "update_f_every_N_steps": 5,
    "minimize_neg_terms_until": 0,
    "downward_critics_config": {
        "hidden_sizes_v_critic": [256, 256, 256],
        "hidden_sizes_xi_critic": [256, 256, 256],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
    
    "decoupled_critic_config": {
        "hidden_sizes_encoder_1": [256, 256],
        "hidden_sizes_encoder_2": [256, 256],
        "critic_output_size": 32,
        "lr": 1e-3,
        "bias": True,
        "weight_decay": 0,
    },
    "feature_network_config": {
        "hidden_sizes": [256, 256],
        "lr": 1e-4,
        "bias": True,
        "weight_decay": 0.00001,
    }
}

from custom_datasets import BitStringDataset

# Create a dataset
dataset = BitStringDataset(
    gamma_parity=bits_config_test["gamma_parity"],
    gamma_extra=bits_config_test["gamma_extra"],
    length=bits_config_test["dataset_length"]
)



dataloader = DataLoader(dataset, batch_size=bits_config_test["batch_size"], shuffle=False)

project_name_test = "NEURIPS-what-bits-do-hidden-state-encode-RNN"

from trainers import train_feature_network

out = train_feature_network(
    config=bits_config_test,
    trainloader=dataloader,
    feature_network_training=get_RNN_hiden_state_rep,
    project_name=project_name_test
)

  model.load_state_dict(torch.load(model_path, map_location='cuda'))
  input_batch = torch.tensor(input_batch, dtype=torch.float32).to('cuda')


feature size: torch.Size([1000, 1])
feature size: torch.Size([1000, 1])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres

Training:  50%|█████     | 1/2 [00:44<00:44, 44.13s/it]

 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
torch.Size([999, 999])
 scpres size; torch.Size([999, 999])
feature size: torch.Size([1000, 1])
feature size: torch.Size([1000, 1])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; 

Training: 100%|██████████| 2/2 [01:28<00:00, 44.24s/it]

torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
torch.Size([1000, 1000])
 scpres size; torch.Size([1000, 1000])
feature size: torch.Size([1000, 1])
feature size: torch.Size([1000, 1])
torch.Size([1000, 1000])
 scpres




0,1
Psi,▆▁▅█▄█▅▆▅▄▃▅▅▅▅▅▃▅▅▅▅▅▅▅▅▅▄▄▅▅▅▅▅▅▅▅▅▅▅▅
bonus_bit_MI,█▆▁▃▅▅▅▅▅▅▅▄▅▆▅▅▅▅▅▅▆▅▇▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
decoupled_MI,▆▁▇▅▆▆▃▇▆▇▆██▄▅█▄▆▄█▇▇▆▆▇▇▅▅█▇▆▇▇▇██▁▆▇▅
downward_MI_0,▃▂█▂▆▁▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂
downward_MI_1,▁▂█▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
downward_MI_2,▂▅▁▅▅▅▅█▅▄▆▅▅▅▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▄▄▅▅
downward_MI_3,▄▅█▅▅▄▅▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
downward_MI_4,▆▅▄▅█▃▁▄▄▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
downward_MI_5,▇▃▇▄▆▅▃▇▆▇▆██▄▅█▄▆▅█▇▇▆▆▇▇▆▅█▆▆▇▇███▁▆█▅
extra_bit_MI,▁▃▇▆▇▆▅█▇█▇██▆▆█▅▆▅█▇█▇▇██▆▆█▇▆▇▇▇██▃▇█▆

0,1
Psi,-0.00097
bonus_bit_MI,-3e-05
decoupled_MI,0.88014
downward_MI_0,-0.0
downward_MI_1,-0.0
downward_MI_2,-0.0
downward_MI_3,0.0
downward_MI_4,-0.0
downward_MI_5,0.88351
extra_bit_MI,0.95366


# testing reps predictive power

to switch between which rep you are using change the two comments at wanbdb.inti and hidden state

you can choose which rep model to use from models folder

you can choose what RNN to use by editing abot get_hidden_... function

In [18]:

device = 'cuda'
import wandb

# Initialize wandb client
wandb.init(project='NEURIPS-predicting-next-time-step-from-emergent-rep')  #"NEURIPS-predicting-next-time-step-from-hidden-state-RNN")

import torch
from torch.utils.data import DataLoader
from custom_datasets import BitStringDataset

# Load bits dataset
dataset_test = BitStringDataset(
    gamma_parity=0.99,
    gamma_extra=0.99,
    length=1000000
)

dataloader = DataLoader(dataset_test, batch_size=1000, shuffle=False)


from models import SkipConnectionSupervenientFeatureNetwork

representation_netork = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=6,
    feature_size=1,
    hidden_sizes=[256, 256]
).to(device)

representation_netork.load_state_dict(torch.load("models/emergent_feature_network-dark-frog-14.pth"))


# Define a simple MLP for prediction
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = []
        current_size = input_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Initialize the MLP
input_size = 1
hidden_sizes = [256, 512, 1024, 1024, 512, 256]
output_size = 6
mlp = MLP(input_size, hidden_sizes, output_size).to(device)

# Define optimizer and loss function
mlp_optimizer = torch.optim.Adam(
    mlp.parameters(),
    lr=1e-4,
    weight_decay=0.00001
)
loss_fn = nn.MSELoss()

# Training loop for the MLP
epochs = 10
for epoch in range(epochs):
    for batch_num, batch in enumerate(dataloader):
        x_t = batch[:, 0].to(device).float()
        x_t_plus_1 = batch[:, 1].to(device).float()
        
        # Get hidden state representation
        hidden_state =  representation_netork(x_t).detach() # get_RNN_hiden_state_rep(x_t).detach()
        
        # Predict x_t+1 using the MLP
        mlp_optimizer.zero_grad()
        x_t_plus_1_pred = mlp(hidden_state)
        
        # Compute loss and backpropagate
        loss = loss_fn(x_t_plus_1_pred, x_t_plus_1)
        loss.backward()
        mlp_optimizer.step()
        
        # Log the loss
        wandb.log({"mlp_loss": loss.item()}, step=epoch * len(dataloader) + batch_num)

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

wandb.finish()


  representation_netork.load_state_dict(torch.load("models/emergent_feature_network-dark-frog-14.pth"))


Epoch 1/10, Loss: 0.2518981397151947
Epoch 2/10, Loss: 0.25208085775375366
Epoch 3/10, Loss: 0.25221505761146545
Epoch 4/10, Loss: 0.2522874176502228
Epoch 5/10, Loss: 0.2521645426750183
Epoch 6/10, Loss: 0.25196224451065063
Epoch 7/10, Loss: 0.25141844153404236
Epoch 8/10, Loss: 0.2503499686717987
Epoch 9/10, Loss: 0.24888698756694794
Epoch 10/10, Loss: 0.2483784258365631


0,1
mlp_loss,▆▂▅▄▃▄█▄▄▃▆▄▅▇▃▅▅▄▄▃▃▅▃▄▃▄▃▃▂▃▁▁▃▂▃▁▃▃▂▃

0,1
mlp_loss,0.24838


The below cell shows the hidden state and the output of the MLP found in above cell for visual inspection (not super important)

In [17]:
from models import SkipConnectionSupervenientFeatureNetwork

representation_netork = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=6,
    feature_size=1,
    hidden_sizes=[256, 256]
).to(device)

# load models/emergent_feature_network-scarlet-plant-12.pth

representation_netork.load_state_dict(torch.load("models/emergent_feature_network-dark-frog-14.pth"))



# Get 3 datapoints from the dataloader
data_iter = iter(dataloader)
points = []
for _ in range(3):
    batch = next(data_iter)
    x_t = batch[:3, 0].to(device).float()
    x_t_plus_1 = batch[:3, 1].to(device).float()
    points.append((x_t, x_t_plus_1))

# Print the original datapoints and their predictions
for x_t, x_t_plus_1 in points:
    print("Original x_t:", x_t)
    print("Original x_t_plus_1:", x_t_plus_1)
    
    # Get hidden state representation using the representation network
    hidden_state = representation_netork(x_t).detach()
    print("Hidden state representation:", hidden_state)
    
    # Predict x_t+1 using the MLP
    x_t_plus_1_pred = mlp(hidden_state)
    print("MLP output (predicted x_t_plus_1):", x_t_plus_1_pred)
    print("\n" + "-"*50 + "\n")


Original x_t: tensor([[0., 0., 1., 1., 0., 1.],
        [1., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 1., 1.]], device='cuda:0')
Original x_t_plus_1: tensor([[1., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 1., 1.],
        [0., 0., 1., 0., 1., 1.]], device='cuda:0')
Hidden state representation: tensor([[ 0.6725],
        [-0.3880],
        [-0.4227]], device='cuda:0')
MLP output (predicted x_t_plus_1): tensor([[ 0.4797,  0.5010,  0.4824,  0.4996,  0.5017,  0.2511],
        [ 0.4673,  0.4508,  0.4954,  0.4315,  0.4772, -0.4513],
        [ 0.4679,  0.4503,  0.4964,  0.4313,  0.4777, -0.4601]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

--------------------------------------------------

Original x_t: tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 0., 1., 0.]], device='cuda:0')
Original x_t_plus_1: tensor([[1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 0., 1., 0.],
        [0., 1., 0., 1., 1., 1.]], device='cuda:0')
Hidden state re

  representation_netork.load_state_dict(torch.load("models/emergent_feature_network-dark-frog-14.pth"))
