In [None]:
import os
import numpy as np
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
from transformers.optimization import AdamW
from transformers.models.roberta.modeling_roberta import RobertaModel

import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

### Load Data

In [None]:
Cbase_pretrain_path = '../Dataset/HWU/Pretrain.txt'
Cbase_x = []
Cbase_y = []
for line in open(Cbase_pretrain_path):
    Cbase_y.append(line.split('\t')[1][:-1].replace('_',' '))
    Cbase_x.append(line.split('\t')[0])

label_list = list(set(Cbase_y))
label_yd_yi_dic = {}
for i in range(len(label_list)):
    label_yd_yi_dic[label_list[i]] = str(i)

Cbase_y_label = []
for i in range(len(Cbase_y)):
    Cbase_y_label.append(label_yd_yi_dic[Cbase_y[i]])

In [None]:
max_length = 32
batch_size = 10
N = 3000
base_lr = 5e-6
lr_scale = 100

In [None]:
###### 
from torch.utils.data import Dataset
import torch.utils.data as util_data

class AugmentPairSamples(Dataset):
    def __init__(self, train_x, train_y, train_y_des):
        assert len(train_y) == len(train_x)
        self.train_x = train_x
        self.train_y = train_y
        self.train_y_des = train_y_des
        
    def __len__(self):
        return len(self.train_y)

    def __getitem__(self, idx):
        return {'text': self.train_x[idx], 'label': self.train_y[idx], 'label_descrip': self.train_y_des[idx]}

train_dataset = AugmentPairSamples(Cbase_x, Cbase_y_label, Cbase_y)
train_loader = util_data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)

### Load Pretrained Model

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3"
device = "cuda:3" if torch.cuda.is_available() else "cpu"

In [None]:
bert_hidden_dim = 1024
pretrain_model_dir = 'roberta-large'

class RobertaClassificationHead(nn.Module):
    def __init__(self, bert_hidden_dim, num_labels):
        super(RobertaClassificationHead, self).__init__()
        self.dense = nn.Linear(bert_hidden_dim, bert_hidden_dim)
        self.dropout = nn.Dropout(0.1)
        self.out_proj = nn.Linear(bert_hidden_dim, num_labels)
    def forward(self, features):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
    
class RobertaForSequenceClassification(nn.Module):
    def __init__(self, tagset_size):
        super(RobertaForSequenceClassification, self).__init__()
        self.tagset_size = tagset_size

        self.roberta_single= RobertaModel.from_pretrained(pretrain_model_dir)
        self.single_hidden2tag = RobertaClassificationHead(bert_hidden_dim, tagset_size)

    def forward(self, input_ids, input_mask):
        outputs_single = self.roberta_single(input_ids, input_mask, None)
        hidden_states_single = outputs_single[1]
        score_single = self.single_hidden2tag(hidden_states_single) 
        return score_single

pre_model = RobertaForSequenceClassification(3)
pre_tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=True)
pre_model.load_state_dict(torch.load('../MNLI_pretrained.pt'), strict=False)
pre_model.to(device)

In [None]:
class PREModel(nn.Module):
    def __init__(self, pre_tokenizer, pre_model, device=device, training_flag=True):
        
        super(PREModel, self).__init__()
        self.device = device
        self.classifier_loss1 = nn.BCELoss()
        self.classifier_loss2 = nn.CrossEntropyLoss()
        self.training_flag = training_flag
        self.optimizer = None
        
        ###### SentenceBert Model ###### 
        self.tokenizer = pre_tokenizer
        self.sentbert = pre_model.roberta_single
        
        ####### Classifer Model ######  
        self.classifer = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(512, 1)).to(self.device)
        
    ###### SentenceBert Model ###### 
    def get_embeddings(self, features, pooling="mean"):
        bert_output =  self.sentbert.forward(**features)
        attention_mask = features['attention_mask'].unsqueeze(-1).to(device)
        all_output = bert_output[0]
        mean_output =  torch.sum(all_output*attention_mask, dim=1) / torch.sum(attention_mask, dim=1)
        return mean_output

    def set_optmizer(self, opt):
        self.optimizer = opt
        
    def forward(self, inputs, labels, aggregate=True):
        for s_idx in range(len(labels)): #10
            h0 = self.get_embeddings(inputs[s_idx], pooling="mean").to(self.device) #[20,1024]
            y_vector = torch.tensor([0.]*h0.shape[0]).to(self.device) #[20]
            y_idx = int(labels[s_idx])
            y_onehot = y_vector.clone()
            y_onehot[y_idx] = 1.
            
            if s_idx == 0:
                classifier_out = self.classifer(h0).squeeze().unsqueeze(0) #[1,20]
                y_onehot_all = y_onehot.unsqueeze(0)
            else:
                classifier_out = torch.cat([classifier_out,self.classifer(h0).squeeze().unsqueeze(0)],dim=0) #[10,20]
                y_onehot_all = torch.cat([y_onehot_all,y_onehot.unsqueeze(0)],dim=0)

        classifier_output = classifier_out
        classifier_sigmode_vector = F.sigmoid(classifier_output)
        classifier_softmax_vector = F.softmax(classifier_output,dim=1)
        cluster_result = [str(torch.argmax(classifier_softmax_vector[i]).item()) for i in range(classifier_softmax_vector.shape[0])]
            
        class_loss_all = torch.tensor(0.0).to(device)

        class_loss_s = 0.
        class_loss_s1 = self.classifier_loss1(classifier_sigmode_vector,y_onehot_all)
        class_loss_s2 = self.classifier_loss2(classifier_softmax_vector,y_onehot_all)
        class_loss_all += class_loss_s1 + class_loss_s2
        class_loss = torch.div(class_loss_all,h0.shape[0])
        
        if self.training_flag:
            loss = class_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        return class_loss.detach(), class_loss_s1.detach(), class_loss_s2.detach(), labels, cluster_result

In [None]:
model = PREModel(pre_tokenizer=pre_tokenizer, pre_model=pre_model, device=device, training_flag=True).to(device)
base_lr = 5e-6
lr_scale = 100
optimizer = torch.optim.Adam([
    {'params':model.sentbert.parameters()}, 
    {'params':model.classifer.parameters(), 'lr':base_lr * lr_scale}], base_lr)

model.set_optmizer(optimizer)

### Training

In [None]:
def prepare_task_input(batch, is_contrastive=False):
    if is_contrastive:
        text, label = batch['text'], batch['label']
        all_label_descrip = list(label_yd_yi_dic.keys())
        feat = {'input_ids': torch.tensor(0),'attention_mask':torch.tensor(0)}
        features1 = model.tokenizer.batch_encode_plus(text, max_length=max_length, return_tensors='pt', padding='longest', 
                                                         truncation=True)
        features2 = model.tokenizer.batch_encode_plus(all_label_descrip, max_length=max_length, return_tensors='pt', padding='longest', 
                                                         truncation=True)
        feat1_inputids = features1["input_ids"]
        feat1_attmask = features1["attention_mask"]
        feat2_inputids = features2["input_ids"]
        feat2_attmask = features2["attention_mask"]
        max_len = feat1_inputids.shape[1] + feat2_inputids.shape[1] 
        
        non_zero1 = torch.count_nonzero(feat1_attmask, dim=1).reshape(-1, 1)
        non_zero2 = torch.count_nonzero(feat2_attmask, dim=1).reshape(-1, 1)
        
        feats = []
        for i in range(non_zero1.shape[0]):
            index1 = non_zero1[i][0]
            feat_s = {'input_ids': torch.tensor(0),'attention_mask':torch.tensor(0)}
            for j in range(non_zero2.shape[0]):
                index2 = non_zero2[j][0]
                feat_inputid = torch.cat([feat2_inputids[j][:index2],feat1_inputids[i][1:index1]])
                feat_attmask = torch.tensor([1]*(index1+index2)+[0]*(max_len-index1-index2)).unsqueeze(0)
                feat_inputid = torch.cat([feat_inputid,torch.tensor([1]*(max_len-index1-index2+1))]).unsqueeze(0)
                if j == 0:
                    feat_inputid_s = feat_inputid
                    feat_attmask_s = feat_attmask
                elif j != 0:
                    feat_inputid_s = torch.cat([feat_inputid_s, feat_inputid],dim=0)
                    feat_attmask_s = torch.cat([feat_attmask_s, feat_attmask],dim=0)
            feat_s['input_ids'] = feat_inputid_s.to(device)
            feat_s['attention_mask'] = feat_attmask_s.to(device)
            feats.append(feat_s)  
        
        return feats, label
    
InsCL_all = []
SupCL_all = []
def training(train_loader):
    pre_acc = 0.
    epoch = 0
    ACC_result = []
    for i in np.arange(N):  
        model.train()
        model.training_flag = True        
        print('************'+str(i)+'************')        
        try:
            batch = next(train_loader_iter)
        except:
            train_loader_iter = iter(train_loader)
            batch = next(train_loader_iter)      
        
        feats, labels = prepare_task_input(batch, is_contrastive=True)    
        losses = model.forward(feats, labels, aggregate=True)
        ACC_result.append(losses[0])

        ###### Save Model ###### 
        best_acc = 0.
        best_ckpt = ''
        if (i+1) % 100 == 0:  # 76
            ckpt = '../Result/PreModel.pt'
            torch.save(model, ckpt)
    return None       

In [None]:
training(train_loader)