# Progressive distillation on sparse parity
This is a code for the paper "Progressive distillation induces an implicit curriculum". Full code is available here: https://github.com/abhishekpanigrahi1996/ProgressiveDistillation

In [7]:
# get the relevant packages

import os
import numpy as np
import torch 
from torch.utils.data import Dataset
import random
from tqdm import tqdm


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
# Necessary Utils for the rest of the code


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
    
"""
Model related
"""

def load_torch_ckpt(model, fckpt):
  ckpt = torch.load(fckpt, map_location=DEVICE)
  model.load_state_dict(ckpt['model_state_dict'])
  print("Loaded ckpt", fckpt)
  return model

def get_mlp(n_input, n_hidden, n_output, n_layers):
  layers = []
  for i in range(n_layers):
      if i == 0:
          layers.append(torch.nn.Linear(n_input, n_hidden))
      else:
          layers.append(torch.nn.Linear(n_hidden, n_hidden))
      layers.append(torch.nn.ReLU())
  layers.append(torch.nn.Linear(n_hidden, n_output))
  return torch.nn.Sequential(*layers)


relu = torch.nn.ReLU()

def get_logits(input_, layers, return_hidden=0):
    num_layers = len(layers) - 1
    out = input_
    hiddens = {}
    for li in range(num_layers):
        out = relu(layers[li](out))
        if return_hidden:
            hiddens[li] = out.detach().cpu().np()
    out = layers[num_layers](out)
    if return_hidden:
        return out, hiddens
    return out

def get_logits_gpt(input_, model):
    output = model(input_ids=input_)
    logits = output.logits
    return logits[:, -1, :]


In [9]:

"""
Hierarchical tree data
"""

def get_features(num_labels, d, feature_complexity, random=False, feature_coordinates=None):
    if random:
        all_features = np.random.choice(d, size=(num_labels-1, feature_complexity))
    else:
        assert (num_labels-1) * feature_complexity <= d, "Number of available components should be more"
        if feature_coordinates is None:
            all_features = [range(i, i+feature_complexity) for i in range(0, (num_labels-1)*feature_complexity, feature_complexity)]
        else:
            all_features = [feature_coordinates]
    return all_features


def boolean_data(n, d, num_labels, all_features):
    def score(train_data, all_features, label):
        all_scores = np.zeros((len(train_data),))
        
        while(label > 1):
            prod_features = all_features [label//2-1]
            score_ = (1 - 2*(label%2)) * np.prod(train_data[:, prod_features], axis=-1)
            all_scores += score_
            label = label // 2
            
        return all_scores
    
    train_x = 2*np.random.choice(2, size=(n, d))-1
    
    train_y = np.zeros((n, num_labels))
    for i in range(num_labels):
        train_y[:, i] = score(train_x, all_features, i+num_labels)
    
    return train_x, np.argmax(train_y, axis=-1).astype(np.int32)

 
class HierarchicalData(Dataset):
  def __init__(self, data, labels):
      self.inputs = data
      self.labels = labels
      self.n_examples = len(data)
  
  def __getitem__(self, idx):
      return self.inputs[idx], self.labels[idx]

  def __len__(self):
      return self.n_examples

In [10]:


"""
Data related
"""

def get_data_loaders(cfg, seed):
    set_seed(seed)

    data_type = cfg['data']['data_type']
    num_labels = cfg['data']['num_labels']
    num_workers = cfg['data']['num_workers']
    model_type = cfg['model']['type']

    if data_type == 'hierarchical':
        data_dimension = cfg['data']['data_dimension']
        feature_complexity = cfg['data']['feature_complexity']
        randomize_features = cfg['data']['randomize_features']
        n_examples = cfg['training']['n_examples']
        batch_size = cfg['training']['batch_size']

        all_features = get_features(num_labels, data_dimension, feature_complexity, random=randomize_features)
        # here, we set the seed to make deterministic runs
        all_data, all_y = boolean_data(n_examples, data_dimension, num_labels, all_features)
        

        eval_split = num_labels * min(2048, len(all_data)//4)
        train_split = len(all_data) - eval_split
        eval_batch_size = min(1000, eval_split)
        
        train_data, train_y = all_data[:train_split], all_y[:train_split]
        train_dataset = HierarchicalData(train_data, train_y)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                  shuffle=True, num_workers=num_workers)
        eval_data, eval_y = all_data[train_split:], all_y[train_split:]
        eval_dataset = HierarchicalData(eval_data, eval_y)
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=eval_batch_size,
                                                  shuffle=False, num_workers=num_workers)
    return train_loader, eval_loader


In [11]:


"""
Training related
"""
def loss_fn(pred, target): 
    loss_ = torch.nn.CrossEntropyLoss()
    return loss_ (pred, target)

def accuracy(pred, target):
    return (torch.argmax(pred, axis=-1) == target).type(torch.float32).mean()

## Step 1: Run SFT experiments to compare models of different sizes

In [12]:
# define the training function

def train(hidden_size, 
          model_type, 
          data_dimension, 
          num_labels, 
          learning_rate, 
          weight_decay, 
          output_path, 
          seed, 
          num_layers, 
          feature_complexity, 
          randomize_features, 
          feature_coordinates, 
          n_examples, 
          n_steps, 
          n_epochs, 
          batch_size, 
          eval_batch_size, 
          log_intvl, 
          save_intvl,
          warmup_ratio,
          anneal_type
        ):


    save_suffix = f'_{model_type}_hid{hidden_size}' + f'_n{data_dimension}_k{feature_complexity}' + f'_num_labels{num_labels}' + f'_lr{learning_rate}' + f'warm{warmup_ratio}' + f'_wd{weight_decay}' + f'_seed{seed}' + f'_num_layers{num_layers}' + f'_e{n_epochs}'
    output_dir = output_path + save_suffix


    """
    Get data
    """
    set_seed(seed)
    all_features = get_features(num_labels, data_dimension, feature_complexity,
                                random=randomize_features,
                                feature_coordinates=feature_coordinates,)
    all_data, all_y = boolean_data(n_examples, data_dimension, num_labels, all_features)
  
    eval_split = num_labels * 125
    train_split = len(all_data) - eval_split
    eval_batch_size = min(1000, eval_split)
    
    train_data, train_y = all_data[:train_split], all_y[:train_split]
    train_dataset = HierarchicalData(train_data, train_y)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=1)
    eval_data, eval_y = all_data[train_split:], all_y[train_split:]
    eval_dataset = HierarchicalData(eval_data, eval_y)
    eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=eval_batch_size,
                                              shuffle=False, num_workers=1)

    dtype = torch.float32 if model_type == 'mlp' else torch.long


    """
    Get model
    """
    all_layers, all_parameters = get_mlp(data_dimension, num_labels, hidden_size, num_layers)
    optimizer = torch.optim.SGD(all_parameters, lr=learning_rate, weight_decay=weight_decay)


    global_step_cnt = 0
    eval_accs = []
    for epoch in tqdm(range(n_epochs)):
        print("Current Epoch:", epoch)
        for bt, batch in tqdm(enumerate(train_loader)):
            batch_data, batch_y = batch

            cuda_batch_data, cuda_batch_y = batch_data.to(DEVICE).type(dtype), batch_y.to(DEVICE).long()
            
            predicted_y = all_layers(cuda_batch_data)
            
            loss = loss_fn(predicted_y, cuda_batch_y)
            loss.backward()

            # set learning rate
            curr_learning_rate = 0
            if global_step_cnt < n_steps * warmup_ratio:
                # linear warmup
                curr_learning_rate = learning_rate * (global_step_cnt+1) / (n_steps * warmup_ratio)
            else:
                if anneal_type == 'cosine':
                    curr_learning_rate = learning_rate * 0.5 * (1 + np.cos(np.pi * (global_step_cnt - n_steps * warmup_ratio) / (n_steps * (1-warmup_ratio))))
                elif anneal_type == 'linear':
                    curr_learning_rate = learning_rate * (1 - (global_step_cnt - n_steps * warmup_ratio) / (n_steps * (1-warmup_ratio)))
                elif anneal_type == 'constant':
                    curr_learning_rate = learning_rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_learning_rate 


            if global_step_cnt % 5000 == 0:
               print(f"Loss at step {global_step_cnt}: {loss.item()}")

                                
            optimizer.step()
            optimizer.zero_grad()

            if global_step_cnt % min(log_intvl, len(train_loader)) == 0:
                eval_loss = 0.
                predicted_y_gap = 0
                for bt, batch in tqdm(enumerate(eval_loader)):
                    batch_data, batch_y = batch
                    cuda_batch_data, cuda_batch_y = batch_data.to(DEVICE).type(dtype), batch_y.to(DEVICE).long()
                    with torch.no_grad():
                        eval_logits = all_layers(cuda_batch_data)
                        predicted_y = torch.nn.functional.softmax(eval_logits, dim=-1)
                        predicted_y_gap += (predicted_y[:, 0] -  predicted_y[:, 1]).mean().item()
                        loss = accuracy(predicted_y, cuda_batch_y)
                        eval_loss += loss.item()
                eval_acc = eval_loss / len(eval_loader)
                eval_accs += eval_acc,
                predicted_y_gap /= len(eval_loader)
                
                # check for early stopping
                if len(eval_accs) > 10 and sum(eval_accs[-10:])/10 > 0.9999 and 0:
                    print("Early Stopping at:", global_step_cnt)
                    return
                
            if global_step_cnt % save_intvl == 0:
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                print("Save at:", global_step_cnt)
                fout = output_dir + '/checkpoint-'+str(global_step_cnt)
                
                if type(all_layers) == list:
                  model_state_dict = [layer.state_dict() for layer in all_layers]
                else:
                  model_state_dict = all_layers.state_dict()
                try:
                  torch.save(
                      {
                      'epoch': epoch, 
                      'step': global_step_cnt,
                      'model_state_dict': model_state_dict,
                      'optimizer_state_dict': optimizer.state_dict()
                    },
                    fout)
                except:
                    print("Failed to save model")
            global_step_cnt += 1


In [13]:
# get the training arguments
config = {
    'hidden_size': 10_000,
    'model_type': 'mlp',
    'data_dimension': 50,
    'num_labels': 2,
    'learning_rate': 1e-2,
    'weight_decay': 0.05,
    'output_path': 'result',
    'seed': 0,
    'num_layers': 2,
    'feature_complexity': 3,
    'randomize_features': False,
    'feature_coordinates': '',
    'n_examples': 100_000,
    'n_steps': 100_000,
    'n_epochs': 1,
    'batch_size': 1,
    'eval_batch_size': 128,
    'log_intvl': 10_000,
    'save_intvl': 10_000,
    'warmup_ratio': 0.05,
    'anneal_type': 'cosine',
}

train(**config)

Random seed set as 0


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices