In [3]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import transformers
import datasets

## Training GPT-2 from scratch

In [6]:
class SparseTrainer(transformers.Trainer):
    def __init__(self, *args, sparsity_lambda=0.01, **kwargs):
        super().__init__(*args, **kwargs)
        self.sparsity_lambda = sparsity_lambda

    def compute_loss(self, model, inputs, return_outputs=False):
        # Compute the original loss
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        lm_logits = outputs.logits
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Add L1 regularization for sparsity
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss += self.sparsity_lambda * l1_norm

        return (loss, outputs) if return_outputs else loss

In [19]:
def train_lm(output_dir, num_train_epochs = 5, lr = 3e-5, per_device_train_batch_size = 6, save_steps = 10000, logging_dir = './logs', logging_steps = 500):
    dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', cache_dir = '/storage/vsub851/.cache')

    #NOTE: If GPT-2 is too big for your GPU in colab, feel free to go smaller to [Distil-GPT2](https://huggingface.co/distilbert/distilgpt2) or something.
    tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2', cache_dir = '/storage/vsub851/.cache')
    tokenizer.pad_token = tokenizer.eos_token
    def prepare_data(examples):
        tokenized_inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)
        tokenized_inputs['labels'] = tokenized_inputs['input_ids'].copy()
        return tokenized_inputs
    dataset = dataset.map(prepare_data, batched=True)

    gpt2_config = transformers.GPT2Config.from_pretrained('gpt2', cache_dir = '/storage/vsub851/.cache')
    model = transformers.AutoModelForCausalLM.from_config(gpt2_config)
    trainer = SparseTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        tokenizer=tokenizer,
    )
    trainer.train()

In [20]:
##TODO: run language modeling training here
output_dir = None ## IMPORTANT: this argument tells the trainer where to save the resultant model. Keep track of this in colab for example
...

Ellipsis

## Your part: Get a linguistic feature

### Option 1: Get a pretrained model off the shelf 
Example: HuggingFace has several sentiment models that are trained for sentiment [here](https://huggingface.co/blog/sentiment-analysis-python). 

Feel free to find anything in this case. You can even ask how it work if you had a visual feature (do language models understand color. We can get pretty creative as long as we have a pretrained model.

### Option 2: Train your own. 
Find a dataset and train your own. If you want to do that, I'll give you some example steps for dependencies.

Check out the [syntax_parser](https://github.com/vsubramaniam851/multimeter-interp/tree/main/language_modeling/syntax_parser) to see the components of how to do this.

First look at `ud_data.py` to see how we have to make a word level dataset for language.

I make an example parser in `parser.py`.

I train the parser for POS tagging in `pos_train.py`.

## Multimeter Training

In [14]:
sys.path.append('../')

In [16]:
from hooks import *
from multimeter import CombinedNetwork

In [30]:
class ModificationHook:
    '''
    Attaches hook to module associated with layer_index in pytorch. Applies a modification on this hook based on a passed in representation from 
    an external network in the hook function. Applies a weighted average of this with the current output of the module.
    Inputs:
        layer_index: Index to grab representations from. Set based on flattened leaf modules list designed by the network.
        param: nn.Parameter set by pytorch and tuned to to incorporate the frozen output from another submodel with the current model.
    '''
    def __init__(self, layer_index, param):
        self.layer_index = layer_index
        self.frozen_output = None
        self.handle = None
        self.leaf_modules = []
        self.param = param

    def hook_fn(self, module, input, output):
        #Weighted sum based on the sigmoid of the passed in parameter. 
        return (1-torch.sigmoid(self.param)) * output + torch.sigmoid(self.param) * self.frozen_output

    def attach(self, model):
        #Iterate over modules of pytorch. Some modules like nn.Sequential, nn.ModuleList have children so iterate and append those recursively.
        def get_leaf_modules(module):
            if not list(module.children()):  # if leaf node
                self.leaf_modules.append(module)
            for child in module.children():
                get_leaf_modules(child)

        get_leaf_modules(model)
        #Register forward hook for the module that corresponds to the layer_index.
        if self.layer_index < len(self.leaf_modules):
            self.handle = self.leaf_modules[self.layer_index].register_forward_hook(self.hook_fn)
        else:
            raise IndexError(f'Layer index {self.layer_index} is out of range. Max index is {len(self.leaf_modules) - 1}')

    def remove(self):
        if self.handle is not None:
            self.handle.remove()

In [31]:
class CombinedNetwork(nn.Module):
    '''
    Multimeter network. Takes an original network and trains with input from a frozen network. 
    Inputs:
        original_network: Optimized network
        frozen_network: Multimeter network that original network can steal weights from
        representation_layer_index: start layer index to connect multimeter to.
        modification_layer_index: end layer index to combine frozen output and original network output.
    '''
    def __init__(self, original_network, frozen_network, representation_layer_index, modification_layer_index):
        super().__init__()
        self.original_network = original_network
        self.frozen_network = frozen_network
        self.parameter = nn.Parameter(torch.randn(()))
        self.representation_hook = RepresentationHook(representation_layer_index)
        self.modification_hook = ModificationHook(modification_layer_index, self.parameter)

    def forward(self, x):
        #Attach representation hook to get output of start module
        self.representation_hook.attach(self.original_network)
        #Run original network all the way through to get the output.
        with torch.no_grad():
            original_output = self.original_network(x)
        #Get frozen output from multimeter network
        frozen_output = self.frozen_network(self.representation_hook.representation)
        #Set modification hook frozen output
        self.modification_hook.frozen_output = frozen_output
        #Attach modification hook to the network
        self.modification_hook.attach(self.original_network)
        #Run through original network again with hook to recombine the entire frozen output
        final_output = self.original_network(x)
        #Remove hooks from the network.
        self.representation_hook.remove()
        self.modification_hook.remove()
        return final_output

In [27]:
class HFCombinedNetwork(CombinedNetwork):
    '''
    HuggingFace based multimeter network. Overloads the original multimeter network.
    '''
    def forward(self, input_ids, attention_mask = None, labels = None, **kwargs):
        #Attach representation hook to get output of start module
        self.representation_hook.attach(self.original_network)
        #Run original network all the way through to get the output.
        with torch.no_grad():
            original_output = self.original_network(input_ids)
        #Get frozen output from multimeter network
        frozen_output = self.frozen_network(self.representation_hook.representation)
        #Set modification hook frozen output
        self.modification_hook.frozen_output = frozen_output
        #Attach modification hook to the network
        self.modification_hook.attach(self.original_network)
        #Run through original network again with hook to recombine the entire frozen output
        final_output = self.original_network(input_ids)
        #Remove hooks from the network.
        self.representation_hook.remove()
        self.modification_hook.remove()
        return final_output

In [22]:
## We need a sparse trainer for fancy saving. I'm too lazy to make the code clean for this...
class SparseTrainer(transformers.Trainer):
    def __init__(self, *args, sparsity_lambda=0.01, **kwargs):
        super().__init__(*args, **kwargs)
        self.sparsity_lambda = sparsity_lambda

    def compute_loss(self, model, inputs, return_outputs=False):
        # Compute the original loss
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        lm_logits = outputs.logits
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Add L1 regularization for sparsity
        l1_norm = sum(p.abs().sum() for p in model.original_network.parameters())
        loss += self.sparsity_lambda * l1_norm

        return (loss, outputs) if return_outputs else loss

    def _save(self, output_dir = None, state_dict = None):
        if output_dir is None:
            output_dir = self.args.output_dir

        os.makedirs(output_dir, exist_ok = True)
        model = self.model_wrapped if hasattr(self, 'model_wrapped') else self.model
        model.original_network.save_pretrained(os.path.join(output_dir, 'original_network'))

        torch.save(model.frozen_network, os.path.join(output_dir, 'frozen_network'))
        torch.save(model.parameter, os.path.join(output_dir, 'parameter.pt'))

        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
        if self.optimizer and self.lr_scheduler:
            torch.save(
                {
                    'optimizer': self.optimizer.state_dict(),
                    'lr_scheduler': self.lr_scheduler.state_dict(),
                },
                os.path.join(output_dir, 'optimizer.pt'),
            )

        self.save_state()

Below, I show you what an example Frozen/Multimeter network looks like. 
TODO: You need to replace this with your own. Most of the structure is the same. I will show you where to add your part.

In [23]:
class SyntaxFrozenNetwork(nn.Module):
    def __init__(self, in_features, model_in_features, model_out_features, out_features, model):
        #CHANGE MODEL TO YOUR MODEL AND SET YOUR PARAMETERS
        super(SyntaxFrozenNetwork, self).__init__()
        self.linear1 = nn.Linear(in_features, model_in_features)
        self.linear2 = nn.Linear(model_out_features, out_features)
        self.relu = nn.ReLU()

        self.model = model
        for param in self.model.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        # print(x.shape)
        x = self.linear1(x)
        x, _, _ = self.model(inputs = None, attention_mask = None, inputs_embeds = x)
        #NOTE: A quick note to self; we choose to send logits instead of representations for the reason that 
        # representations may have added information/capacity when using BERT/GPT-2. It's super important to NOT
        # let this added capacity ablate more weights of the network. So we only use the logits instead.

        # Another quick note to self: Using BERT/GPT-2 isn't a foolproof idea so I need to think about this more.
        # I guess this was harder than I thought...
        x = x.view(x.shape[0], x.shape[1] * x.shape[2])
        x = self.relu(self.linear2(x))
        # Note this may change! Don't assume 768. Be sure to debug this
        x = x.view(x.shape[0], -1, 768)
        return x

In [25]:
def lm_multimeter(output_dir, pretrained_dir, mul_network, frozen_dir, index1, index2, use_pretrained = True, num_train_epochs = 5, lr = 3e-5, per_device_train_batch_size = 4, save_steps = 10000, logging_dir = './logs', logging_steps = 500):
    # Set up original network
    if use_pretrained:
        gpt2_model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_dir)
    else:
        # Sanity Check: train from a randomly initialized network instead
        gpt2_config = transformers.AutoConfig.from_pretrained('gpt2', cache_dir = '/storage/vsub851/.cache')
        gpt2_model = transformers.AutoModelForCausalLM.from_config(gpt2_config)

    # Set up frozen network. For Aliya: THIS WILL CHANGE.
    bert_model = transformers.AutoModel.from_pretrained('bert-base-uncased', cache_dir = '/storage/vsub851/.cache')
    # NOTE: 128 refers to the max length from the model. TODO: Add these as parameters
    if mul_network == 'dep':
        dep_model = PairwiseMLP(bert_model, in_features = 768, hidden_state = 5)
        dep_model.load_state_dict(torch.load(frozen_dir))
        frozen_network = SyntaxFrozenNetwork(in_features = 768, model_in_features = 768, model_out_features = 128 * 128, out_features = 128 * 768, model = dep_model)
    elif mul_network == 'pos':
        ckpt = torch.load(frozen_dir)
        num_pos = ckpt['linear2.weight'].shape[0]
        pos_model = MLP(bert_model, in_features = 768, hidden_state = 5, out_features = num_pos)
        pos_model.load_state_dict(torch.load(frozen_dir))
        frozen_network = SyntaxFrozenNetwork(in_features = 768, model_in_features = 768, model_out_features = 128 * num_pos, out_features = 128 * 768, model = pos_model)
    else:
        raise NotImplementedError

    # Set up combined network
    combined_network = HFCombinedNetwork(gpt2_model, frozen_network, representation_layer_index = index1, modification_layer_index = index2)

    # Set up dataset from wikitext again
    dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', cache_dir = '/storage/vsub851/.cache')
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(pretrained_dir)
    tokenizer.pad_token = tokenizer.eos_token
    # Function to add labels for language modeling
    def prepare_data(examples):
        tokenized_inputs = tokenizer(examples['text'], truncation = True, padding = 'max_length', max_length = 128)
        tokenized_inputs['labels'] = tokenized_inputs['input_ids'].copy()
        return tokenized_inputs

    dataset = dataset.map(prepare_data, batched=True)

    training_args = transformers.TrainingArguments(
        output_dir = output_dir,
        overwrite_output_dir = True,
        num_train_epochs = num_train_epochs,
        per_device_train_batch_size = per_device_train_batch_size,
        learning_rate = lr,
        save_steps = save_steps,
        logging_dir = logging_dir,
        logging_steps = logging_steps
    )

    trainer = SparseTrainer(
        model=combined_network,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        tokenizer=tokenizer,
    )

    print('Beginning training...')
    trainer.train()

Run everything here...

In [32]:
#lm_multimeter(...)
## IMPORTANT ARGUMENT: representation_layer_index = where the multimeter starts, modification_layer_index = where the multimeter ends. These must be different values and representation_layer_index < modification_layer_index

## Analysis Code

If you get everything to run, we can talk about this :p