In [1]:
import torch
import numpy as np
import ecco
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
def break_attn_heads_by_layer(zero_type, model, share, layer):
    """
    set certain percentage attention heads to zero at specific layer
    return the modified model
    :param zero_type: the type for zeroing attention heads,
                      'random', 'first' and 'shuffle' are supported
    :type zero_type: str
    :param model: the oringal GPT-2 model to be modified
    :type model: transformers.modeling_gpt2.GPT2LMHeadModel
    :param share: % of attention heads to be zeroed,
                  25%, 50%, and 100% are supported
    :type share: int
    :param layer: the specific layer to be modified,
                  ranging from 0 to 11
    :type layer: int
    :return: the modified model
    :rtype: transformers.modeling_gpt2.GPT2LMHeadModel
    """
    head_offsets = [1536, 1536+64, 1536+128, 1536+192, 1536+256,
                    1536+320, 1536+384, 1536+448, 1536+512,
                    1536+576, 1536+640, 1536+704]
    batch = 64
    with torch.no_grad():
        if zero_type == 'random':
            np.random.seed(42)
            torch.manual_seed(42)
            # Serguei's approach to reduce running time
            for head in head_offsets:
                # update to unique random integers
                rnd_index = np.random.choice(range(head, head+64), int(batch*(share/100)), replace=False)
                for row in range(0,model.transformer.h[layer].attn.c_attn.weight.size()[0]):
                    model.transformer.h[layer].attn.c_attn.weight[row][rnd_index] = \
                        model.transformer.h[layer].attn.c_attn.weight[row][rnd_index].mul(0.0)
            return model
        elif zero_type == 'first':
            offset = int(batch*(share/100))
            for head in head_offsets:
                for row in range(0,model.transformer.h[layer].attn.c_attn.weight.size()[0]):
                    model.transformer.h[layer].attn.c_attn.weight[row][head:head+offset] = \
                        model.transformer.h[layer].attn.c_attn.weight[row][head:head+offset].mul(0)
            return model
        elif zero_type == 'shuffle':
            offset = int(64*(share/100))
            for head in head_offsets:
                for row in range(0,model.transformer.h[layer].attn.c_attn.weight.size()[0]):
                    np.random.shuffle(model.transformer.h[layer].attn.c_attn.weight[row][head:head+offset] )
            return model
        else:
            raise ValueError("zeroing type is not supported!")

In [4]:
model_dem = GPT2LMHeadModel.from_pretrained("gpt2")
model_con = GPT2LMHeadModel.from_pretrained("gpt2")
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2", do_lower_case=True)
layers = [0, 1, 2, 3, 4, 8, 10]
for layer in layers:
    model_dem = break_attn_heads_by_layer("first", model_dem, 50, layer)

In [9]:
sent = "There are two children and their mother in the kitchen."
model_con = ecco.from_pretrained('gpt2', activations=True)
con_output = model_con.generate(sent, top_p=0.9, temperature=1)
con_output.saliency(style="detailed")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [13]:
nmf1 = con_output.run_nmf(n_components=12)
nmf1.explore()

<IPython.core.display.Javascript object>