In [1]:
%cd ../..

/home/shapkin/effective-inference


## Import libs

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import yaml
import h5py
import numpy as np
import torch

from utils.dataset_cache import cache_embeddings, get_dataset_for_regression, build_dataset_from_cached, load_cached_dataset
from utils.config import ConfigWrapper
from utils.prepare_dataset import load_datasets, cut_datasets
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Optional, Union

In [9]:
from torch import nn
from copy import deepcopy

def hidden_to_heads(x, config):
    num_attention_heads = config.attention_config.num_heads
    attention_head_size =  config.attention_config.d_model // config.attention_config.num_heads
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(new_x_shape)
    return x

class LinearAttention(nn.Module):
    def __init__(self, config):
        super(LinearAttention, self).__init__()
        self.config = config
        self.features = config['features']
        self.device = config['device']
        self.batch_size = config['batch_size']
        
        if config.split_heads:
            self.dim_size = config['d_model'] // config['num_heads']

            for head_id in range(len(config['num_heads'])):
                for k in self.features:
                    if k == 'head_num':
                        continue
                        
                    if 'hidden' in k:
                        learnable_parameters = f'torch.nn.Linear(in_features={self.dim_size}, out_features=1)'
                        exec(f"self.{k}_{head_id} = {learnable_parameters}")
                    else:
                        #learnable_parameters = f'torch.nn.Linear(in_features=1, out_features=1)'
                        learnable_parameters = f'nn.Parameter(torch.randn(1), requires_grad=True)'
                        #self.register_parameter(f'{k}', nn.Parameter(torch.normal(mean=torch.tensor(0.0), std=torch.tensor(1.0))))
                        exec(f"self.{k}_{head_id} = {learnable_parameters}")

        else:
            self.dim_size = config['d_model']

            for k in self.features:
                if 'hidden' in k:
                    learnable_parameters = f'torch.nn.Linear(in_features={self.dim_size}, out_features=1)'
                    exec(f"self.{k} = {learnable_parameters}")
                else:
                    #learnable_parameters = f'torch.nn.Linear(in_features=1, out_features=1)'
                    learnable_parameters = f'nn.Parameter(torch.randn(1), requires_grad=True)'
                    #self.register_parameter(f'{k}', nn.Parameter(torch.normal(mean=torch.tensor(0.0), std=torch.tensor(1.0))))
                    exec(f"self.{k} = {learnable_parameters}")
                    
            
    def forward(self, seq_len, **kwargs):
        result = torch.zeros((self.batch_size, seq_len, seq_len), device=self.device)
        if self.config.split_heads:
            assert 'head_num' in kwargs, 'You did NOT provide head_num'
            for arg_name, arg_value in kwargs.items():
                if arg_name == head_num:
                    continue
                # a = arg_value.clone()
                namespace = {'cur_result': None, 'self': self, 'arg_name': arg_name, 'arg_value': arg_value}
                if 'hidden' in arg_name:
                    #print(f"result += self.{arg_name}(arg_value)")
                    exec(f"cur_result = self.{arg_name}_{head_num}(arg_value)", namespace)
                else:
                    #print(f"result += self.{arg_name} * arg_value")
                    exec(f"cur_result = self.{arg_name}_{head_num} * arg_value", namespace)
                if 'from' in arg_name:
                    result += namespace['cur_result'].T
                else:
                    result += namespace['cur_result']
            return result
            
        else:
            for arg_name, arg_value in kwargs.items():
                # a = arg_value.clone()
                namespace = {'cur_result': None, 'self': self, 'arg_name': arg_name, 'arg_value': arg_value}
                if 'hidden' in arg_name:
                    #print(f"result += self.{arg_name}(arg_value)")
                    exec(f"cur_result = self.{arg_name}(arg_value)", namespace)
                else:
                    #print(f"result += self.{arg_name} * arg_value")
                    exec(f"cur_result = self.{arg_name} * arg_value", namespace)
                if 'from' in arg_name:
                    result += namespace['cur_result'].T
                else:
                    result += namespace['cur_result']
            return result
            

from transformers.models.bert.modeling_bert import BertSelfAttention, BertModel, \
    BaseModelOutputWithPastAndCrossAttentions

class LinearClassifierBertAttention(BertSelfAttention):
    """
    Idea: attention weights are predicted by Linear Classifier
    """
    def __init__(self, bert_config, config):
        super(LinearClassifierBertAttention, self).__init__(bert_config)
        self.config = config
        self.linear_config = config.attention_config
        self.linear_model = LinearAttention(config.attention_config)

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
            output_attentions: Optional[bool] = False,
            # special_tokens_idxs: Optional[List[int]] = [0]
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
            # special_tokens_idxs = (encoder_hidden_states[0] < 103).nonzero().squeeze()
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
            # special_tokens_idxs = (encoder_hidden_states[0] < 103).nonzero().squeeze()
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        if self.config.attention_config.split_heads:
            hidden_states = hidden_to_heads(hidden_states, self.config)

            for head_num in range(self.linear_config['num_heads']):
                seq_len = hidden_states.shape[1]
                positions = torch.arange(seq_len).view(-1, 1)
                
                full_data_to_linear = {
                    'hidden_from': hidden_states[:, :, head_num, :], 
                    'hidden_to': hidden_states[:, :, head_num, :], 
                    'pos_from': positions,
                    'pos_to': positions,
                    'relev_pos_from': seq_len - positions,
                    'relev_pos_to': seq_len - positions,
                    'inv_pos_from': (positions / seq_len),
                    'inv_pos_to': (positions / seq_len),
                    'inv_relev_pos_from': ((seq_len - positions) / seq_len),
                    'inv_relev_pos_to': ((seq_len - positions) / seq_len),
                    'seq_len': seq_len, 
                    'inv_seq_len': (1 / seq_len)
                }

            data_to_linear = {k:full_data_to_linear[k].to(self.linear_config['device']) for k in self.linear_config['features'] if k != 'num_heads'}
            predicted_attention= self.linear_model(seq_len, **data_to_linear)
            
            attention_probs = torch.exp(predicted_attention)  # torch.nn.functional.softmax(predicted_attention, dim=-1)
            context_layer = torch.matmul(attention_probs, value_layer.squeeze(1))
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            context_layer = context_layer.view(new_context_layer_shape)
    
            outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    
            if self.is_decoder:
                outputs = outputs + (past_key_value,)
            return outputs

        else:
            seq_len = hidden_states.shape[1]
            positions = torch.arange(seq_len).view(-1, 1)
            
            full_data_to_linear = {
                'hidden_from': hidden_states, 
                'hidden_to': hidden_states, 
                'pos_from': positions,
                'pos_to': positions,
                'relev_pos_from': seq_len - positions,
                'relev_pos_to': seq_len - positions,
                'inv_pos_from': (positions / seq_len),
                'inv_pos_to': (positions / seq_len),
                'inv_relev_pos_from': ((seq_len - positions) / seq_len),
                'inv_relev_pos_to': ((seq_len - positions) / seq_len),
                'seq_len': seq_len, 
                'inv_seq_len': (1 / seq_len)
            }

            data_to_linear = {k:full_data_to_linear[k].to(self.linear_config['device']) for k in self.linear_config['features']}
            predicted_attention= self.linear_model(seq_len, **data_to_linear)
            print('predicted_attention shape:', predicted_attention.shape)
            print('value_layer shape:', value_layer.shape)
            
            attention_probs = torch.exp(predicted_attention)  # torch.nn.functional.softmax(predicted_attention, dim=-1)
            context_layer = torch.matmul(attention_probs, value_layer.squeeze(1))
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            context_layer = context_layer.view(new_context_layer_shape)
    
            outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    
            if self.is_decoder:
                outputs = outputs + (past_key_value,)
        return outputs

class BertWrapperLin(nn.Module):
    def __init__(self, model, new_attention_class, linear_config, layer_nums=None, window_size=2):
        super().__init__()

        self.bert_model = deepcopy(model)
        self.layer_nums = layer_nums

        # Create a list of modules to modify
        modules_to_modify = []
        for i in range(len(self.bert_model.encoder.layer)):
            if (layer_nums is not None and i in layer_nums) or (layer_nums is None):
                mean_attention = new_attention_class(self.bert_model.config, linear_config) # self.bert_model.config, 
                #mean_attention.set_window_size(window_size)
                mean_attention.load_state_dict(self.bert_model.encoder.layer[i].attention.self.state_dict(), strict=False)

                self.bert_model.encoder.layer[i].attention.self = mean_attention

    def forward(self, *args, **kwargs):
        return self.bert_model(*args, **kwargs)

## Project configuration

In [10]:
config_path = 'config.yaml'

with open(config_path, "r") as f:
    config = ConfigWrapper(yaml.load(f, Loader=yaml.FullLoader))

In [11]:
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, max_length=config.general.max_len)
initial_model = AutoModel.from_pretrained(config.model.model_name).to(config.general.device)

## Load data

In [12]:
X_train, y_train, X_test, y_test = load_cached_dataset(config, layer=0)
X_train.shape, X_test.shape

((1862, 1538), (98, 1538))

In [13]:
config

{'data': {'data_path': 'data',
  'train_datasets': [['', 'imdb']],
  'train_datasets_fields': [['text', '']],
  'eval_datasets': [['glue', 'mrpc']],
  'eval_datasets_fields': [['sentence1', 'sentence2']],
  'cut_size': None,
  'cache_cut_size': 100,
  'prob_of_take': 0.01,
  'cache_features': True,
  'cache_train_features': True,
  'train_features_prefix': 'not_split',
  'cache_train_dataset': True,
  'train_prop': 0.95,
  'layers': 12,
  'heads': 12},
 'model': {'model_name': 'bert-base-uncased', 'attention_aproximation': ''},
 'general': {'device': 'cuda',
  'out_prediction': 'data/outputs.json',
  'out_metrics': 'data/metrics.txt',
  'max_len': 1024,
  'batch_size': 1,
  'd_model': 768},
 'attention_config': {'d_model': 768,
  'device': 'cuda',
  'features': ['hidden_to', 'hidden_from', 'pos_to', 'pos_from'],
  'batch_size': 1,
  'num_heads': 12,
  'split_heads': True}}

In [14]:
linear_model = BertWrapperLin(initial_model, LinearClassifierBertAttention, config, layer_nums=[6, 7, 8, 9, 10, 11])

TypeError: object of type 'int' has no len()

In [None]:
linear_model

In [16]:
encoded_inputs = tokenizer.encode(
                                'Hello! My name is... Hi, my name is... Slim Shady',
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)

In [17]:
aa = initial_model(encoded_inputs, output_hidden_states=True, output_attentions=True)

In [18]:
len(aa.attentions), aa.attentions[0].shape

(12, torch.Size([1, 12, 20, 20]))

In [None]:
linear_model.to(config.general.device)

In [None]:
aa = linear_model(encoded_inputs.to(config.general.device), output_hidden_states=True, output_attentions=True)

In [None]:
len(aa.attentions)

In [None]:
aa.attentions[-1].shape