# BERT's Attention and Dependency Syntax

This notebook contains code for comparing BERT's attention to dependency syntax annotations (see Sections 4.2 and 5 of [What Does BERT Look At? An Analysis of BERT's Attention](https://arxiv.org/abs/1906.04341))

In [1]:
import collections
import pickle
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

### Loading the data

Download the data used in this notebook from [here](https://drive.google.com/open?id=1DEIBQIl0Q0az5ZuLoy4_lYabIfLSKBg-). However, note that since Penn Treebank annotations are not public, this is dummy data where all labels are ROOT. See the README for extracting attention maps on your own data.

In [2]:
def load_pickle(fname):
  with open(fname, "rb") as f:
    return pickle.load(f)  # add, encoding="latin1") if using python3 and downloaded data
    
dev_data = load_pickle("E:/data/caption2_attn.pkl")#("./data/depparse/dev_attn.pkl")


In [3]:
p1=0
p2=0


file='E:/data/targetscene2.txt'
f=open(file,'a+')


In [4]:
def plot_attn2(examples, layer, head, place1,place2,count,hide_sep=False):
  """Plot BERT's attention for a particular head/example."""
  for i, example in enumerate(examples):
    
    attn = example["attns"][layer][head]

    if hide_sep:
      attn = np.array(attn)
      attn[:, 0] = 0
      attn[:, -1] = 0
      attn /= attn.sum(axis=-1, keepdims=True)
    
    words = ["[CLS]"] + example["words"] + ["[SEP]"]
    n_words = len(words)
    targetobjects=[]
    
    for m in range(n_words):
        for n in (place1+1,place2+1):
            if attn[m][n]>0.1:
                targetobjects.append(words[m])
                print(words[n])
    
    objects=count            
    for j in range(len(targetobjects)):
        objects=objects+','+targetobjects[j]
    
    objects=objects+'\n'
    print(objects)
    f.write(objects)

In [5]:

def plot_attn1(examples, layer, head, place1,count,hide_sep=False):
  """Plot BERT's attention for a particular head/example."""
  for i, example in enumerate(examples):
    
    attn = example["attns"][layer][head]

    if hide_sep:
      attn = np.array(attn)
      attn[:, 0] = 0
      attn[:, -1] = 0
      attn /= attn.sum(axis=-1, keepdims=True)
    
    words = ["[CLS]"] + example["words"] + ["[SEP]"]
    n_words = len(words)
    targetobjects=[]
    
    for m in range(n_words):
        n=place1+1
        if attn[m][n]>0.1:
            targetobjects.append(words[m])
            print(words[n])
    
    objects=count       
    for j in range(len(targetobjects)):
        objects=objects+','+targetobjects[j]
    
    objects=objects+'\n'
    print(objects)
    f.write(objects)
    

    
for i in range(0,100):
    # Examples from Figure 5 of the paper.
    objects=dev_data[i]["words"]
    
    flag=0
    count=0
    
    for j in range(len(objects)):
        if objects[j] in ['in','on','with','around','down','at'] and flag==0:
            p1=j
            flag=j
            count=count+1      
        
        if j!=flag and objects[j] in ['in','on','with','around','down','at']:
            p2=j
            count=count+1
    
    if i<10:
        im='0'+str(i)
    else:
        im=str(i)
        
    if count==1 :
        print(p1)
        plot_attn1([dev_data[i]], 8, 5,p1,im)
    elif count==2:
        print(p1,p2)
        plot_attn2([dev_data[i]], 8, 5,p1,p2,im)    
    else:
        im=im+'\n'
        print(im)
        
    


f.close()

#in ['person','people','boy','boys','girl','girls','man','men','woman','women','player']
#['in','on','with','around','down','at']8,5

4 7
on
on
on
00,on,a,chair

3
on
01,next

3
on
02,next

2 5
with
03,person

5
down
04,sidewalk

5
05

2 7
06

3
07

2 6
08

2 6
09

3 7
10

5
11

3
12

13

3
in
in
14,a,.

15

4
16

17

3
18

19

6
20

3
21

3 6
at
at
22,at,table

5
23

3
24

25

3
26

27

28

5
at
at
29,at,table

3
30

4
31

32

2
33

3
34

3 6
with
35,[CLS]

5
36

5
37

5
in
in
38,in,room

2
39

6
40

5
41

42

3 6
at
at
43,at,table

3
44

5
in
in
45,in,room

3 6
in
46,a

2
47

2 5
in
in
48,[CLS],.

3 6
49

3
in
in
in
in
in
in
50,in,a,next,to,a,.

5
around
51,table

3
in
52,front

5
53

3 6
at
at
54,at,table

2
55

4
56

2 5
with
57,person

6
58

5
59

5
60

61

3
in
in
62,a,.

5
at
63,desk

3
64

2 7
on
65,in

3 6
on
in
in
66,chair,in,a

2
67

68

3
69

70

3
71

6
72

3
at
at
73,at,table

74

5
75

3
with
76,building

3 10
77

5
78

6
79

3
80

5
81

4
82

5
83

84

5
85

5
86

87

6
88

5
89

3 6
at
at
90,at,desk

2 6
in
in
91,in,room

7
92

3 5
with
93,.

5
94

3 6
95

3 6
96

5
97

5
98

4
99



In [None]:
for i in range(0,100):
    # Examples from Figure 5 of the paper.
    objects=dev_data[i]["words"]
    
    flag=0
    count=0
    
    for j in range(len(objects)):
        if 'ing' in objects[j] and objects[j] not in ['painting','living','building'] and flag==0:
            p1=j
            flag=j
            count=count+1      
        
        if j!=flag and 'ing' in objects[j]:
            p2=j
            count=count+1
    
    if i<10:
        im='0'+str(i)
    else:
        im=str(i)
        
    if count==1 :
        print(p1)
        plot_attn1([dev_data[i]], 4, 3,p1,im)
    elif count==2:
        print(p1,p2)
        plot_attn2([dev_data[i]], 4, 3,p1,p2,im)    
    else:
        im=im+'\n'
        print(im)
        
        
        