In [1]:
from opennre import encoder, model, framework
import opennre
import os
import json
import numpy as np
from tqdm import tqdm_notebook as tqdm
import torch
from torch import nn
import math

In [2]:
root_path = '..'
word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
rel2id = json.load(open('../benchmark/nyt10-aug/nyt10_rel2id.json'))

In [3]:
sentence_encoder = opennre.encoder.PCNNEncoder(
    token2id=word2id,
    max_length=100,
    word_size=50,
    position_size=5,
    hidden_size=230,
    blank_padding=True,
    kernel_size=3,
    padding_size=1,
    word2vec=word2vec,
    dropout=0.5
)
train_loader = framework.BagRELoader(
                '../benchmark/nyt10-aug/lt_train_augmented.txt',
                rel2id,
                sentence_encoder.tokenize,
                160,
                True,
                bag_size=0,
                entpair_as_bag=False)
bag_encoder = opennre.model.IntraBagAttention(sentence_encoder, len(rel2id), rel2id)

2021-03-08 15:48:33,414 - root - INFO - Initializing word embedding with word2vec.


In [4]:
def conv_bn_relu(in_c, out_c):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
        nn.BatchNorm2d(out_c),
        nn.ReLU(True)
    )
def calc_init_centroid(bag_reps, num_sbags_width, num_sbags_height):
    
    centroids = nn.functional.adaptive_avg_pool2d(bag_reps, (num_sbags_height, num_sbags_width))
    with torch.no_grad():
        num_sbags = num_sbags_width * num_sbags_height
        labels = torch.arange(num_sbags).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids)
        init_label_map = nn.functional.interpolate(labels, size=(height, width), mode="nearest")
        init_label_map = init_label_map.repeat(1, 1, 1, 1)

    init_label_map = init_label_map.reshape(1, -1)
    centroids = centroids.reshape(1, 1, -1)

    return centroids, init_label_map

@torch.no_grad()
def get_abs_indices(init_label_map, num_sbags_width):
    b, n_pixel = init_label_map.shape
    device = init_label_map.device
    r = torch.arange(-1, 2.0, device=device)
    relative_spix_indices = torch.cat([r - num_sbags_width, r, r + num_sbags_width], 0)

    abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long()
    abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long()
    abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long()
 
    return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0)

@torch.no_grad()
def get_hard_abs_labels(affinity_matrix, init_label_map, num_sbags_width):
    relative_label = affinity_matrix.max(1)[1]
    r = torch.arange(-1, 2.0, device=affinity_matrix.device)
    relative_spix_indices = torch.cat([r - num_sbags_width, r, r + num_sbags_width], 0)
    label = init_label_map + relative_spix_indices[relative_label]
    return label.long()

In [7]:

class BagCluster(nn.Module):
    """
    Instance attention for bag-level relation extraction.
    """

    def __init__(self, bag_encoder, num_class, rel2id, cluster_num, num_iter):
        """
        Args:
            bag_encoder: encoder for bag
            num_class: number of classes
            id2rel: dictionary of id -> relation name mapping
            cluster_num: number of cluster
            num_iter: number of iteration
        """
        super().__init__()
        self.bag_encoder = bag_encoder
        self.num_class = num_class
        self.cluster_num = cluster_num
        self.num_iter = num_iter
        self.feature_dim = 128
        self.fc = nn.Linear(self.bag_encoder.sentence_encoder.hidden_size, num_class)
        self.softmax = nn.Softmax(-1)
        self.rel2id = rel2id
        self.id2rel = {}
        self.drop = nn.Dropout()
        for rel, id in rel2id.items():
            self.id2rel[id] = rel
            
        self.scale1 = nn.Sequential(
            conv_bn_relu(1, 64),
            conv_bn_relu(64, 64)
        )
        self.scale2 = nn.Sequential(
            nn.MaxPool2d(3, 2, padding=1),
            conv_bn_relu(64, 64),
            conv_bn_relu(64, 64)
        )
        self.scale3 = nn.Sequential(
            nn.MaxPool2d(3, 2, padding=1),
            conv_bn_relu(64, 64),
            conv_bn_relu(64, 64)
        )

        self.output_conv = nn.Sequential(
            nn.Conv2d(64*3+5, self.feature_dim-5, 3, padding=1),
            nn.ReLU(True)
        )



    
    def forward(self, labels, bag_reps, train=True):
        """
        Args:
            bag_reps: (B, 3H) bag representations of a batch
            num_sbags: (int) A number of superbags 
        Return:
            cluster, (C, , 3H) C is the number of cluster, the cluster of bag.
        """
        height,width = bag_reps.shape[-2:]
        num_sbags_width = int(math.sqrt(num_sbags*width / height))
        num_sbags_height = int(math.sqrt(num_sbags * height / width))

        spixel_feature, init_label_map = calc_init_centroid(bag_feature, num_sbags_width, num_sbags_height)
        abs_indices = get_abs_indices(init_label_map, num_sbags_width)
        
        bag_feature = bag_reps.reshape(*bag_reps[-2:], -1)
        permuted_bag_feature = bag_feature.permute(0, 2, 1).contiguous()
        
        return bag_feature
            


In [4]:
import random
def convert_index(bag_num, label, sbag_num, is_random=True):
    """
    初始化超包和包的对应索引
    Args:
    bag_num: 包的数量
    label: 包对应的标签
    sbag_num: 超包的数量
    is_random: 是否随机初始化，否则为使用标签初始化。
    return：
    cor_map: b->sb的关系，包和超包的关联数组，索引i的值为j，表示第i个包属于第j个超包
    sbag_map: sb->b的关系，(sbag_num, )
    """
    cor_map = [0] * bag_num
    if is_random:
        for i in range(bag_num):
            if i < sbag_num:
                cor_map[i] = i
            else:
                cor_map[i] = math.floor(random.random() * sbag_num)
    else:
        visited_label = []
        for idx, l in enumerate(label):
            if l not in visited_label:
                visited_label.append(l)
            sbag_index = visited_label.index(l)
            cor_map[idx] =sbag_index
    sbag_map = [[] for _ in range(sbag_num)]
    for idx, sbag_index in enumerate(cor_map):
        sbag_map[sbag_index].append(idx)
    return cor_map, sbag_map

In [32]:
def SbagFeature(bag_feat, label, num_sbag):
    """
    通过平均值获取超包的特征
    Args: 
    bag_feat: 包特征(B,3H)
    label: 包对应的label
    num_sbag: 超包的个数
    return 
    超包的特征
    """
    bag_num, hidden_size = bag_feat.shape
    cor_map, sbag_map = convert_index(bag_num, label, num_sbag, is_random=True)
    ave_feat = []
    for sbags_idx in sbag_map:
        sbags_feat = []
        for idx in sbags_idx:
            try:
                sbags_feat.append(bag_feat[idx])
            except:
                print(idx)
        sbags_feat = torch.stack(sbags_feat)
#         print(sbags_feat.size(), sbags_feat.shape)
        ave_feat.append(torch.mul(sbags_feat.sum(0), 1/len(sbags_idx)))
    return torch.stack(ave_feat), cor_map, sbag_map

In [33]:
def Passoc(bag_feat, sbag_feat, sb2b_index, scale_value=1):
    '''
    calculate the distance between bag with each superbag. each iteration spixel_init is fixed,
    only change the feature and association.
    :param bag_feat: (B,3H)
    :param sbag_feat: (D, 3H) D is the number of surpixels
    :param p2sp_index_: (D)
    :param scale_value:
    :return:
    distance (B*D*3H)
    '''
    b, h = bag_feat.shape
    sb_num = len(sb2b_index)
#     if len(p2sp_index_.shape) == 3:
#         p2sp_index_ = torch.from_numpy(p2sp_index_).unsqueeze(0)
#         invisible_ = torch.from_numpy(invisible_).unsqueeze(0)
    bag_feat = bag_feat.repeat(1, sb_num).reshape(b, sb_num, h) # (B*D*3H)
    sbag_feat = sbag_feat.repeat(b, 1).reshape(b, -1, h ) #(B*D*3H)

    distance = torch.pow(sbag_feat - bag_feat, 2.0)  # 9*B*C*H*W  (occupy storage 440M)
    distance = distance * scale_value  # B*D*3H
    return distance

In [34]:
def compute_assignments(sbag_feature, bag_rep, sb2b_index):

    pixel_spixel_neg_dist = Passoc(bag_rep, sbag_feature, sb2b_index)
    pixel_spixel_assoc = (pixel_spixel_neg_dist - pixel_spixel_neg_dist.max(1, keepdim=True)[0]).exp()
    pixel_spixel_assoc = pixel_spixel_assoc / (pixel_spixel_assoc.sum(1, keepdim=True))
    
    return pixel_spixel_assoc

In [35]:
def SpixelFeature2(bag_feature, weight, num_spixels):
    '''
    calculate spixel feature according to the similarity matrix between pixel and spixel
    :param bag_feature: B*3H
    :param weight:  B*D*3H
    :return:B*3H
    '''
    b, h = bag_feature.shape
    
    feat = bag_feature.reshape(b, 1, h) # B*1*3H
    
    s_feat = feat * weight  #B*D*3H
    
#     s_feat = s_feat.reshape(b, 1, num_spixels, -1)  #B*D*(n/D)
#     weight = weight.reshape(b, 1, num_spixels, -1)  #B*D*(n/D)
    
    weight = weight.sum(0)  #D*H
    s_feat = s_feat.sum(0)  #D*H

    S_feat = s_feat / (weight + 1e-5)
    S_feat = S_feat * (weight > 0.001).float()
    return S_feat

In [36]:
def exec_iter(sbag_feature, bag_rep, sb2b_index):

    # Compute pixel-superpixel assignments
    # t3 = time.time()
    # print(f't2-t1:{t2-t1:.3f}, t3-t2:{t3-t2:.3f}')
    pixel_assoc = compute_assignments(sbag_feature, bag_rep, sb2b_index)
    sbag_feat = SpixelFeature2(bag_rep, pixel_assoc, len(sb2b_index))
    return sbag_feat, pixel_assoc

In [37]:
# 解码，用于计算紧致度损失
def decode_features(pixel_spixel_assoc, spixel_feat):
    """
    :param pixel_spixel_assoc: B*D*3H the distance of each bag and each sbag
    :param spixel_feat: B*D*3H sbag feature
    :return:
    """
    
    b, d, h, = pixel_spixel_assoc.shape
    recon_feat = spixel_feat.sum(1) + 1e-10  # B*3H

    # norm
    try:
        assert recon_feat.min() >= 0., 'fails'
    except:
        import pdb
        pdb.set_trace()
    #
    print(recon_feat.shape)
    recon_feat = recon_feat / recon_feat.sum(1, keepdim=True)


    return recon_feat

In [38]:
def convert_label(label, num=50):
    
    label = torch.tensor(label)
    problabel = np.zeros((1, num, label.shape[0])).astype(np.float32)

    ct = 0
    for t in np.unique(label).tolist():
        if ct >= num:
            print(np.unique(label).shape)
            break
            # raise IOError
        else:
            problabel[:, ct, :] = (label == t)
        ct = ct + 1

    label2 = np.squeeze(np.argmax(problabel, axis = 1))

    return label2, problabel

In [39]:
def compute_final_bag_rep(bag_rep, sbag_feat, b2sb_index):
    res = []
    for (idx,sb_idx) in enumerate(b2sb_index):
        res.append((bag_rep[idx] + sbag_feat[sb_idx]) / 2.0)
    return torch.stack(res)

In [11]:
tokenizer = sentence_encoder.tokenizer
file = open('./id2voc.txt', 'w+')
file.write(json.dumps(tokenizer.inv_vocab))


8491501

In [60]:
train_loader.dataset.bag_scope[1225]

[23814,
 23815,
 23816,
 23817,
 23818,
 23819,
 23820,
 69692,
 69693,
 69694,
 69695,
 69696,
 69697,
 69698]

In [62]:
train_loader.dataset.name2id[('m.02b3v0', 'm.013hxv', '/people/person/place_of_birth')]

1225

In [63]:
# bag_cluster = BagCluster(bag_encoder, len(rel2id), rel2id, 100, 100)
num_iter = 50
train = True
for iter, data in enumerate(train_loader):
#     if torch.cuda.is_available():
#         for i in range(len(data)):
#             try:
#                 data[i] = data[i].cuda()
#             except:
#                 pass
    label = data[0]
    bag_name = data[1]
    scope = data[2]
    tokens = data[3]
    args = data[3:]
    print(bag_name[0], sentence_encoder.tokenizer.convert_ids_to_tokens(tokens[0][2].numpy().tolist()))
#     sentence_encoder.tokenizer.convert_ids_to_tokens(tokens[0][0])
#     bag_rep = bag_encoder(label, scope, *args, bag_size=0)
#     sbag_feature, b2sb_index, sb2b_index = SbagFeature(bag_rep, label, 20)
    break
#     for i in range(num_iter):
#         sbag_feature, _ = exec_iter(sbag_feature, bag_rep, sb2b_index)
        
#     final_bag_assoc = compute_assignments(sbag_feature, bag_rep, sb2b_index)
#     if train:
#         new_sbag_feat = SpixelFeature2(bag_rep, final_bag_assoc, len(sb2b_index))
#         print(new_sbag_feat.shape)
# #         new_spix_indices = compute_final_spixel_labels(final_bag_assoc, p2sp_index)
# #         recon_feat2 = Semar(new_spixel_feat, new_spix_indices)
# #         problabel = convert_label(label)
# #         print(problabel[0], problabel[0].shape)
# #         spixel_label = SpixelFeature2(problabel, final_bag_assoc, len(sb2b_index))
#         final_bag_feat = compute_final_bag_rep(bag_rep, new_sbag_feat, b2sb_index)
    
#         print(final_bag_feat, final_bag_feat.shape)
    

('m.053x3n', 'm.0fnb4', '/people/person/place_of_birth') ['these', 'include', "''", 'the', 'best', 'poems', 'of', 'shamsur', 'rahman', ',', "''", 'published', 'last', 'year', 'in', 'new', 'delhi', ';', 'and', "''", 'the', 'devotee', ',', 'the', 'combatant', ':', 'choose', 'poems', 'of', 'shamsur', 'rahman', ',', "''", 'published', 'in', '2000', 'in', 'dhaka', '.', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [59]:
def convert_index(bag_num, labels, sbag_num, is_random=True):
    """
    初始化超包和包的对应索引
    Args:
    bag_num: 包的数量
    label: 包对应的标签
    sbag_num: 超包的数量
    is_random: 是否随机初始化，否则为使用标签初始化。
    return：
    cor_map: b->sb的关系，包和超包的关联数组，索引i的值为j，表示第i个包属于第j个超包
    sbag_map: sb->b的关系，(sbag_num, )
    cluster_labels: sb对应的label
    """
    cor_map = [-1] * bag_num
    if is_random:
        cluster_labels = labels
        for i in range(bag_num):
            if i < sbag_num:
                cor_map[i] = i
            else:
                cor_map[i] = math.floor(random.random() * sbag_num)
    else:
        # 先将不重合的超包标签定下来
        superbag_labels = list(set(labels))
        for (idx, label) in enumerate(superbag_labels):
            idx = labels.index(label)
            cor_map[idx] = idx
        # 当前超包的数量小于需要的数量，加入重合的超包标签
        iter_num = sbag_num - len(superbag_labels)
        for _ in range(iter_num):
            # 将第一个出现的未标记的包的标签标记为新的超包标签
            empty_idx = cor_map.index(-1)
            cor_map[empty_idx] = len(superbag_labels)
            superbag_labels.append(labels[empty_idx])
        for (idx, label) in enumerate(cor_map):
            if label < 0:
                cur_label = labels[idx]
                superbag_idxs = [idx for (idx, label) in enumerate(superbag_labels) if label == cur_label]
                selected = random.choice(superbag_idxs)
                cor_map[idx] = selected
    sbag_map = [[] for _ in range(sbag_num)]
    print(cor_map)
    for idx, sbag_index in enumerate(cor_map):
        sbag_map[sbag_index].append(idx)
    return cor_map, sbag_map

In [34]:
a = torch.Tensor([[0,1,3],[3,3,3]])
b = torch.Tensor([[1,4,3],[2,3,4]])

In [43]:
assignments = []
for i in a:
    distances = torch.pow(i - b, 2).sum(1)
    print(distances)
    assignments.append(distances.argmin().item())

tensor([10.,  9.])
tensor([5., 2.])


In [44]:
print(assignments)

[1, 1]


In [6]:
d[a].sum(0)

tensor([4., 4., 4.])

In [62]:
d.long()

tensor([0, 1, 2, 3, 4, 5, 6, 0, 3, 5])

In [68]:
len(bag_encoder.id2rel)

58