In [None]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from transformers import BertModel, BertConfig

class Bert_Encoder(nn.Module):

    def __init__(self, config, out_token=False):
        super(Bert_Encoder, self).__init__()

        # load model
        self.encoder = BertModel.from_pretrained(config.bert_path).cuda()
        self.bert_config = BertConfig.from_pretrained(config.bert_path)

        # the dimension for the final outputs
        self.output_size = config.encoder_output_size
        self.out_dim = self.output_size

        # find which encoding is used
        if config.pattern in ['standard', 'entity_marker']:
            self.pattern = config.pattern
        else:
            raise Exception('Wrong encoding.')

        if self.pattern == 'entity_marker':
            self.encoder.resize_token_embeddings(config.vocab_size + config.marker_size)
            self.linear_transform = nn.Linear(self.bert_config.hidden_size*2, self.output_size, bias=True)
        else:
            self.linear_transform = nn.Linear(self.bert_config.hidden_size, self.output_size, bias=True)

        self.layer_normalization = nn.LayerNorm([self.output_size])


    def get_output_size(self):
        return self.output_size

    def forward(self, inputs):
        # generate representation under a certain encoding strategy
        if self.pattern == 'standard':
            # in the standard mode, the representation is generated according to
            #  the representation of[CLS] mark.
            output = self.encoder(inputs)[1]
        else:
            # in the entity_marker mode, the representation is generated from the representations of
            #  marks [E11] and [E21] of the head and tail entities.
            e11 = []
            e21 = []
            # for each sample in the batch, acquire the positions of its [E11] and [E21]
            for i in range(inputs.size()[0]):
                tokens = inputs[i].cpu().numpy()
                e11.append(np.argwhere(tokens == 30522)[0][0])
                e21.append(np.argwhere(tokens == 30524)[0][0])

            # input the sample to BERT
            tokens_output = self.encoder(inputs)[0] # [B,N] --> [B,N,H]
            output = []
            # for each sample in the batch, acquire its representations for [E11] and [E21]
            for i in range(len(e11)):
                if inputs.device.type in ['cuda']:
                    instance_output = torch.index_select(tokens_output, 0, torch.tensor(i).cuda())
                    instance_output = torch.index_select(instance_output, 1, torch.tensor([e11[i], e21[i]]).cuda())
                else:
                    instance_output = torch.index_select(tokens_output, 0, torch.tensor(i))
                    instance_output = torch.index_select(instance_output, 1, torch.tensor([e11[i], e21[i]]))
                output.append(instance_output) # [B,N] --> [B,2,H]
            # for each sample in the batch, concatenate the representations of [E11] and [E21], and reshape
            output = torch.cat(output, dim=0)
            output = output.view(output.size()[0], -1) # [B,N] --> [B,H*2]
            
            output = self.linear_transform(output)



        return output

class Encoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.encoder = Bert_Encoder(args)
        self.output_size = self.encoder.out_dim
        dim_in = self.output_size
        self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, args.feat_dim)
            )
    def bert_forward(self, x):
        out = self.encoder(x)
        xx = self.head(out)
        xx = F.normalize(xx, p=2, dim=1)
        return out, xx


In [None]:
class Config:
    def __init__(self):
        self.bert_path = "datasets/bert-base-uncased"
        self.encoder_output_size = 768
        self.pattern = "entity_marker"
        self.vocab_size = 30522
        self.marker_size = 4
        self.feat_dim = 64
        self.device = 'cuda'
        self.num_workers = 0
        self.num_protos = 20
        self.seed = 2023
        self.max_length = 256
        self.bert_path = "datasets/bert-base-uncased"
        self.dataname = "FewRel"
        self.task_name = "FewRel"
        self.data_path = "sample_data/"
        self.num_of_relation = 80
        
        
args = Config()

encoder = Encoder(args=args).to(args.device)

In [None]:
from sklearn.cluster import KMeans
from dataloaders.my_sampler import data_sampler
from dataloaders.data_loader import get_data_loader

sampler = data_sampler(args=args, seed=args.seed)

dataset_id2sample, dataset_rel2sample = next(sampler)

current_relations = sampler.id2rel

def select_data(self, args, encoder, sample_set):
    data_loader = get_data_loader(args, sample_set, shuffle=False, drop_last=False, batch_size=1)
    features = []
    encoder.eval()
    for step, batch_data in enumerate(data_loader):
        labels, tokens, ind = batch_data
        tokens=torch.stack([x.to(args.device) for x in tokens],dim=0)
        with torch.no_grad():
            feature, rp = encoder.bert_forward(tokens)
        features.append(feature.detach().cpu())

    features = np.concatenate(features)
    num_clusters = min(20, len(sample_set))
    distances = KMeans(n_clusters=num_clusters, random_state=0).fit_transform(features)

    mem_set = []
    current_feat = []
    for k in range(num_clusters):
        sel_index = np.argmin(distances[:, k])
        instance = sample_set[sel_index]
        mem_set.append(instance)
        current_feat.append(features[sel_index])
    
    current_feat = np.stack(current_feat, axis=0)
    current_feat = torch.from_numpy(current_feat)
    return mem_set, current_feat, current_feat.mean(0)

for relation in current_relations:
    memorized_samples[relation], feat, temp_proto = self.select_data(args, encoder, dataset_rel2sample[relation])
    feat_mem.append(feat)
    proto_mem.append(temp_proto)
