In [None]:
from bucket_sampler import SortedSampler, BucketBatchSampler
from EMA import *
import random
import numpy as np
import torch

config = {
    'train_file_path': '/content/drive/MyDrive/train.json',
    'dev_file_path': '/content/drive/MyDrive/dev.json',
    'test_file_path': '/content/drive/MyDrive/test.json',
    'output_path': '/content/drive/MyDrive/output',
    'model_path': '/content/drive/MyDrive/BERT_model',
    'batch_size': 16,
    'num_epoches': 1,
    'max_seq_len': 64,
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'use_bucket': True,
    'bucket_multiplier': 200,
    'unsup_data_ratio': 1.5,
    'uda_softmax_temp': 0.4,
    'uda_confidence_threshold': 0.8,
    'device': 'cuda',
    'n_gpus': 0,
    'logging_step': 300,
    'ema_start_step': 500,
    'ema_start': False,
    'seed': 2021
}

if not torch.cuda.is_available():
  config['device'] = 'cpu' 
else:
  config['n_gpus'] = torch.cuda.device_count()
  config['batch_size'] *= config['n_gpus']

def seed_everything(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

seed_everything(config['seed'])


In [None]:
from tqdm import tqdm
import json
import pandas as pd

def parse_data(path, data_type='train'):
  sentence_a = []
  sentence_b = []
  labels = []
  with open(path, 'r', encoding='utf8') as f:
    for line in tqdm(f.readlines(), desc=f'Reading {data_type} data'):
      line = json.loads(line)
      sentence_a.append(line['sentence1'])
      sentence_b.append(line['sentence2'])
      if data_type!='test':
        labels.append(int(line['label']))
      else:
        labels.append(0)

  df = pd.DataFrame(zip(sentence_a,sentence_b,labels), columns=['text_a', 'text_b', 'labels'])
  return df

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(config['model_path'])

In [None]:
def build_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer):
  inputs_dict = tokenizer.encode_plus(sentence_a, sentence_b, add_special_tokens=True,
                                      return_token_type_ids=True, return_attention_mask=True)
  inputs['input_ids'].append(inputs_dict['input_ids'])
  inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
  inputs['attention_mask'].append(inputs_dict['attention_mask'])
  inputs['labels'].append(label)



In [None]:
def build_unsup_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer):
  lr_inputs_dict = tokenizer.encode_plus(sentence_a, sentence_b, add_special_tokens=True,
                                         return_token_type_ids=True, return_attention_mask=True)
  rl_inputs_dict = tokenizer.encode_plus(sentence_b, sentence_a, add_special_tokens=True,
                                         return_token_type_ids=True, return_attention_mask=True)
  
  inputs['input_ids'].append((lr_inputs_dict['input_ids'],rl_inputs_dict['input_ids']))
  inputs['token_type_ids'].append((lr_inputs_dict['token_type_ids'],rl_inputs_dict['token_type_ids']))
  inputs['attention_mask'].append((lr_inputs_dict['attention_mask'],rl_inputs_dict['attention_mask']))
  inputs['labels'].append(label)


In [None]:
from collections import defaultdict
def read_data(config, tokenizer):
  train_df = parse_data(config['train_file_path'], data_type='train')
  dev_df = parse_data(config['dev_file_path'], data_type='dev')
  test_df = parse_data(config['test_file_path'], data_type='test')

  data_df = {'train': train_df, 'dev': dev_df, 'test': test_df}
  processed_data = {}
  unsup_data = defaultdict(list)

  for data_type, df in data_df.items():
    inputs = defaultdict(list)

    for i, row in tqdm(df.iterrows(), desc=f'Preprocessing {data_type} data', total=len(df)):
      label = 0 if data_type == 'test' else row[2]
      sentence_a, sentence_b = row[0], row[1]
      build_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer)

      if data_type.startswith('test'):
        build_bert_inputs(inputs, label, sentence_b, sentence_a, tokenizer)

      build_unsup_bert_inputs(unsup_data, label, sentence_a, sentence_b, tokenizer)

    processed_data[data_type] = inputs

  processed_data['unsup_data'] = unsup_data
  return processed_data

In [None]:
data = read_data(config, tokenizer)

Reading train data: 100%|██████████| 34334/34334 [00:00<00:00, 214124.73it/s]
Reading dev data: 100%|██████████| 4316/4316 [00:00<00:00, 197353.19it/s]
Reading test data: 100%|██████████| 3861/3861 [00:00<00:00, 198992.49it/s]
Preprocessing train data: 100%|██████████| 34334/34334 [01:23<00:00, 413.23it/s]
Preprocessing dev data: 100%|██████████| 4316/4316 [00:08<00:00, 514.24it/s]
Preprocessing test data: 100%|██████████| 3861/3861 [00:09<00:00, 409.51it/s]


In [None]:
for i in data['train']['input_ids']:
  print(i)
  break

[101, 6010, 6009, 955, 1446, 5023, 7583, 6820, 3621, 1377, 809, 2940, 2768, 1044, 2622, 1400, 3315, 1408, 102, 955, 1446, 3300, 1044, 2622, 1168, 3309, 6820, 3315, 1408, 102]


In [None]:
from torch.utils.data import Dataset
class AFQMCDataset(Dataset):
  def __init__(self, data_dict):
    super(AFQMCDataset, self).__init__()
    self.data_dict = data_dict

  def __getitem__(self, index):
    data = (self.data_dict['input_ids'][index], self.data_dict['token_type_ids'][index],
            self.data_dict['attention_mask'][index], self.data_dict['labels'][index])
    return data
    
  def __len__(self):
    return len(self.data_dict['input_ids'])

In [None]:
class Collator():
  def __init__(self, max_seq_len, tokenizer):
    self.max_seq_len = max_seq_len
    self.tokenizer = tokenizer

  def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
    input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
    token_type_ids = torch.zeros_like(input_ids)
    attention_mask = torch.zeros_like(input_ids)

    for i in range(len(input_ids_list)):
      seq_len = len(input_ids_list[i])

      if seq_len <= max_seq_len:
        input_ids[i,:seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
        token_type_ids[i,:seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
        attention_mask[i,:seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
      else:
        input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len-1] + [self.tokenizer.sep_token_id], dtype=torch.long)
        token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len])
        attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len])

    labels = torch.tensor(labels_list, dtype=torch.long)
    return input_ids, token_type_ids, attention_mask, labels


  def __call__(self, examples):
    input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
    cur_max_seq_len = max(len(input_ids) for input_ids in input_ids_list)
    max_seq_len = min(cur_max_seq_len, self.max_seq_len)

    input_ids, token_type_ids, attention_mask, labels = self.pad_and_truncate(input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len)

    data_dict = {
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }
    return data_dict


In [None]:
collate_fn = Collator(config['max_seq_len'], tokenizer)

In [None]:
from torch.utils.data import Dataset
class UnsupAFQMCDataset(Dataset):
  def __init__(self, data_dict):
    super(UnsupAFQMCDataset, self).__init__()
    self.data_dict = data_dict

  def __getitem__(self, index):
    input_ids = self.data_dict['input_ids'][index]
    token_type_ids = self.data_dict['token_type_ids'][index]
    attention_mask = self.data_dict['attention_mask'][index]
    labels = self.data_dict['labels'][index]
    return (input_ids[0], token_type_ids[0], attention_mask[0],
            input_ids[1], token_type_ids[1], attention_mask[1], labels)
    
  def __len__(self):
    return len(self.data_dict['input_ids'])

In [None]:
class UnsupCollator():
  def __init__(self, max_seq_len, tokenizer):
    self.max_seq_len = max_seq_len
    self.tokenizer = tokenizer

  def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
    input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
    token_type_ids = torch.zeros_like(input_ids)
    attention_mask = torch.zeros_like(input_ids)

    for i in range(len(input_ids_list)):
      seq_len = len(input_ids_list[i])

      if seq_len <= max_seq_len:
        input_ids[i,:seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
        token_type_ids[i,:seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
        attention_mask[i,:seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
      else:
        input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len-1] + [self.tokenizer.sep_token_id], dtype=torch.long)
        token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len])
        attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len])

    labels = torch.tensor(labels_list, dtype=torch.long)
    return input_ids, token_type_ids, attention_mask, labels


  def __call__(self, examples):
    (ab_input_ids_list, ab_token_type_ids_list, ab_attention_mask_list, 
     ba_input_ids_list, ba_token_type_ids_list, ba_attention_mask_list, labels_list) = list(zip(*examples))
    cur_max_seq_len = max(len(input_ids) for input_ids in ab_input_ids_list)
    max_seq_len = min(cur_max_seq_len, self.max_seq_len)

    ab_input_ids, ab_token_type_ids, ab_attention_mask, labels = self.pad_and_truncate(ab_input_ids_list, ab_token_type_ids_list, ab_attention_mask_list, labels_list, max_seq_len)
    ba_input_ids, ba_token_type_ids, ba_attention_mask, labels = self.pad_and_truncate(ba_input_ids_list, ba_token_type_ids_list, ba_attention_mask_list, labels_list, max_seq_len)

    data_dict = {
        'ab_input_ids': ab_input_ids,
        'ab_token_type_ids': ab_token_type_ids,
        'ab_attention_mask': ab_attention_mask,
        'ba_input_ids': ba_input_ids,
        'ba_token_type_ids': ba_token_type_ids,
        'ba_attention_mask': ba_attention_mask,
        'labels': labels
    }
    return data_dict


In [None]:
from torch.utils.data import DataLoader, RandomSampler
def build_dataloader(config, data, tokenizer):
  train_dataset = AFQMCDataset(data['train'])
  dev_dataset = AFQMCDataset(data['dev'])
  test_dataset = AFQMCDataset(data['test'])

  unsup_dataset = UnsupAFQMCDataset(data['unsup_data'])

  collate_fn = Collator(config['max_seq_len'], tokenizer)
  unsup_collate_fn = UnsupCollator(config['max_seq_len'], tokenizer)

  if config['use_bucket']:
    train_sampler = RandomSampler(train_dataset)
    bucket_sampler = BucketBatchSampler(train_sampler, batch_size=config['batch_size'],
                                        drop_last=False, sort_key=lambda x:len(train_dataset[x][0]),
                                        bucket_size_multiplier=config['bucket_multiplier'])
    train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=bucket_sampler, num_workers=4, collate_fn=collate_fn)

    unsup_sampler = RandomSampler(unsup_dataset)
    unsup_bucket_sampler = BucketBatchSampler(unsup_sampler, batch_size=int(config['batch_size']*config['unsup_data_ratio']),
                                        drop_last=False, sort_key=lambda x:len(unsup_dataset[x][0]),
                                        bucket_size_multiplier=config['bucket_multiplier'])
    unsup_dataloader = DataLoader(dataset=unsup_dataset, batch_sampler=unsup_bucket_sampler, num_workers=4, collate_fn=unsup_collate_fn)

  else:
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, collate_fn=collate_fn)
    unsup_dataloader = DataLoader(dataset=unsup_dataset, batch_size=int(config['batch_size']*config['unsup_data_ratio']), shuffle=True, num_workers=4, collate_fn=unsup_collate_fn)
    
  dev_dataloader = DataLoader(dataset=dev_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=4, collate_fn=collate_fn)
  test_dataloader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=4, collate_fn=collate_fn)

  return unsup_dataloader, train_dataloader, dev_dataloader, test_dataloader

In [None]:
unsup_dataloader, train_dataloader, dev_dataloader, test_dataloader = build_dataloader(config, data, tokenizer)

  cpuset_checked))


In [None]:
for i in train_dataloader:
  print(i)
  for j in i.items():
    print(j)
  break

  cpuset_checked))


{'input_ids': tensor([[ 101,  955, 1446, 1372, 5543, 1146,  115,  115,  115, 3309, 6820,  749,
          720,  102, 6010, 6009,  955, 1446, 6820, 3621, 5543, 1146, 1914, 2208,
         3309,  102],
        [ 101, 2769, 2682, 2828, 5381, 1555, 6587, 3121, 2768,  955, 1446,  102,
          711,  784,  720, 6010, 6009,  955, 1446, 8024, 1359, 2768, 5381, 1555,
         6587,  102],
        [ 101, 2769, 4638, 5709, 1446, 6820, 3621, 5353, 1103,  749, 2582,  720,
         1905, 4415,  102, 2582,  720, 2798, 5543, 6820, 5709, 1446,  677, 4638,
         7178,  102],
        [ 101,  711,  784,  720, 2769, 4638, 5709, 1446, 7360, 1168,  115,  115,
          115,  102, 5709, 1446,  711,  784,  720, 7360,  856,  749,  671,  674,
         1914,  102],
        [ 101,  784,  720, 3198,  952, 2612, 1908, 6010, 6009, 5709, 1446, 7583,
         2428,  886, 4500,  102, 2582,  720,  886, 5709, 1446, 2612, 1908,  886,
         4500,  102],
        [ 101, 2582,  720, 1357, 3867, 5709, 1446, 6158, 5632, 122

In [None]:
from sklearn.metrics import f1_score,accuracy_score
def evaluation(config, model, val_dataloader):
  model.eval()
  preds = []
  labels = []
  val_loss = 0.
  val_iterator = tqdm(val_dataloader, desc='Evaluation', total=len(val_dataloader))

  with torch.no_grad():
    for batch in val_iterator:  
      labels.append(batch['labels'])
      batch_cuda = {item: value.to(config['device']) for item,value in list(batch.items())}
      batch_cuda['mode'] = 'val'
    
      loss, logits = model(**batch_cuda)[:2]

      if config['n_gpus'] > 1:
        loss = loss.mean()
      val_loss += loss.item()
      
      preds.append(logits.argmax(dim=-1).detach().cpu())

  avg_val_loss = val_loss/len(val_dataloader)
  labels = torch.cat(labels, dim=0).numpy()
  preds = torch.cat(preds, dim=0).numpy()

  f1 = f1_score(labels, preds)
  acc = accuracy_score(labels, preds)

  return avg_val_loss, f1, acc


In [None]:
from transformers import BertForSequenceClassification
import torch.nn as nn
class BertForAFQMC(BertForSequenceClassification):
  def forward(self, input_ids, token_type_ids, attention_mask, labels=None, mode='train'):
    outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)

    pooled_output = outputs[1]
    pooled_output = self.dropout(pooled_output)

    logits = self.classifier(pooled_output)
    outputs = (logits, )

    if mode == 'val':
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits, labels.view(-1))

      outputs = (loss, ) + outputs
    return outputs

In [None]:
def get_data(sup_batch, unsup_batch, config):
  grad_data = {}
  no_grad_data = {}

  sup_max_len = sup_batch['input_ids'].size(1)
  unsup_max_len = unsup_batch['ba_input_ids'].size(1)

  cur_max_len = max(sup_max_len, unsup_max_len)

  for item, sup_value in sup_batch.items():
    if item == 'labels':
      grad_data[item] = sup_value.to(config['device'])
      continue
    
    ba_unsup_value = unsup_batch[f'ba_{item}']
    ab_unsup_value = unsup_batch[f'ab_{item}']

    if sup_max_len == cur_max_len:
      padding_value = torch.zeros((ba_unsup_value.size(0), cur_max_len-unsup_max_len), dtype=ba_unsup_value.dtype)
      ba_unsup_value = torch.cat([ba_unsup_value, padding_value], dim=-1)

    else:
      padding_value = torch.zeros((sup_value.size(0), cur_max_len-sup_max_len), dtype=sup_value.dtype)
      sup_value = torch.cat([sup_value, padding_value], dim=-1)

    grad_value = torch.cat([sup_value, ba_unsup_value], dim=0)

    grad_data[item] = grad_value.to(config['device'])
    no_grad_data[item] = ab_unsup_value.to(config['device'])

  return grad_data, no_grad_data

In [None]:
def forward_no_grad(no_grad_data, config, model):
  with torch.no_grad():
    no_grad_logits = model(**no_grad_data)[0]

    no_grad_probs = torch.softmax(no_grad_logits/config['uda_softmax_temp'], dim=1)

    largest_probs, _ = no_grad_probs.max(dim=-1)
    unsup_loss_mask = largest_probs.gt(config['uda_confidence_threshold']).float()

  return unsup_loss_mask, no_grad_probs


In [None]:
def get_tsa_threshold(total_steps, global_steps):
  return np.exp((global_steps/ total_steps) * 5) / 2 + 0.5

In [None]:
def forward_with_grad(unsup_loss_mask, unsup_probs, config, cur_bs, model, grad_data, total_steps, global_steps):
  tsa_threshold = get_tsa_threshold(total_steps, global_steps)

  logits = model(**grad_data)[0]

  sup_logits, unsup_logits = logits.split([logits.size(0)-cur_bs, cur_bs])

  sup_labels = grad_data['labels'][:logits.size(0)-cur_bs]
  per_example_loss = nn.CrossEntropyLoss(reduction='none')(sup_logits, sup_labels)

  correct_label_probs = torch.softmax(sup_logits, dim=-1).gather(dim=-1, index=sup_labels.view(-1, 1))

  sup_loss_mask = correct_label_probs.le(tsa_threshold).squeeze().float()

  per_example_loss *= sup_loss_mask

  sup_loss = per_example_loss.sum()/max(sup_loss_mask.sum(), 1)


  unsup_log_probs = torch.log_softmax(unsup_logits, dim=1)

  per_example_kl_loss = nn.KLDivLoss(reduction='none')(unsup_log_probs, unsup_probs).sum(dim=-1)

  per_example_kl_loss *= unsup_loss_mask

  unsup_loss = per_example_kl_loss.sum()/ max(unsup_loss_mask.sum(), 1)

  loss = sup_loss + unsup_loss
  
  if config['n_gpus'] > 1:
    loss = loss.mean()
    sup_loss = sup_loss.mean()
    unsup_loss = unsup_loss.mean()

  return loss, tsa_threshold, unsup_loss, sup_loss


In [None]:
from transformers import AdamW
from tqdm import trange
import os
def train(config, train_dataloader, dev_dataloader, unsup_dataloader=None):
  model = BertForAFQMC.from_pretrained(config['model_path'])

  optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
  model.to(config['device'])

  total_steps = len(unsup_dataloader) * config['num_epoches']
  epoch_iterator = trange(config['num_epoches'])
  global_steps = 0
  train_loss = 0.
  logging_loss = 0.
  best_acc = 0.
  best_model_path = ''

  if config['n_gpus'] > 1:
    model = nn.DataParallel(model)

  train_iterator = iter(train_dataloader)

  for _ in epoch_iterator:
    unsup_iterator = tqdm(unsup_dataloader, desc='Training', total=len(unsup_dataloader))
    model.train()

    for unsup_batch in unsup_iterator:
      cur_bs = unsup_batch['ab_input_ids'].size(0)
      try:
        sup_batch = next(train_iterator)
      except StopIteration:
        train_iterator = iter(train_dataloader)
        sup_batch = next(train_iterator)

      grad_data, no_grad_data = get_data(sup_batch, unsup_batch, config)

      unsup_loss_mask, unsup_probs = forward_no_grad(no_grad_data, config, model)

      loss, tsa_threshold, unsup_loss, sup_loss = forward_with_grad(
          unsup_loss_mask, unsup_probs, config, cur_bs, model, grad_data, total_steps, global_steps
      )

      model.zero_grad()
      loss.backward()

      nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
      optimizer.step()

      if config['ema_start']:
        ema.update()

      train_loss += loss.item()
      global_steps +=1

      unsup_iterator.set_postfix_str(f'running training loss: {loss.item():.4f}')

      if global_steps % config['logging_step'] == 0:
        if global_steps >= config['ema_start_step'] and not config['ema_start']:
          print('\n>>> EMA starting ...')
          config['ema_start'] = True
          ema = EMA(model.module if hasattr(model, 'module') else model, decay=0.999)

        print_train_loss = (train_loss - logging_loss) / config['logging_step']
        logging_loss = train_loss

        if config['ema_start']:
          ema.apply_shadow()
        val_loss, f1, acc = evaluation(config, model, dev_dataloader)

        print_log = f'\n>>> training loss: {print_train_loss:.6f}, valid loss: {val_loss:.6f}'

        if acc > best_acc:
          model_save_path = os.path.join(config['output_path'], f'checkpoint-{global_steps}-{acc:.6f}')
          model_to_save = model.module if hasattr(model, 'module') else model_config
          model_to_save.save_pretrained(model_save_path)
          best_acc = acc
          best_model_path = model_save_path

        print_log += f'valid f1: {f1:.6f}, valid acc: {acc:.6f}'
        print(print_log)
        model.train()
        if config['ema_start']:
          ema.restore()

  return model, best_model_path



In [None]:
model, best_model_path = train(config, train_dataloader, dev_dataloader, unsup_dataloader)