# In this experiment we train an RNN on the bits dataset and show it does not learn the emergent bit

In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from info_theory_experiments.custom_datasets import BitStringDataset
from torch.utils.data import DataLoader

# Define the RNN model
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])  # Use the last time step's output
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

# Training function
def train_rnn(config, trainloader, model, criterion, optimizer, device):
    model.train()
    for epoch in range(config['epochs']):
        for batch in trainloader:
            inputs = batch[:, :-1, :].to(device)
            targets = batch[:, -1, :].to(device)

            hidden = model.init_hidden(inputs.size(0)).to(device)
            optimizer.zero_grad()
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            wandb.log({"loss": loss.item()})



# train the RNN

In [22]:
# Main script
def main():
    # Configuration
    config = {
        'epochs': 5,
        'batch_size': 32,
        'learning_rate': 0.001,
        'input_size': 6,
        'hidden_size': 1,
        'output_size': 6,
        'gamma_parity': 0.9,
        'gamma_extra': 0.9,
        'length': 100000,
        'num_runs': 1
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create directory for saving models
    os.makedirs('RNNs', exist_ok=True)

    for run in range(config['num_runs']):
        wandb.init(project="RNN_BitStringDataset", config=config, name=f"run_{run+1}")

        # Dataset and DataLoader
        dataset = BitStringDataset(config['gamma_parity'], config['gamma_extra'], config['length'])
        trainloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

        # Model, criterion, and optimizer
        model = SimpleRNN(config['input_size'], config['hidden_size'], config['output_size']).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

        # Train the model
        train_rnn(config, trainloader, model, criterion, optimizer, device)

        # Save the model
        torch.save(model.state_dict(), f"RNNs/rnn_model_run_{run+1}.pth")

        wandb.finish()


###

main()


0,1
loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.24474


In [23]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# change this to your own RNN model path

def get_RNN_hiden_state_rep(input_batch):
    """
    Function to get the hidden state representation of an RNN model with path below
    """
    model_path = 'RNNs/rnn_model_run_1.pth' # model path should probably be an input, edit this for a different RNN
    # Load the model
    input_size = 6
    hidden_size = 1
    output_size = 6
    model = SimpleRNN(input_size, hidden_size, output_size).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Initialize hidden state
    batch_size = input_batch.size(0)
    hidden = torch.zeros(1, batch_size, hidden_size).to(device)

    # Convert input batch to tensor if not already
    input_batch = torch.tensor(input_batch, dtype=torch.float32).to(device)

    # Forward pass through the model
    with torch.no_grad():
        _, hidden = model(input_batch.unsqueeze(1), hidden)

    output = hidden.squeeze(0).to(device)

    assert output.size() == (batch_size, hidden_size), output.size()

    return hidden.to(device)


# testing reps predictive power

to switch between which rep you are using change the two comments at wanbdb.inti and hidden state

you can choose which rep model to use from models folder

In [24]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import wandb

# Initialize wandb client
wandb.init(project='NEURIPS-predicting-next-time-step-from-emergent-rep')  #"NEURIPS-predicting-next-time-step-from-hidden-state-RNN")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from info_theory_experiments.custom_datasets import BitStringDataset

# Load bits dataset
dataset_test = BitStringDataset(
    gamma_parity=0.99,
    gamma_extra=0.99,
    length=1000000
)

dataloader = DataLoader(dataset_test, batch_size=1000, shuffle=False)

from info_theory_experiments.models import SkipConnectionSupervenientFeatureNetwork

representation_netork = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=6,
    feature_size=1,
    hidden_sizes=[256, 256]
).to(device)

# Define a simple MLP for prediction
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = []
        current_size = input_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Initialize the MLP
input_size = 1
hidden_sizes = [256, 512, 1024, 1024, 512, 256]
output_size = 6
mlp = MLP(input_size, hidden_sizes, output_size).to(device)

# Define optimizer and loss function
mlp_optimizer = torch.optim.Adam(
    mlp.parameters(),
    lr=1e-4,
    weight_decay=0.00001
)
loss_fn = nn.MSELoss()

# Training loop for the MLP
epochs = 10
for epoch in range(epochs):
    for batch_num, batch in enumerate(dataloader):
        x_t = batch[:, 0].to(device).float()
        x_t_plus_1 = batch[:, 1].to(device).float()
        
        # Get hidden state representation
        hidden_state = get_RNN_hiden_state_rep(x_t).detach()
        
        # Predict x_t+1 using the MLP
        mlp_optimizer.zero_grad()
        x_t_plus_1_pred = mlp(hidden_state)
        
        # Compute loss and backpropagate
        loss = loss_fn(x_t_plus_1_pred, x_t_plus_1)
        loss.backward()
        mlp_optimizer.step()
        
        # Log the loss
        wandb.log({"mlp_loss": loss.item()}, step=epoch * len(dataloader) + batch_num)

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

wandb.finish()


  model.load_state_dict(torch.load(model_path, map_location=device))
  input_batch = torch.tensor(input_batch, dtype=torch.float32).to(device)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/10, Loss: 0.21030686795711517
Epoch 2/10, Loss: 0.21046586334705353
Epoch 3/10, Loss: 0.21048860251903534
Epoch 4/10, Loss: 0.21041585505008698
Epoch 5/10, Loss: 0.21041055023670197
Epoch 6/10, Loss: 0.2103809118270874
Epoch 7/10, Loss: 0.21038936078548431
Epoch 8/10, Loss: 0.21040891110897064
Epoch 9/10, Loss: 0.2104358822107315
Epoch 10/10, Loss: 0.2104797661304474


0,1
mlp_loss,██▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
mlp_loss,0.21048


In [6]:
from models import GeneralSmileMIEstimator
from custom_datasets import BitStringDataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from models import SkipConnectionSupervenientFeatureNetwork

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])  # Use the last time step's output
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size).to('cuda')

device = 'cuda'

def get_RNN_hiden_state_rep(input_batch):
    # Load the model
    model_path = 'RNNs/rnn_model_run_4.pth'
    input_size = 6
    hidden_size = 1
    output_size = 6
    model = SimpleRNN(input_size, hidden_size, output_size).to('cuda')
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model.eval()

    # Initialize hidden state
    batch_size = input_batch.size(0)
    hidden = torch.zeros(1, batch_size, hidden_size).to('cuda')

    # Convert input batch to tensor if not already
    input_batch = torch.tensor(input_batch, dtype=torch.float32).to('cuda')

    # Forward pass through the model
    with torch.no_grad():
        _, hidden = model(input_batch.unsqueeze(1).to('cuda'), hidden)

    output = hidden.squeeze(0).to('cuda')

    assert output.size() == (batch_size, hidden_size), output.size()

    return hidden.to('cuda')


model = SkipConnectionSupervenientFeatureNetwork(
    num_atoms=6,
    feature_size=1,
    hidden_sizes=[256, 256],
    include_bias=True
).to(device)

model.load_state_dict(torch.load("models/emergent_feature_network-jolly-sea-16.pth"))


bits_dataset = BitStringDataset(
    gamma_parity=0.99,
    gamma_extra=0.99,
    length=100000
)

dataloader = DataLoader(bits_dataset, batch_size=1000, shuffle=True)

print(bits_dataset.data.size())

mi_estimator = GeneralSmileMIEstimator(
    x_dim=2,
    y_dim=6,
    critic_output_size=8,
    x_critics_hidden_sizes=[256,512, 512, 256],
    y_critics_hidden_sizes=[256,512, 512, 256],
    clip=5,
    include_bias=True,
).to('cuda')

optim = torch.optim.Adam(mi_estimator.parameters(), lr=1e-3)




  model.load_state_dict(torch.load("models/emergent_feature_network-jolly-sea-16.pth"))


torch.Size([99999, 2, 6])
MI: 0.8779662847518921
MI: 1.8352651596069336
MI: 1.8177181482315063
MI: 1.846044898033142
MI: 1.7928447723388672


In [11]:
for seed in range(7,12):
    for epoch in range(6):
        for batch in dataloader:
            torch.manual_seed(seed)
            x_t = batch[:, 0].to(device).float()
            x_t_plus_1 = batch[:, 1].to(device).float()

            x_t_rep1 = get_RNN_hiden_state_rep(x_t).detach().to('cuda').squeeze(0)
            x_t_rep2 = model(x_t).detach().to('cuda').squeeze(0)

            # x_t_rep1 = x_t[:, :5].sum(dim=1) % 2
            # x_t_rep2 = x_t[:, -1]

            rep = torch.cat([x_t_rep1.unsqueeze(1), x_t_rep2.unsqueeze(1)], dim=1).squeeze(-1)
            optim.zero_grad()

            MI = mi_estimator(rep, x_t_plus_1).to('cuda')

            loss = - MI

            loss.backward()
            optim.step()

    print("MI:", MI.item())


  model.load_state_dict(torch.load(model_path, map_location='cuda'))
  input_batch = torch.tensor(input_batch, dtype=torch.float32).to('cuda')


MI: 1.802267074584961
MI: 1.8062381744384766
MI: 1.8205976486206055
MI: 1.8456114530563354
MI: 1.8380610942840576
