In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm
from transformers import BertTokenizerFast, AutoTokenizer, BertModel, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

def tokenize_data(file_path):
  def read_data(file_path):
    with open(file_path, 'r', encoding="utf-8") as f:
        data_in = json.load(f)
    datas = [{'id': x[0],'text1': x[1], 'text2': x[2], 'relation': x[3]} for x in data_in]
    return datas

  datas = read_data(file_path)
  #tokenizer = BertTokenizerFast.from_pretrained("ckiplab/albert-base-chinese")
  #tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
  #tokenizer = AutoTokenizer.from_pretrained("luhua/chinese_pretrain_mrc_roberta_wwm_ext_large")
  tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-electra-180g-large-discriminator")
  text1_tokenized = tokenizer([data["text1"] for data in datas], add_special_tokens=False, truncation = True)
  text2_tokenized = tokenizer([data["text2"] for data in datas], add_special_tokens=False,truncation = True)

  return datas, text1_tokenized, text2_tokenized

train_datas, train_text1_tokenized, train_text2_tokenized = tokenize_data('/kaggle/input/nlp-final/Final Project Task 1/team_train.json')
dev_datas, dev_text1_tokenized, dev_text2_tokenized = tokenize_data('/kaggle/input/nlp-final/Final Project Task 1/team_dev.json')
print(len(train_datas))

In [None]:
class text_Dataset(Dataset):
    def __init__(self, datas, tokenized_text1, tokenized_text2):
        self.datas = datas
        self.tokenized_text1 = tokenized_text1
        self.tokenized_text2 = tokenized_text2
        self.max_len = 254
        # Input sequence length = [CLS] + text1 + [SEP] + text2 + [SEP]
        self.max_seq_len = 1 + self.max_len + 1 + self.max_len + 1

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        data = self.datas[idx]
        id = data['id']
        label = data['relation']
        tokenized_text1 = self.tokenized_text1[idx]
        tokenized_text2 = self.tokenized_text2[idx]

        # Add special tokens (101: CLS, 102: SEP)
        input_ids_text1 = [101] + tokenized_text1.ids[:self.max_len] + [102]
        input_ids_text2 = tokenized_text2.ids[:self.max_len] + [102]

        # Pad sequence and obtain inputs to model
        input_ids, token_type_ids, attention_mask = self.padding(input_ids_text1, input_ids_text2)
        return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), label

    def padding(self, input_ids_text1, input_ids_text2):
        # Pad zeros if sequence length is shorter than max_seq_len
        padding_len = self.max_seq_len - len(input_ids_text1) - len(input_ids_text2)
        # Indices of input sequence tokens in the vocabulary
        input_ids = input_ids_text1 + input_ids_text2 + [0] * padding_len
        # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]
        token_type_ids = [0] * len(input_ids_text1) + [1] * len(input_ids_text2) + [0] * padding_len
        # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]
        attention_mask = [1] * (len(input_ids_text1) + len(input_ids_text2)) + [0] * padding_len

        return input_ids, token_type_ids, attention_mask


In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(BERTClassifier, self).__init__()
        #self.bert = BertModel.from_pretrained("ckiplab/albert-base-chinese")
        #self.bert = BertModel.from_pretrained("bert-base-chinese")
        #self.bert = BertModel.from_pretrained("luhua/chinese_pretrain_mrc_roberta_wwm_ext_large")
        self.bert = AutoModel.from_pretrained("hfl/chinese-electra-180g-large-discriminator")
        self.d_dim = self.bert.config.hidden_size
        self.dropout = nn.Dropout(0.25)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, token_type_ids, attention_mask):
        output = self.bert(input_ids=input_ids, token_type_ids = token_type_ids, attention_mask=attention_mask)
        #For ALBERT, BERT, RoBERTa
            #x = output.pooler_output
        #For ELECTRA
        x = output.last_hidden_state[:,0,:]#cls last_hidden
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [None]:
from accelerate import Accelerator

class Trainer():
    def __init__(self, config, train_loader, valid_loader):
        _, self.n_epochs, self.lr, self.patience, self.model = config.values()
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=10, T_mult=2)
        self.gradient_accumulation_steps = 16
        self.accelerator = Accelerator(gradient_accumulation_steps = self.gradient_accumulation_steps)
        self.model, self.optimizer, self.train_loader = self.accelerator.prepare(self.model, self.optimizer, self.train_loader) 

    def train(self):
        stale = 0
        best_acc = 0
        _exp_name = "ver"

        for epoch in range(self.n_epochs):
            print(f'Epoch_{epoch} starts!!')
            self.model.train()
            train_loss = []
            train_accs = []
            for batch in self.train_loader:
                data = [i.to(device) for i in batch]
                input_ids, token_type_ids, attention_mask, labels = data
                out = self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

                loss = self.criterion(out, labels.to(device))
                acc = (out.argmax(dim=-1) == labels.to(device)).float().mean()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                train_loss.append(loss.item())
                train_accs.append(acc)

            train_loss = sum(train_loss) / len(train_loss)
            train_acc = sum(train_accs) / len(train_accs)
            print(f"[ Train | {epoch + 1:03d}/{self.n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

            # validation
            self.model.eval()
            valid_loss = []
            valid_accs = []

            for batch in self.valid_loader:
                data = [i.to(device) for i in batch]
                input_ids, token_type_ids, attention_mask, labels = data
                
                with torch.no_grad():
                    out = self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

                loss = self.criterion(out, labels.to(device))
                acc = (out.argmax(dim=-1) == labels.to(device)).float().mean()

                valid_loss.append(loss.item())
                valid_accs.append(acc)
            valid_loss = sum(valid_loss) / len(valid_loss)
            valid_acc = sum(valid_accs) / len(valid_accs)

            print(f"[ Valid | {epoch + 1:03d}/{self.n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

            if valid_acc > best_acc:
                with open(f"./{_exp_name}_log.txt","a"):
                    print(f"[ Valid | {epoch + 1:03d}/{self.n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best")
            else:
                with open(f"./{_exp_name}_log.txt","a"):
                    print(f"[ Valid | {epoch + 1:03d}/{self.n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

            if valid_acc > best_acc:
                print(f"Best model found at epoch {epoch}, saving model")
                torch.save(self.model.state_dict(), f"{_exp_name}_best.ckpt") # only save best to prevent output memory exceed error
                best_acc = valid_acc
                stale = 0
            else:
                stale += 1
                if stale > self.patience:
                    print(f"No improvment {self.patience} consecutive epochs, early stopping")
                    break

            self.scheduler.step()
        #save checkpoint
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'train_loss': train_loss,
                    'valid_loss': valid_loss,
                    'best_acc': best_acc,
                    }, f"model.pt")
        return


In [None]:
config = {'batch_size': 4, #8
          'n_epochs': 100,
          'lr': 1e-5,
          'patience': 40,
          'model': BERTClassifier(num_classes = 3).to(device)
          }

train_set = text_Dataset(train_datas, train_text1_tokenized, train_text2_tokenized)
train_loader = DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True)
dev_set = text_Dataset(dev_datas, dev_text1_tokenized, dev_text2_tokenized)
dev_loader = DataLoader(dev_set, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True)

print(len(train_loader), len(dev_loader))

In [None]:
trainer = Trainer(config, train_loader, dev_loader)
trainer.train()