In [1]:
import transformers
import torch
import datasets
import pdb

In [2]:
config = transformers.CONFIG_MAPPING['bert']().from_pretrained('bert-base-uncased')

In [3]:
bert = transformers.BertModel(config, add_pooling_layer=False).from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
corpus = datasets.load_dataset('bookcorpus',split='train')

Reusing dataset bookcorpus (/mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700)


In [5]:
tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

In [6]:
mini_tokenized = corpus.select(range(16)).map(lambda e: tokenizer(e['text'],truncation=True,padding='max_length',max_length=128),remove_columns=['text'])

Loading cached processed dataset at /mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700/cache-2104a3e5d00ab1ee.arrow


In [7]:
mini_tokenized.set_format('torch')

In [8]:
mini_tokenized

Dataset({
    features: ['attention_mask', 'input_ids', 'token_type_ids'],
    num_rows: 16
})

In [9]:
output = bert(input_ids = mini_tokenized['input_ids'],
             attention_mask = mini_tokenized['attention_mask'],
             output_hidden_states=True,
#              output_attentions=True
             )

In [10]:
output.hidden_states[0].shape

torch.Size([16, 128, 768])

In [11]:
output_attn = output.last_hidden_state * mini_tokenized['attention_mask'].unsqueeze(2).expand_as(output.last_hidden_state)

In [12]:
output_attn.shape

torch.Size([16, 128, 768])

In [13]:
weight_p = torch.nn.Parameter(torch.rand(16,128))

In [14]:
weight = torch.rand(16,128)

In [15]:
out = torch.bmm(weight_p.unsqueeze(1),output.last_hidden_state)

In [16]:
torch.isclose((weight_p.unsqueeze(2).expand_as(output.last_hidden_state)*output.last_hidden_state).sum(1),
              out.squeeze())

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [17]:
class ParameterizedBertPooler(torch.nn.Module):
    def __init__(self,bert_layers,seq_len,batch_size):#args,config):
        super().__init__()
        
        self.bert_layers = bert_layers        
        self.weight_parameter = [torch.nn.Parameter(torch.rand(batch_size,1,seq_len)) for j in range(self.bert_layers)]
        
    def forward(self,bert_output,attention_mask):
        output = []
        start = 12
        end = 12-self.bert_layers
        j=0
        for i in range(start,end,-1): # go from 13 backwards
            hidden_state = bert_output.hidden_states[i]
            hidden_state_masked = hidden_state * attention_mask.unsqueeze(2).expand_as(hidden_state)
            out = torch.bmm(self.weight_parameter[j],hidden_state_masked) # batched dot product
            output.append(out.squeeze())
            j+=1 # weight_parameter counter
        
        return torch.stack(output).transpose(0,1) # swap batch and layer_dimension

In [18]:
class GeneralParameterizedPooler(torch.nn.Module):
    def __init__(self,num_layers,seq_len,batch_size):#args,config):
        super().__init__()
        
        self.num_layers = num_layers        
        self.weight_parameter = [torch.nn.Parameter(torch.rand(batch_size,1,seq_len)) for j in range(self.num_layers)]
        
    def forward(self,token_embeddings,attention_mask):
        assert len(token_embeddings)==self.num_layers, 'Number of embeddings must be the same as number of layers'
        output = []

        for i in range(self.num_layers):
            hidden_state = token_embeddings[i]
            hidden_state_masked = hidden_state * attention_mask.unsqueeze(2).expand_as(hidden_state)
            out = torch.bmm(self.weight_parameter[i],hidden_state_masked) # batched dot product
            output.append(out.squeeze())
        
        return torch.stack(output).transpose(0,1) # swap batch and layer_dimension

In [None]:
pooler = ParameterizedBertPooler(3,128,16)

In [None]:
pooler.weight_parameter[0].shape

In [None]:
list(pooler.parameters())

In [None]:
poolerG = GeneralParameterizedPooler(3,128,16)

In [None]:
pooled = poolerG(output.hidden_states[-3:],mini_tokenized['attention_mask'])

In [None]:
pooled.shape

## new parameter

In [None]:
bs = 16
seq_len = 128
num_layers = 3

In [None]:
seq_param = [torch.nn.Parameter(torch.rand(seq_len)-0.5) for j in range(num_layers)]
layer_param = torch.nn.Parameter(torch.rand(num_layers))

In [None]:
seq_param[0].shape, layer_param.shape

In [None]:
output.hidden_states[0].shape

In [None]:
seq_param[0].expand([bs,1,seq_len]).shape

In [None]:
output.hidden_states[0].shape

In [None]:
out=[torch.bmm(seq_param[i].expand([bs,1,seq_len]),output.hidden_states[i]).squeeze() for i in range(num_layers)]
out = torch.stack(out).transpose(0,1)

In [None]:
out.shape

In [None]:
torch.bmm(layer_param.expand(bs,1,num_layers),out).squeeze().shape

In [34]:
class GeneralParameterizedPooler(torch.nn.Module):
    def __init__(self,num_layers,seq_len,hidden_dim,batch_size,already_masked=True):#args,config):
        super().__init__()
        
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.already_masked = already_masked
        self.seq_len = seq_len
        
        self.seq_weight_parameter = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(seq_len)-0.5) for j in range(self.num_layers)]) # random numbers between [-0.5,0.5]
        self.layer_weight_parameter = torch.nn.Parameter(torch.rand(num_layers)-0.5) # random numbers between [-0.5,0.5]
        
    def forward(self,token_embeddings,attention_mask):
        assert len(token_embeddings)==self.num_layers, 'Number of embeddings must be the same as number of layers'
        output = []

        for i in range(self.num_layers):
            hidden_state = token_embeddings[i]
            if self.already_masked:
                hidden_state_masked = hidden_state
            else:
                hidden_state_masked = hidden_state * attention_mask.unsqueeze(2).expand_as(hidden_state)
            weight_parameter = self.seq_weight_parameter[i].expand([self.batch_size,1,self.seq_len])
            out = torch.bmm(weight_parameter,hidden_state_masked) # batched dot product
            output.append(out.squeeze())
            
        layer_sentence_embeddings = torch.stack(output).transpose(0,1) # swap batch and layer_dimension
        
        model_sentence_embedding = torch.bmm(self.layer_weight_parameter.expand(self.batch_size,1,self.num_layers),layer_sentence_embeddings).squeeze()
        
        return {'sentence_embedding' : model_sentence_embedding}

    def get_sentence_embedding_dimension(self):
        return self.hidden_dim

In [36]:
poolerG = GeneralParameterizedPooler(3,128,768,16)

In [None]:
for param in poolerG.parameters():
    print(type(param.data), param.size())

In [None]:
pooled = poolerG(output.hidden_states[-3:],mini_tokenized['attention_mask'])

In [None]:
pooled.shape

In [None]:
(torch.rand(100)-0.5).mean()

In [None]:
poolerG.weight_parameter[1].mean()

In [None]:
conv1 = torch.nn.Conv1d(in_channels=1,out_channels=512,kernel_size=1)

In [None]:
conv1

In [None]:
conv1(pooled.unsqueeze(1)).shape

## CONV1D layer

In [None]:
stacked = torch.cat(output.hidden_states[-3:],1)
stacked.shape

In [None]:
conv1 = torch.nn.Conv1d(in_channels=384,out_channels=512,kernel_size=1)

In [None]:
conv1(stacked).shape

In [24]:
class Conv1DLayers(torch.nn.Module):
    def __init__(self,num_trainable_layers,seq_len,channels_list):#args,config):
        super().__init__()
        
        self.seq_len = seq_len
        self.out_channels = channels_list
        self.num_trainable_layers = num_trainable_layers
        
        in_channels = num_trainable_layers * seq_len
        out_channels = channels_list[0]
        self.layer1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
        in_channels = out_channels
        out_channels = channels_list[1]
        self.layer2 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
    def forward(self,bert_output,attention_mask):
        
        attn_mask = attention_mask.unsqueeze(2).expand_as(bert_output[0])
        selected_layers = bert_output.hidden_states[-self.num_trainable_layers:]
        attended_bert_output = []
        
        for layer in selected_layers:
            attended_bert_output.append(layer*attn_mask)
            
        stacked = torch.cat(attended_bert_output,dim=1)
                
        out_layer1 = self.layer1(stacked)
        out_layer2 = self.layer2(torch.nn.ReLU()(out_layer1))
        
        return out_layer2

In [25]:
conv = Conv1DLayers(3,128,[512,1024])

In [26]:
conv_out = conv(output,mini_tokenized['attention_mask'])

In [29]:
conv_out.shape

torch.Size([16, 1024, 768])

In [37]:
poolerG = GeneralParameterizedPooler(1,1024,768,16)

In [38]:
pooled = poolerG([conv_out],mini_tokenized['attention_mask'])

In [40]:
pooled['sentence_embedding'].shape

torch.Size([16, 768])