In [1]:
import torch
from transformers import Trainer, TrainingArguments
from transformers import AlbertConfig, AlbertTokenizer, AlbertForPreTraining, ConvbertForPreTraining, ConvbertModel
from dataset import SOPDataset, MyTrainer, collate_batch
import os

model_dir = 'E:/ConvbertData/convbert_12/output'
#model_dir = 'D:/ConvbertData/albert_model_dir'

def get_last_checkpoint(dir_name):
    max_check = -1
    result = None
    for filename in os.listdir(dir_name):
        if 'checkpoint' in filename:
            step = int(filename.split('-')[1])
            if step > max_check:
                max_check = step
                result = filename
    return os.path.join(dir_name, result)

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = ConvbertModel.from_pretrained(get_last_checkpoint(model_dir))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
train_dataset = SOPDataset(directory='E:/ConvbertData/text_data/cache', batch_size=1, tokenizer=tokenizer, mlm_probability=0.15)

  from ._conv import register_converters as _register_converters


In [2]:
inputs = tokenizer.encode("The cat sat on the mat", return_tensors='pt')
outputs = model(inputs.to(device), output_attentions=True)
print(len(outputs))
print(outputs[0].shape)
print(outputs[1].shape)
print(len(outputs[2]))
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) 

3
torch.Size([1, 8, 768])
torch.Size([1, 768])
12


In [3]:
layers = []
kernel_size = 63
padding_l = kernel_size // 2

with torch.no_grad():
  for layer in range(12):
    seq_len = attention[layer].shape[0] // 12
    layer_attention = attention[layer].view(seq_len, 12, kernel_size)
    #print(layer_attention[0, 0])
    result = []
    for word in range(seq_len):
      att = layer_attention[word]
      if word < padding_l:
        att = att[:, padding_l - word:]
      if seq_len - word - 1 < padding_l:
        att = att[:, :-(padding_l - seq_len + word + 1)]
      if word > padding_l:
        att = torch.nn.functional.pad(att, (0, 0, word - padding_l, 0), value=0)
      if seq_len - word + 1 > padding_l:
        att = torch.nn.functional.pad(att, (0, 0, 0, seq_len - word + 1 - padding_l), value=0)
      result.append(att)
    layers.append(torch.stack(result).permute(1, 0, 2).unsqueeze(0))
    #break

#print(layers[0])

In [4]:
from bertviz import head_view
head_view(layers, tokens)

<IPython.core.display.Javascript object>

In [12]:
def print_attentions(text):
    inputs = tokenizer.encode(text, return_tensors='pt')
    tokens = tokenizer.convert_ids_to_tokens(inputs[0])
    outputs = model(inputs.to(device), output_attentions=True)
    attention = outputs[-1]
    
    layers = []
    kernel_size = 63
    padding_l = kernel_size // 2

    with torch.no_grad():
        for layer in range(12):
            seq_len = attention[layer].shape[0] // 12
            layer_attention = attention[layer].view(seq_len, 12, kernel_size)
            #print(layer_attention[0, 0])
            result = []
            for word in range(seq_len):
                att = layer_attention[word]
                if word < padding_l:
                    att = att[:, padding_l - word:]
                if seq_len - word - 1 < padding_l:
                    att = att[:, :-(padding_l - seq_len + word + 1)]
                if word > padding_l:
                    att = torch.nn.functional.pad(att, (0, 0, word - padding_l, 0), value=0)
                if seq_len - word + 1 > padding_l:
                    att = torch.nn.functional.pad(att, (0, 0, 0, seq_len - word + 1 - padding_l), value=0)
                result.append(att)
            layers.append(torch.stack(result).permute(1, 0, 2).unsqueeze(0))
    head_view(layers, tokens)

In [13]:
print_attentions("hi mark")

<IPython.core.display.Javascript object>

In [14]:
print_attentions("New German laws allow regulators to prohibit any anti-competitive behaviour at an earlier stage")

<IPython.core.display.Javascript object>