In [1]:
import numpy as np
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_weights(ntasks, alpha):
    # Calculate the unnormalized weights for each task
    unnormalized_weights = [i ** (-alpha) for i in range(1, ntasks + 1)]
    
    # Calculate the normalization factor (sum of unnormalized weights)
    normalization_factor = sum(unnormalized_weights)
    
    # Normalize the weights so that they sum to 1
    normalized_weights = [weight / normalization_factor for weight in unnormalized_weights]
    
    return normalized_weights

def generate_multitask_sparse_parity(n, k, ntasks, num_samples, control_bit_probs=None):
    # Generate ntasks random subsets Si of k indices from {1,2,...,n}
    subsets = [np.random.choice(range(n), k, replace=False) for _ in range(ntasks)]

    # Set uniform distribution if no control_bit_probs provided
    if control_bit_probs is None:
        control_bit_probs = [1/ntasks] * ntasks

    # Generate random dataset
    task_bits = np.random.randint(2, size=(num_samples, n))
    task_nums = np.random.choice(ntasks, size=num_samples, p=control_bit_probs)
    control_bits = np.zeros((num_samples, ntasks), dtype=int)
    control_bits[np.arange(num_samples), task_nums] = 1

    # Compute the sparse parity for the active task
    sparse_parities = np.array([np.sum(task_bits[i, subsets[task_num]]) % 2 for i, task_num in enumerate(task_nums)])

    # Combine control bits, task bits, and sparse parity as input-output pairs
    input_bits = np.hstack((control_bits, task_bits))

    return input_bits, sparse_parities


In [3]:
n = 100
ntasks = 500
k = 3
alpha = 1.4
probs = generate_weights(ntasks,alpha)

In [4]:
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

In [5]:
# Define the neural network model
class ReLUMLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(ReLUMLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.layers(x)

In [6]:
def eval_subtasks(model, task_subsets, loss_func, samples_per_task = 50):
  n_tasks = task_subsets.shape[0]
  n = task_subsets.shape[1]

  task_bits = np.random.randint(2, size = (samples_per_task * n_tasks, n))
  output_bits = np.zeros(samples_per_task * n_tasks, dtype = int)
  # create outputs for every subtask
  for i in range(n_tasks):
    input_slice = task_bits[i * samples_per_task : (i + 1) * samples_per_task]
    outputs = np.sum(input_slice * task_subsets[i], axis = 1) % 2
    output_bits[i * samples_per_task:(i + 1) * samples_per_task] = outputs
  input_bits = np.eye(n_tasks)[np.repeat(np.arange(n_tasks), samples_per_task)]
  # print(input_bits.shape)

  total_input = np.concatenate([input_bits, task_bits], axis = 1)
  # calculate the loss for each (in bits)
  mod_outputs = model(torch.from_numpy(total_input).float().to(device))
  loss = loss_func(mod_outputs, torch.from_numpy(output_bits).long().to(device))
  # print(loss.shape)
  # return loss in bits for every subtask
  loss = loss.cpu()

  losses = loss.reshape(-1, samples_per_task).mean(axis = 1) / np.log(2)
  # print(losses.shape)
  return losses.detach().cpu().numpy()

In [7]:
def generate_task_subsets(num_tasks, num_bits, task_sizes, random_state = 0):
  np.random.seed(random_state)
  task_subsets = np.zeros((num_tasks, num_bits))
  for i in range(num_tasks):
    task_subsets[i][np.random.choice(num_bits, size = task_sizes[i], replace = False)] = 1
  return task_subsets

In [8]:
num_iterations = 10000
batch_size = 20000
import gc

task_subsets = generate_task_subsets(ntasks, n, np.ones(ntasks, int) * 3)
loss_func = nn.CrossEntropyLoss(reduction = 'none')

def train_online_model(hidden_dim):
    # Set up the model, optimizer
    input_dim = n + ntasks
    output_dim = 2
    model = ReLUMLP(input_dim, output_dim, hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)

    losses = np.zeros(num_iterations)
    subtask_losses_across_training = np.zeros((num_iterations // 1000, ntasks))


    # Train the model for num_iterations iterations
    for i in tqdm(range(num_iterations), desc="Training"):
        # Generate a new dataset for the current iteration
        input_data, output_data = generate_multitask_sparse_parity(n, k, ntasks, batch_size, probs)

        # Convert to PyTorch tensors
        input_data_tensor = torch.tensor(input_data, dtype=torch.float32).to(device)
        output_data_tensor = torch.tensor(output_data, dtype=torch.long).to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_data_tensor)

        # Calculate the loss
        loss = criterion(outputs, output_data_tensor)

        # Backward pass
        loss.backward()

        # Optimize
        optimizer.step()

        # Store the loss
        losses[i] = loss.item()

        if i % 1000 == 0:
            print("Iteration {}: Loss = {}".format(i, loss.item()))
            # eval on subtasks
            subtask_losses = eval_subtasks(model, task_subsets, loss_func)
            subtask_losses_across_training[i // 1000] = subtask_losses

    del model
    torch.cuda.empty_cache()
    gc.collect()
    return losses, subtask_losses_across_training


In [10]:
losses_10 = train_online_model(10)

Training:   0%|          | 1/10000 [00:00<33:03,  5.04it/s]

Iteration 0: Loss = 0.6965048313140869


Training:  10%|█         | 1001/10000 [02:23<24:03,  6.23it/s]

Iteration 1000: Loss = 0.6931250691413879


Training:  20%|██        | 2001/10000 [04:49<21:27,  6.21it/s]

Iteration 2000: Loss = 0.6931466460227966


Training:  30%|███       | 3001/10000 [07:36<19:16,  6.05it/s]

Iteration 3000: Loss = 0.6931455135345459


Training:  40%|████      | 4001/10000 [10:03<16:34,  6.03it/s]

Iteration 4000: Loss = 0.6931328773498535


Training:  50%|█████     | 5001/10000 [12:42<14:01,  5.94it/s]

Iteration 5000: Loss = 0.6931437253952026


Training:  60%|██████    | 6001/10000 [15:29<15:42,  4.24it/s]

Iteration 6000: Loss = 0.6931556463241577


Training:  70%|███████   | 7001/10000 [18:31<14:22,  3.48it/s]

Iteration 7000: Loss = 0.6931427717208862


Training:  80%|████████  | 8001/10000 [21:11<05:36,  5.93it/s]

Iteration 8000: Loss = 0.693150520324707


Training:  90%|█████████ | 9001/10000 [24:01<03:00,  5.54it/s]

Iteration 9000: Loss = 0.6931452751159668


Training: 100%|██████████| 10000/10000 [26:36<00:00,  6.27it/s]
