In [None]:
import transformers
import torch
import datasets
import pdb

In [None]:
config = transformers.CONFIG_MAPPING['bert']().from_pretrained('bert-base-uncased')

In [None]:
bert = transformers.BertModel(config, add_pooling_layer=False).from_pretrained('bert-base-uncased')

In [None]:
corpus = datasets.load_dataset('bookcorpus',split='train')

In [None]:
tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

In [None]:
mini_tokenized = corpus.select(range(16)).map(lambda e: tokenizer(e['text'],truncation=True,padding='max_length',max_length=128),remove_columns=['text'])

In [None]:
mini_tokenized.set_format('torch')

In [None]:
mini_tokenized

In [None]:
output = bert(input_ids = mini_tokenized['input_ids'],
             attention_mask = mini_tokenized['attention_mask'],
             output_hidden_states=True,
             output_attentions=True
             )

In [None]:
output.hidden_states[0].shape

In [None]:
len(output.attentions)

In [None]:
output.attentions[0].shape

In [None]:
output.keys()

In [None]:
output.hidden_states[-3:]

In [None]:
torch.cat(output.hidden_states[-3:]).shape

In [None]:
out_stacked = torch.stack(output.hidden_states)

In [None]:
out_stacked.shape

In [None]:
out_stacked.transpose(0,1).transpose(2,3).shape

In [None]:
out_stacked.permute(1, 0, 3, 2).shape

In [None]:
out_stacked.reshape(16,-1,128).shape

In [None]:
filt1 = torch.nn.Conv1d(in_channels=9984, out_channels=12, kernel_size=1)

In [None]:
filtered = filt1(out_stacked.reshape(16,-1,128))

In [None]:
filtered.shape

In [None]:
attn_mask = mini_tokenized['attention_mask'].unsqueeze(1).expand_as(filtered)

In [None]:
attn_mask

In [None]:
(filtered*attn_mask).shape

In [None]:
a = torch.randn(32, 100, 20)  
m = torch.nn.Conv1d(100, 10, 1) 
out = m(a)
print(out.size())
print(m)

In [None]:
class ConvLayers(torch.nn.Module):
    def __init__(self,bert_layers,hidden_dim,channels_list):#args,config):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.out_channels = channels_list
        self.bert_layers = bert_layers
        
        in_channels = bert_layers * hidden_dim
        out_channels = channels_list[0]
        self.layer1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
        in_channels = out_channels
        out_channels = channels_list[1]
        self.layer2 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
    def forward(self,bert_output,attention_mask):
#         pdb.set_trace()
        bs = bert_output[0].shape[0]
        seq_len = bert_output[0].shape[1]
        
        stacked = torch.stack(bert_output.hidden_states[-self.bert_layers:])

        permuted = stacked.permute(1, 0, 3, 2) # swap batch_size with num_layers and hidden_size with seq_len
        
        input_to_convlayer = permuted.reshape(bs,-1,seq_len) # merge num_layers and seq_len
        
        out_layer1 = self.layer1(input_to_convlayer)
        out_layer2 = self.layer2(out_layer1)
        
        attn_mask = attention_mask.unsqueeze(1).expand_as(out_layer2)
        
        return out_layer2*attn_mask

In [None]:
class GeneralConvLayers(torch.nn.Module):
    def __init__(self,bert_layers,hidden_dim,channels_list):#args,config):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.out_channels = channels_list
        self.bert_layers = bert_layers
        
        in_channels = bert_layers * hidden_dim
        out_channels = channels_list[0]
        self.layer1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
        in_channels = out_channels
        out_channels = channels_list[1]
        self.layer2 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        
    def forward(self,bert_output,attention_mask):
        pdb.set_trace()
        bs = bert_output[0].shape[0]
        seq_len = bert_output[0].shape[1]
        
        stacked = torch.stack(bert_output.hidden_states[-self.bert_layers:])

        permuted = stacked.permute(1, 0, 3, 2) # swap batch_size with num_layers and hidden_size with seq_len
        
        input_to_convlayer = permuted.reshape(bs,-1,seq_len) # merge num_layers and seq_len
        
        out_layer1 = self.layer1(input_to_convlayer)
        out_layer2 = self.layer2(out_layer1)
        
        attn_mask = attention_mask.unsqueeze(1).expand_as(out_layer2)
        
        return out_layer2*attn_mask

In [None]:
layer = ConvLayers(13,768,[512,256])

In [None]:
layer(output,mini_tokenized['attention_mask']).shape