In [9]:
from transformers import BertModel, BertConfig
from copy import deepcopy
import torch
from typing import List, Dict
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os
import torch

In [14]:
class ClassificationDataBase:
    """ Data base class API: loads, preprocesses, and tokenizes datasets to 
        be ready for the BERT model.
        Required attributes:
         - num_classes: number of classes in a dataset
         - datasets: dictionary with 'train' and 'val' tokenized datasets;
    """
    def __init__(self, **kwargs):
        self.num_classes: int
        self.datasets: Dict[str, Dataset] = {}  
        
        
class IMDb(ClassificationDataBase):
    def __init__(self, device, backbone_name, **kwargs):
        super().__init__()
        loading_kwargs = {'path': 'imdb', 
                          'train_filename': 'train.tsv',
                          'dev_filename': 'dev.tsv',
                          'header': 0,
                          'index_col': 0}
        self.device = device
        self._prepare_datasets(loading_kwargs, device, backbone_name)
    
    def _prepare_datasets(self, loading_kwargs, device, backbone_name):
        dataframes = load_dataframes(loading_kwargs)
        self.num_classes = len(dataframes['train']['label'].unique())
        data_X, data_y = {}, {}
        for split in dataframes.keys():
            data_X[split] = dataframes[split].drop('label', axis=1, inplace=False)
            data_X[split] = data_X[split].values.tolist()
            data_y[split] = dataframes[split]['label'].values
            data_y[split] = torch.LongTensor(data_y[split]).to(device)
            data_X[split] = get_tokens(data_X[split], backbone_name=backbone_name)
            data_X[split] = {k: v.to(device) for k,v in data_X[split].items()}
            self.datasets[split] = TokenizedDataset(data_X[split], data_y[split])
        

class TokenizedDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.y)
            
    def __getitem__(self,idx):
        return {key: self.X[key][idx] for key in self.X.keys()}, self.y[idx]


def get_all_subclasses(cls):
    all_subclasses = []
    for subclass in cls.__subclasses__():
        all_subclasses.append(subclass)
        all_subclasses.extend(get_all_subclasses(subclass))
    return all_subclasses


def set_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


def load_dataframes(loading_kwargs):
    path = loading_kwargs['path']
    train_filename = loading_kwargs['train_filename']
    dev_filename = loading_kwargs['dev_filename']
    header = loading_kwargs['header']
    index_col = loading_kwargs['index_col']
    assert '.' in train_filename, f"unrecognized file format {train_filename}"
    extension = train_filename.split('.')[-1]
    if extension == 'tsv':
        delimiter = '\t'
    elif extension == 'csv':
        delimiter = ','
    else:
        raise ValueError(f"unrecognized file format {extension}")
    dataframes = {}
    for split,split_filename in zip(['train','dev'], [train_filename, dev_filename]):
        filename = os.path.join(path, split_filename)
        dataframes[split] = pd.read_csv(
                filename,
                delimiter=delimiter,
                header=header,
                index_col=index_col,
                engine="python",
                error_bad_lines=False,
                warn_bad_lines=False)
        new_columns = [f"sentence_{i}" for i in range(len(dataframes[split].columns)-1)]+["label"]
        dataframes[split].columns = new_columns
    return dataframes


def get_tokens(data_X: List[List[str]], backbone_name: str) -> dict:
    tokenizer =  BertTokenizer.from_pretrained(backbone_name)
    if len(data_X[0])==1:
        data_X = [X[0] for X in data_X]
    data_X = tokenizer.batch_encode_plus(
                data_X,
                add_special_tokens=True,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt")
    return data_X
    

def stabilized_forward(linearized_encoder, min_, max_, scaler=1e3, max_attempts=10):
    def forward(x, attention_mask=None):
        for i in range(len(linearized_encoder.layer)):
            attempts = 0
            prelim_x = linearized_encoder.layer[i](x, attention_mask=attention_mask)[0]
            while torch.mean(torch.abs(prelim_x))>max_ and attempts<max_attempts:
                for m in linearized_encoder.layer[i].modules():
                    if isinstance(m, torch.nn.Linear):
                        m.weight.data = m.weight.data/scaler
                query = linearized_encoder.layer[i].attention.self.query.weight.data
                linearized_encoder.layer[i].attention.self.query.weight.data = scaler*query
                key = linearized_encoder.layer[i].attention.self.key.weight.data
                linearized_encoder.layer[i].attention.self.key.weight.data = scaler*key
                prelim_x = linearized_encoder.layer[i](x, attention_mask=attention_mask)[0]
                attempts+=1
            while torch.mean(torch.abs(prelim_x))<min_ and attempts<max_attempts:
                for m in linearized_encoder.layer[i].modules():
                    if isinstance(m, torch.nn.Linear):
                        m.weight.data = scaler*m.weight.data
                query = linearized_encoder.layer[i].attention.self.query.weight.data
                linearized_encoder.layer[i].attention.self.query.weight.data = query/scaler
                key = linearized_encoder.layer[i].attention.self.key.weight.data
                linearized_encoder.layer[i].attention.self.key.weight.data = key/scaler
                prelim_x = linearized_encoder.layer[i](x, attention_mask=attention_mask)[0]
                attempts+=1
            assert attempts<max_attempts
            x = prelim_x
        return x
    return forward


def max_pooling(token_embeddings, attention_mask):
    input_mask_expanded=(attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float())
    token_embeddings[input_mask_expanded == 0] = -1e9
    return torch.max(token_embeddings, -2).values


def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded=(attention_mask.unsqueeze(-1).float())
    sum_embeddings=torch.sum(token_embeddings * input_mask_expanded, -2)
    sum_mask=torch.clamp(input_mask_expanded.sum(-2), min=1e-9)
    return sum_embeddings / sum_mask


def pool(x, encoded_input, pool_type):
    if pool_type in ['avg', 'mean']:
        x = mean_pooling(x['last_hidden_state'], encoded_input['attention_mask'])
    elif pool_type == 'max':
        x = max_pooling(x['last_hidden_state'], encoded_input['attention_mask'])
    elif pool_type in ['cls', 'first']:
        x = x['last_hidden_state'][:,0]
    elif pool_type == 'pooler_output':
        x = x['pooler_output']
    else:
        raise ValueError(f'unknown pool type {pool_type}.')
    return x


def effective_masks_dense(masks):
    for i,mask in enumerate(masks):
        assert len(mask.size())==2, f"found mask of shape {mask.size()}"
        masks[i] = mask.T
    units=[mask.shape[-2] for mask in masks]+[masks[-1].shape[-1]]
    next_layer=torch.ones((units[-1],))
    way_out=[next_layer]
    for mask in masks[::-1]:
        curr_mask=torch.matmul(mask,next_layer.view(len(next_layer),1))
        next_layer=torch.sum(curr_mask,dim=1)>0
        way_out.append(next_layer)
    way_out=way_out[::-1]
    prev_layer=torch.ones((units[0],))
    way_in=[prev_layer]
    for mask in masks:
        curr_mask=torch.matmul(prev_layer.view(1,len(prev_layer)),mask)
        prev_layer=torch.sum(curr_mask,dim=0)>0
        way_in.append(prev_layer)
    activity=[w_in*w_out for w_in,w_out in zip(way_in,way_out)]
    effective_masks = []
    for i,mask in enumerate(masks):
        activity_prev = activity[i].view(len(activity[i]),1)
        activity_next = activity[i+1].view(1,len(activity[i+1]))
        effective_mask = mask*torch.matmul(activity_prev, activity_next)
        effective_masks.append(effective_mask.T)
    return effective_masks


In [15]:
class ClassificationModelBase(torch.nn.Module):
    """ The base model API
        Required attributes:
         - masks: dictionary storing lists of binary masks for
           each stage ('embeddings', 'encoder', 'classifier')
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.masks: Dict[str, List[torch.Tensor]]

    def forward(self, X: Dict) -> torch.Tensor:
        """ Forward pass of the model
            Args:
             - X: dictionary of batched tokenized documents
            Returns: class logits
        """
        raise NotImplementedError(f"override 'forward' method")

    def effective_masks(self, **kwargs) -> Dict[str, List[torch.Tensor]]:
        """ Return effective masks in a dictionary by module:
            'embeddings', 'encoder', 'classifier'
        """
        raise NotImplementedError(f"override 'linearize' method")

In [50]:
class LinearizedEmbeddings(torch.nn.Module):
    def __init__(self, bert_embeddings):
        super().__init__()
        self.shapes = {1: bert_embeddings.word_embeddings.weight.data.size(),
                2: bert_embeddings.position_embeddings.weight.data.size(),
                3: bert_embeddings.token_type_embeddings.weight.data.size()}
        self.total_length = sum(shape[0] for shape in self.shapes.values())
        bert_word_embeddings = bert_embeddings.word_embeddings.weight.data.detach().clone().t()
        self.word_embeddings = torch.nn.Linear(self.shapes[1][0],
                self.shapes[1][1], bias=False)
        self.word_embeddings.weight.data = bert_word_embeddings.requires_grad_(True)
        bert_position_embeddings = bert_embeddings.position_embeddings.weight.data.detach().clone().t()
        self.position_embeddings = torch.nn.Linear(self.shapes[2][0],
                self.shapes[2][1], bias=False)
        self.position_embeddings.weight.data = bert_position_embeddings.requires_grad_(True)
        bert_token_type_embeddings = bert_embeddings.token_type_embeddings.weight.data.detach().clone().t()
        self.token_type_embeddings = torch.nn.Linear(self.shapes[3][0],
                self.shapes[3][1], bias=False)
        self.token_type_embeddings.weight.data = bert_token_type_embeddings.requires_grad_(True)
        
    def forward(self, X):
        X = torch.squeeze(X)
        X1 = X[:, sum(self.shapes[i][0] for i in range(1,1)):sum(self.shapes[i][0] for i in range(1,2))]
        X2 = X[:, sum(self.shapes[i][0] for i in range(1,2)):sum(self.shapes[i][0] for i in range(1,3))]
        X3 = X[:, sum(self.shapes[i][0] for i in range(1,3)):sum(self.shapes[i][0] for i in range(1,4))]
        word_out = self.word_embeddings(X1)
        position_out = self.position_embeddings(X2)
        token_type_out = self.token_type_embeddings(X3)
        return word_out + position_out + token_type_out


class LinearizedBERTClassifier(ClassificationModelBase):
    def __init__(self, reference_model, stabilize=False):
        super().__init__()
        self.device = reference_model.device
        self.stabilize = stabilize
        configuration = BertConfig(hidden_act=lambda x: x,
                hidden_dropout_prob=0,
                attention_probs_dropout_prob=0)
        linearized_encoder = BertModel(configuration).encoder
        if stabilize:
            linearized_encoder.forward = stabilized_forward(
                linearized_encoder,
                min_=1e-2,
                max_=1e5)
        identity = torch.nn.Identity()
        weights = []
        for m in reference_model.modules_dict.encoder.modules():
            if isinstance(m, torch.nn.Linear):
                weights.append(m.weight.data.detach().clone())
        idx = 0
        for m in linearized_encoder.modules():
            if isinstance(m, torch.nn.Linear):
                m.weight.data = weights[idx].requires_grad_(True)
                idx+=1
            if isinstance(m, torch.nn.LayerNorm):
                m.forward = identity.forward
        linearized_embeddings = LinearizedEmbeddings(reference_model.modules_dict.embeddings)
        self.total_length = linearized_embeddings.total_length
        self.modules_dict = torch.nn.ModuleDict({
                'embeddings': linearized_embeddings.to(self.device),
                'encoder': linearized_encoder.to(self.device),
                #'classifier': deepcopy(reference_model.modules_dict.classifier).to(self.device)
        })
        for m in self.modules():
            if hasattr(m, 'weight'):
                m.weight.data = torch.abs(m.weight.data)
                m.weight.grad = None
        
    def forward(self, inp):
        attention_mask = torch.ones(512)[None, None, None, :]
        x = self.modules_dict['embeddings'](inp)[None, :, :]
        x = self.modules_dict['encoder'](x, attention_mask=attention_mask)
        mask = {'attention_mask': torch.squeeze(attention_mask)}
        if self.stabilize:
            x = {'last_hidden_state': x}
        x = pool(x, mask, 'mean')
        #x = self.modules_dict['classifier'](x)
        return x

    def get_effective_masks(self):
        effective_masks = {}
        full_length = self.total_length
        X = torch.ones((512,full_length))
        output = torch.sum(self(X))
        output.backward()
        for stage in self.modules_dict.keys():
            effective_masks[stage] = []
            for m in self.modules_dict[stage].modules():
                if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Embedding):
                    print(m.weight.grad.size(), m.weight.grad.reshape(-1)[:5], torch.min(m.weight.grad), torch.max(m.weight.grad))
                    effective_masks[stage].append((torch.abs(m.weight.grad)>0).int())
                    m.weight.grad = None
        return effective_masks


class BERTClassifier(ClassificationModelBase):
    def __init__(self, backbone_name,
                    pool_type,
                    num_classes,
                    device,
                    **kwargs):
        super().__init__()
        self.pool_type = pool_type
        bert = BertModel.from_pretrained(backbone_name)
        classifier = torch.nn.Linear(768, num_classes, bias=True)
        self.modules_dict = torch.nn.ModuleDict({
                'embeddings': bert.embeddings.to(device),
                'encoder': bert.encoder.to(device),
                #'classifier': classifier.to(device)
        })
        self.prunable_modules = {torch.nn.Embedding, torch.nn.Linear}
        self.create_masks()
        self.device = device
        self.effective_masks
        
    def create_masks(self):
        self.masks = {}
        for stage in self.modules_dict.keys():
            self.masks[stage] = []
            for m in self.modules_dict[stage].modules():
                if type(m) in self.prunable_modules:
                    self.masks[stage].append(torch.ones(m.weight.data.size()))

    def apply_masks(self):
        for stage in self.masks.keys():
            current_idx = 0
            for m in self.modules_dict[stage].modules():
                if type(m) in self.prunable_modules:
                    m.weight.data = torch.mul(m.weight.data, self.masks[stage][current_idx])
                    current_idx+=1
    
    def forward(self, inp):
        self.apply_masks()
        attention_mask = inp['attention_mask'][:, None, None, :]
        x = self.modules_dict['embeddings'](inp['input_ids'])
        x = self.modules_dict['encoder'](x, attention_mask=attention_mask)
        x = pool(x, inp, self.pool_type)
        #x = self.modules_dict['classifier'](x)
        return x

    def linearize(self):
        self.apply_masks()
        return LinearizedBERTClassifier(self, stabilize=True)

In [56]:
a = torch.nn.Linear(10,10)
list(iter(a.modules()))

[Linear(in_features=10, out_features=10, bias=True)]

In [7]:
imdb_dataset = IMDb(torch.device('cpu'), 'bert-base-uncased')

In [51]:
bert = BertModel.from_pretrained('bert-base-uncased')
model = BERTClassifier('bert-base-uncased', 'mean', 2, torch.device('cpu'))
linearized = model.linearize()
dataloader = DataLoader(imdb_dataset.datasets['train'], batch_size=4, shuffle=True)
effective_masks = linearized.get_effective_masks()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predic

torch.Size([768, 30522]) tensor([0.0332, 0.0332, 0.0332, 0.0332, 0.0332]) tensor(-0.4529) tensor(0.5298)
torch.Size([768, 512]) tensor([0.0332, 0.0332, 0.0332, 0.0332, 0.0332]) tensor(-0.4529) tensor(0.5298)
torch.Size([768, 2]) tensor([0.0332, 0.0332, 0.0658, 0.0658, 0.0956]) tensor(-0.4529) tensor(0.5298)
torch.Size([768, 768]) tensor([42.2640, 52.9883, 51.1211, 56.7959, 38.8760]) tensor(-238.6008) tensor(192.4649)
torch.Size([768, 768]) tensor([48.2956, 60.5500, 58.4161, 64.9012, 44.4238]) tensor(-239.9757) tensor(238.0393)
torch.Size([768, 768]) tensor([14.6375, 18.3516, 17.7049, 19.6704, 13.4641]) tensor(7.3030) tensor(67.6220)
torch.Size([768, 768]) tensor([15.0646, 13.8326, 13.8340, 13.7263, 14.4319]) tensor(5.4796) tensor(36.2494)
torch.Size([3072, 768]) tensor([14.2495, 17.8647, 17.2354, 19.1481, 13.1073]) tensor(8.5118) tensor(124.5001)
torch.Size([768, 3072]) tensor([18.3313, 18.8038, 19.9627, 20.0192, 17.1135]) tensor(5.6353) tensor(81.2385)
torch.Size([768, 768]) tensor([ 

In [201]:
configuration = BertConfig(hidden_act=lambda x: x,hidden_dropout_prob=0,attention_probs_dropout_prob=0)
bert = BertModel(configuration)
bert.encoder.layer[0].attention.self.query

Linear(in_features=768, out_features=768, bias=True)

In [53]:
for mask in effective_masks['encoder']:
    print(mask.shape, (mask.numpy()>0).all())

torch.Size([768, 768]) False
torch.Size([768, 768]) False
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
torch.Size([768, 3072]) True
torch.Size([768, 768]) False
torch.Size([768, 768]) False
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
torch.Size([768, 3072]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
torch.Size([768, 3072]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
torch.Size([768, 3072]) True
torch.Size([768, 768]) False
torch.Size([768, 768]) False
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
torch.Size([768, 3072]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([768, 768]) True
torch.Size([3072, 768]) True
tor