In [1]:
%cd ../..

/home/shapkin/effective-inference


## Import libs

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import yaml
import h5py
import numpy as np
import torch

from utils.dataset_cache import cache_embeddings, get_dataset_for_regression, build_dataset_from_cached, load_cached_dataset
from utils.dataset_cache import build_dict_dataset_from_cached
from utils.prepare_dataset import load_datasets, cut_datasets
from utils.config import ConfigWrapper
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Optional, Union
from numpy.random import shuffle

In [3]:
from torch import nn
from copy import deepcopy

def hidden_to_heads(x, config):
    num_attention_heads = config.attention_config.num_heads
    attention_head_size =  config.attention_config.d_model // config.attention_config.num_heads
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(new_x_shape)
    return x

class LinearAttention(nn.Module):
    def __init__(self, config):
        super(LinearAttention, self).__init__()
        self.config = config
        self.features = config['features']
        self.device = config['device']
        self.batch_size = config['batch_size']

        self.dim_size = config['d_model']
        if config.split_heads:
            self.dim_size = config['d_model'] // config['num_heads']
                        
        for k in self.features:                
            if 'hidden' in k:
                learnable_parameters = f'torch.nn.Linear(in_features={self.dim_size}, out_features=1)'
                exec(f"self.{k} = {learnable_parameters}")
            else:
                learnable_parameters = f'nn.Parameter(torch.randn(1), requires_grad=True)'
                exec(f"self.{k} = {learnable_parameters}")
                    
            
    def forward(self, seq_len=None, **kwargs):
        if seq_len is not None:
            result = torch.zeros((self.batch_size, seq_len, seq_len), device=self.device)
            
            for arg_name, arg_value in kwargs.items():
                namespace = {'cur_result': None, 'self': self, 'arg_name': arg_name, 'arg_value': arg_value}
                if 'hidden' in arg_name:
                    exec(f"cur_result = self.{arg_name}(arg_value)", namespace)
                else:
                    exec(f"cur_result = self.{arg_name} * arg_value", namespace)
                if 'from' in arg_name:
                    result += namespace['cur_result'].T
                else:
                    result += namespace['cur_result']
        else:
            result = torch.zeros((self.batch_size, 1), device=self.device)
            
            for arg_name, arg_value in kwargs.items():
                namespace = {'cur_result': None, 'self': self, 'arg_name': arg_name, 'arg_value': arg_value}
                if 'hidden' in arg_name:
                    exec(f"cur_result = self.{arg_name}(arg_value)", namespace)
                else:
                    exec(f"cur_result = self.{arg_name} * arg_value", namespace)
                if 'from' in arg_name:
                    result += namespace['cur_result'].T
                else:
                    result += namespace['cur_result']
            
        return result


from transformers.models.bert.modeling_bert import BertSelfAttention, BertModel, \
    BaseModelOutputWithPastAndCrossAttentions

class LinearClassifierBertAttention(BertSelfAttention):
    """
    Idea: attention weights are predicted by Linear Classifier
    """
    def __init__(self, bert_config, config):
        super(LinearClassifierBertAttention, self).__init__(bert_config)
        self.config = config
        self.linear_config = config.attention_config

        if self.linear_config.split_heads:
            for head_num in range(self.linear_config['num_heads']):
                learnable_parameters = f'LinearAttention(self.config.attention_config)'
                namespace = {'head_num': head_num, 'self': self, 'LinearAttention': LinearAttention}
                exec(f"self.linear_model_{head_num} = {learnable_parameters}", namespace)
        else:
            self.linear_model = LinearAttention(config.attention_config)

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
            output_attentions: Optional[bool] = False,
            # special_tokens_idxs: Optional[List[int]] = [0]
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
            # special_tokens_idxs = (encoder_hidden_states[0] < 103).nonzero().squeeze()
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
            # special_tokens_idxs = (encoder_hidden_states[0] < 103).nonzero().squeeze()
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        if self.config.attention_config.split_heads:
            hidden_states = hidden_to_heads(hidden_states, self.config)

            attentions = []
            for head_num in range(self.linear_config['num_heads']):
                seq_len = hidden_states.shape[1]
                positions = torch.arange(seq_len).view(-1, 1)
                
                full_data_to_linear = {
                    'hidden_from': hidden_states[:, :, head_num, :], 
                    'hidden_to': hidden_states[:, :, head_num, :], 
                    'pos_from': positions,
                    'pos_to': positions,
                    'relev_pos_from': seq_len - positions,
                    'relev_pos_to': seq_len - positions,
                    'inv_pos_from': (positions / seq_len),
                    'inv_pos_to': (positions / seq_len),
                    'inv_relev_pos_from': ((seq_len - positions) / seq_len),
                    'inv_relev_pos_to': ((seq_len - positions) / seq_len),
                    'seq_len': seq_len, 
                    'inv_seq_len': (1 / seq_len)
                }
                
                data_to_linear = {k:full_data_to_linear[k].to(self.linear_config['device']) for k in self.linear_config['features'] if k != 'num_heads'}

                namespace = {'predicted_attention': None, 'self': self, 'data_to_linear': data_to_linear, 'seq_len': seq_len}
                exec(f"predicted_attention = self.linear_model_{head_num}(seq_len, **data_to_linear)", namespace)
                # print(namespace['predicted_attention'])
                attention_probs = nn.Sigmoid()(torch.exp(namespace['predicted_attention']))  # torch.nn.functional.softmax(predicted_attention, dim=-1)
                attentions.append(attention_probs)

            attention_probs = torch.stack(attentions, dim=1)
            context_layer = torch.matmul(attention_probs, value_layer)
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            context_layer = context_layer.view(new_context_layer_shape)
    
            outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    
            if self.is_decoder:
                outputs = outputs + (past_key_value,)
            return outputs

        else:
            seq_len = hidden_states.shape[1]
            positions = torch.arange(seq_len).view(-1, 1)
            full_data_to_linear = {
                'hidden_from': hidden_states, 
                'hidden_to': hidden_states, 
                'pos_from': positions,
                'pos_to': positions,
                'relev_pos_from': seq_len - positions,
                'relev_pos_to': seq_len - positions,
                'inv_pos_from': (positions / seq_len),
                'inv_pos_to': (positions / seq_len),
                'inv_relev_pos_from': ((seq_len - positions) / seq_len),
                'inv_relev_pos_to': ((seq_len - positions) / seq_len),
                'seq_len': seq_len, 
                'inv_seq_len': (1 / seq_len),
            }
            if 'head_num' in self.linear_config['features']:
                attentions = []
                for head_num in range(self.linear_config['num_heads']):
                    full_data_to_linear['head_num'] = torch.tensor([head_num])

                    data_to_linear = {k:full_data_to_linear[k].to(self.linear_config['device']) for k in self.linear_config['features']}
                    predicted_attention= self.linear_model(seq_len, **data_to_linear)
                
                    attention_probs = nn.Sigmoid()(torch.exp(predicted_attention))  # torch.nn.functional.softmax(predicted_attention, dim=-1)
                    attentions.append(attention_probs)

                    attention_probs = torch.stack(attentions, dim=1)
                    context_layer = torch.matmul(attention_probs, value_layer.squeeze(1))
                    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
                    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
                    context_layer = context_layer.view(new_context_layer_shape)
            
                    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
            
                    if self.is_decoder:
                        outputs = outputs + (past_key_value,)
                    return outputs
                
            else:
                data_to_linear = {k:full_data_to_linear[k].to(self.linear_config['device']) for k in self.linear_config['features']}
                predicted_attention= self.linear_model(seq_len, **data_to_linear)
            
                attention_probs = nn.Sigmoid()(torch.exp(predicted_attention))  # torch.nn.functional.softmax(predicted_attention, dim=-1)
                context_layer = torch.matmul(attention_probs, value_layer.squeeze(1))
        
                context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
                new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
                context_layer = context_layer.view(new_context_layer_shape)
        
                outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        
                if self.is_decoder:
                    outputs = outputs + (past_key_value,)
        return outputs

class BertWrapperLin(nn.Module):
    def __init__(self, model, new_attention_class, linear_config, layer_nums=None, window_size=2):
        super().__init__()

        self.bert_model = deepcopy(model)
        self.layer_nums = layer_nums

        # Create a list of modules to modify
        modules_to_modify = []
        for i in range(len(self.bert_model.encoder.layer)):
            if (layer_nums is not None and i in layer_nums) or (layer_nums is None):
                mean_attention = new_attention_class(self.bert_model.config, linear_config) # self.bert_model.config, 
                #mean_attention.set_window_size(window_size)
                mean_attention.load_state_dict(self.bert_model.encoder.layer[i].attention.self.state_dict(), strict=False)

                self.bert_model.encoder.layer[i].attention.self = mean_attention

    def forward(self, *args, **kwargs):
        return self.bert_model(*args, **kwargs)

## Project configuration

In [4]:
config_path = 'config.yaml'

with open(config_path, "r") as f:
    config = ConfigWrapper(yaml.load(f, Loader=yaml.FullLoader))

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, max_length=config.general.max_len)
initial_model = AutoModel.from_pretrained(config.model.model_name).to(config.general.device)

## Load data

In [6]:
#X_train, y_train, X_test, y_test = load_cached_dataset(config, layer=0)
#X_train.shape, X_test.shape

In [7]:
train_datasets = load_datasets(config.data.train_datasets, config.data.cut_size)
train_datasets

{'imdb': DatasetDict({
     train: Dataset({
         features: ['text', 'label'],
         num_rows: 25000
     })
     test: Dataset({
         features: ['text', 'label'],
         num_rows: 25000
     })
     unsupervised: Dataset({
         features: ['text', 'label'],
         num_rows: 50000
     })
 })}

In [8]:
X_train, y_train, X_test, y_test = build_dict_dataset_from_cached(config, train_datasets, layer=0, heads=[0], 
                                                                  features=config.attention_config.features, 
                                                                  split_hidden=True)

In [9]:
X_train[0]['hidden_to'][0], X_train[1000]['hidden_to'][0]

(0.15553807, 0.21780878)

In [10]:
X_train[0]['hidden_to'].shape

(64,)

In [11]:
len(X_train), len(X_test)

(1606, 85)

In [12]:
X_train[0]

{'hidden_to': array([ 0.15553807,  0.14143425,  0.13302758,  0.31399602, -0.65938884,
        -1.0736477 ,  0.6093533 ,  0.13589564,  1.3214531 ,  0.667912  ,
         0.17945454,  0.4693082 ,  0.30648527, -0.56361777, -0.20076695,
         0.9098844 ,  0.59665406,  0.5579244 , -1.4222709 ,  0.18325962,
        -1.105002  ,  1.2219864 , -0.40134352, -0.9862202 , -0.27676457,
        -0.7115183 , -0.48043513, -0.41555184,  0.5822979 , -0.8090197 ,
         0.809246  , -0.16602594, -0.07341174,  0.117907  , -0.9402012 ,
         0.02974713,  0.22651967, -0.3060124 ,  0.05891408,  0.43850982,
         0.10242728,  0.91733706,  0.64147025, -0.3555956 , -1.7645459 ,
        -0.2814301 , -0.00459639,  0.25998768, -1.391335  ,  0.47268033,
         0.24658613, -0.22695369, -0.24742895, -0.01804959,  0.80036235,
        -1.1185328 ,  0.7725253 ,  0.06924633,  0.02469535, -0.3253744 ,
         0.24401177, -0.05173851, -0.20923023,  0.3738086 ], dtype=float32),
 'hidden_from': array([-0.5322135 

In [29]:
def get_dict_batch(samples_arr, device):
    sample = samples_arr[0]
    final = {}
    for k, _ in sample.items():
        batched_feature = []
        for el in samples_arr:
            if 'hidden' in k:
                batched_feature.append(torch.tensor(el[k]))
            else:
                batched_feature.append(el[k])

        if 'hidden' in k:
            batched_feature = torch.stack(batched_feature)
        else:
            batched_feature = torch.tensor(batched_feature)
        final[k] = batched_feature.to(device)
    return final
    
def prepare_batches(dataset, n, device):
    shuffle_idx = np.arange(len(dataset))
    np.random.shuffle(shuffle_idx)
    new_dataset = dataset[shuffle_idx]
    print(new_dataset)
    # looping till length l
    for i in range(0, len(new_dataset), n): 
        yield get_dict_batch[i:i + n]

In [31]:
a = next(iter(prepare_batches(X_train, 5, 'cuda')))

[{'hidden_to': array([-0.19079687,  0.21273507,  0.66632915, -0.44872472, -0.48511204,
        -0.17139994,  0.33636194, -0.01300793, -0.33155817, -0.35787386,
         0.1725571 , -0.14330178, -0.61597705,  0.12558405,  0.9203594 ,
        -0.25555786, -0.7375698 , -1.3028327 ,  0.59556675,  1.0017891 ,
         0.04428435,  0.22370818, -0.09685135,  0.6178254 , -0.425674  ,
        -0.5404954 ,  0.50118387,  0.07698976, -0.42920288,  0.7901867 ,
         0.52212757,  0.09048381,  0.75964904, -0.60198724,  0.26504874,
         0.6656127 , -1.7834715 , -0.78274894,  0.12213472,  0.866579  ,
         0.36710104, -0.6980736 ,  0.33897507,  0.06735259,  0.11314931,
         0.7412936 ,  0.08072751,  0.34569353, -0.91472894, -0.19712013,
         1.1396126 ,  1.4449006 , -0.19016334,  0.26347688, -0.48608357,
        -0.40321687, -0.03245781,  0.11642367, -1.5451432 , -0.10034049,
         0.8672733 , -0.55858016,  0.5449989 ,  0.30212072], dtype=float32), 'hidden_from': array([ 0.05000687

TypeError: 'function' object is not subscriptable

In [29]:
get_dict_batch(X_train[:1], 'cuda')

{'hidden_to': tensor([[ 0.1555,  0.1414,  0.1330,  0.3140, -0.6594, -1.0736,  0.6094,  0.1359,
           1.3215,  0.6679,  0.1795,  0.4693,  0.3065, -0.5636, -0.2008,  0.9099,
           0.5967,  0.5579, -1.4223,  0.1833, -1.1050,  1.2220, -0.4013, -0.9862,
          -0.2768, -0.7115, -0.4804, -0.4156,  0.5823, -0.8090,  0.8092, -0.1660,
          -0.0734,  0.1179, -0.9402,  0.0297,  0.2265, -0.3060,  0.0589,  0.4385,
           0.1024,  0.9173,  0.6415, -0.3556, -1.7645, -0.2814, -0.0046,  0.2600,
          -1.3913,  0.4727,  0.2466, -0.2270, -0.2474, -0.0180,  0.8004, -1.1185,
           0.7725,  0.0692,  0.0247, -0.3254,  0.2440, -0.0517, -0.2092,  0.3738]],
        device='cuda:0'),
 'hidden_from': tensor([[-0.5322,  0.8566,  0.2841,  0.3840, -0.3893,  0.0606,  0.5328,  0.2478,
           0.7118,  0.6445,  0.2871,  0.1810,  0.4628,  0.1429,  0.1734,  0.5267,
          -0.2735,  0.3500,  0.0261,  0.5631,  0.7153,  0.3269,  0.4973,  0.2403,
          -0.0357,  0.9348, -0.1326,  0.00

In [30]:
lin_model = LinearAttention(config.attention_config).to('cuda')

In [31]:
lin_model(**get_dict_batch(X_train[:1], 'cuda'))

tensor([[-0.2485]], device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
torch.stack([torch.tensor([1, 2])])

In [32]:
my_list = ['geeks', 'for', 'geeks', 'like',
           'geeky','nerdy', 'geek', 'love',
               'questions','words', 'life']
  
# Yield successive n-sized
# chunks from l.
def divide_chunks(l, n):
      
    # looping till length l
    for i in range(0, len(l), n): 
        yield l[i:i + n]
  
# How many elements each
# list should have
n = 5
  
x = list(divide_chunks(my_list, n))
print (x)

[['geeks', 'for', 'geeks', 'like', 'geeky'], ['nerdy', 'geek', 'love', 'questions', 'words'], ['life']]


In [None]:
linear_model = BertWrapperLin(initial_model, LinearClassifierBertAttention, config, layer_nums=[6, 7, 8, 9, 10, 11])

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
linear_model = BertWrapperLin(initial_model, LinearClassifierBertAttention, config, layer_nums=[6, 7, 8, 9, 10, 11])

In [None]:
linear_model

In [None]:
encoded_inputs = tokenizer.encode(
                                'Hello! My name is... Hi, my name is... Slim Shady',
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)

In [None]:
aa = initial_model(encoded_inputs, output_hidden_states=True, output_attentions=True)

In [None]:
len(aa.attentions), aa.attentions[0].shape

In [None]:
linear_model.to(config.general.device)

In [None]:
aa = linear_model(encoded_inputs.to(config.general.device), output_hidden_states=True, output_attentions=True)

In [None]:
len(aa.attentions)

In [None]:
aa.attentions[-1].shape

In [None]:
aa.attentions[-1]

In [None]:
aa.attentions