In [None]:
import os
import wandb
import random
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader, Dataset

In [None]:
PROJECT_NAME = "OCNLI-CLS-MLM"
RUN_NAME = "macbert-base_lr3e-5_b32"
ENTITY = "jonahchow"

RANDOM_SEED = 1145141919
PRETRAINED_MODEL_NAME = "pretrained/hfl/chinese-macbert-base"

WORK_DIR = f"work_dirs/{RUN_NAME}"

OCNLI_PATH = "data/ocnli"
OCNLI_TRAIN_FILE = f"{OCNLI_PATH}/train.50k.json"
OCNLI_VAL_FILE = f"{OCNLI_PATH}/dev.json"
OCNLI_TEST_FILE = f"{OCNLI_PATH}/test.json"

S1_PROMPT = "{0}"
S2_PROMPT = "{0}"

LOG_INTERVAL = 1
CHECKPOINT_DIR = f"{WORK_DIR}/checkpoints"
CHECKPOINT_SAVE_BEST = 1
CHECKPOINT_SAVE_INTERVAL = 1
SAVE_EACH_EPOCH = False

NUM_EPOCHS = 5
BATCH_SIZE = 32
MAX_SEQ_LEN = 128
LEARNING_RATE = 3e-5

In [None]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

In [None]:
class OcnliDataset(Dataset):
    labels = ["entailment", "neutral", "contradiction"]

    label_map = {
        "entailment": 0,
        "neutral": 1,
        "contradiction": 2
    }

    levels = ["na", "easy", "medium", "hard"]

    level_map = {
        "na": 0,
        "easy": 1,
        "medium": 2,
        "hard": 3
    }

    def __init__(self, data_file, test_mode=False, s1_prompt="{0}", s2_prompt="{0}"):
        self.test_mode = test_mode
        self.data = pd.read_json(data_file, lines=True)
        self.data["sentence1"] = [s1_prompt.format(s) for s in self.data["sentence1"]]
        self.data["sentence2"] = [s2_prompt.format(s) for s in self.data["sentence2"]]
        if not self.test_mode:
            self.data = self.data[self.data["label"].isin(self.labels)]
            self.data["label"] = self.data["label"].map(self.label_map)
            self.data["level"] = self.data["level"].map(self.level_map)

        self.data = self.data.to_dict(orient="records")

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if self.test_mode:
            return self.data[idx]["sentence1"], self.data[idx]["sentence2"]
        
        return self.data[idx]["sentence1"], self.data[idx]["sentence2"], self.data[idx]["label"], self.data[idx]["level"]


In [None]:
class MultiSentencesDataLoader:
    def __init__(self, tokenizer, dataset, max_length, batch_size, device, shuffle=False, drop_last=True, test_mode=False):
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.max_length = max_length
        self.batch_size = batch_size
        self.device = device
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.test_mode = test_mode
        
        self.dataloader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            collate_fn=self._collate_fn,
            shuffle=self.shuffle,
            drop_last=self.drop_last
        )

    def _collate_fn(self, batch):
        sents = [(item[0], item[1]) for item in batch]
        if not self.test_mode:
            cls_labels = [item[2] for item in batch]
            levels = [item[3] for item in batch]

        inputs = self.tokenizer.batch_encode_plus(
            sents,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        token_type_ids = inputs["token_type_ids"].to(self.device)
        if self.test_mode:
            return input_ids, attention_mask, token_type_ids

        cls_labels = torch.LongTensor(cls_labels).to(self.device)

        return input_ids, attention_mask, token_type_ids, cls_labels
    
    def __iter__(self):
        return self.dataloader.__iter__()
    
    def __len__(self):
        return len(self.dataloader)

In [None]:
class BertClassificationHead(nn.Module):
    def __init__(self, hidden_size=768, num_classes=3, dropout_prob=0.1):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.out_proj = nn.Linear(hidden_size, num_classes)

    def forward(self, features, **kwargs):
        x = features[-1][:, 0, :]
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [None]:
class BertClassifier(nn.Module):
    def __init__(self, pretrained_model_name, num_classes, dropout_prob=0.1, test_mode=False):
        super().__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name, output_hidden_states=True)
        self.classifier = BertClassificationHead(
            hidden_size=self.bert.config.hidden_size,
            num_classes=num_classes,
            dropout_prob=dropout_prob
        )
        self.test_mode = test_mode
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls_logits = self.classifier(outputs.hidden_states)
        return cls_logits


In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"
# wandb.init(project=PROJECT_NAME, name=RUN_NAME, mode="disabled")
wandb.init(project=PROJECT_NAME, name=RUN_NAME, entity=ENTITY)
wandb.config.model_name = PRETRAINED_MODEL_NAME
wandb.config.random_seed = RANDOM_SEED
wandb.config.num_epochs = NUM_EPOCHS
wandb.config.batch_size = BATCH_SIZE
wandb.config.max_seq_len = MAX_SEQ_LEN
wandb.config.learning_rate = LEARNING_RATE
wandb.config.dataset_train = os.path.basename(OCNLI_TRAIN_FILE)
wandb.config.s1_prompt = S1_PROMPT
wandb.config.s2_prompt = S2_PROMPT

tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = OcnliDataset(OCNLI_TRAIN_FILE, s1_prompt=S1_PROMPT, s2_prompt=S2_PROMPT)
train_data_loader = MultiSentencesDataLoader(
    tokenizer=tokenizer,
    dataset=train_dataset,
    max_length=MAX_SEQ_LEN,
    batch_size=BATCH_SIZE,
    device=device,
    shuffle=True,
    drop_last=True
)

val_dataset = OcnliDataset(OCNLI_VAL_FILE, s1_prompt=S1_PROMPT, s2_prompt=S2_PROMPT)
val_data_loader = MultiSentencesDataLoader(
    tokenizer=tokenizer,
    dataset=val_dataset,
    max_length=MAX_SEQ_LEN,
    batch_size=1,
    device=device,
    shuffle=False,
    drop_last=False
)

model = BertClassifier(pretrained_model_name=PRETRAINED_MODEL_NAME, num_classes=3)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

best_checkpoints = [
    {"val_acc": 0.0, "state_dict": None}
] * CHECKPOINT_SAVE_BEST

for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    with tqdm(train_data_loader, desc=f"Train: Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch") as t:
        for input_ids, attention_mask, token_type_ids, cls_labels in t:
            optimizer.zero_grad()

            logits = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(logits, cls_labels)
            loss.backward()
            optimizer.step()

            correct = logits.argmax(dim=-1).eq(cls_labels).sum().item()
            accuracy = correct/cls_labels.size(0)
            train_loss += loss.item()
            train_correct += correct
            train_total += cls_labels.size(0)

            t.set_postfix(loss=loss.item(), acc=accuracy)

            if t.n % LOG_INTERVAL == 0:
                wandb.log({
                    "train_cls_loss": loss.item(),
                    "train_loss": loss.item(),
                    "train_cls_acc": accuracy
                })

        t.set_postfix(loss=train_loss/train_total, acc=train_correct/train_total)

    model.eval()
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        with tqdm(val_data_loader, desc=f"Val: Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch") as t:
            for input_ids, attention_mask, token_type_ids, cls_labels in t:
                logits = model(input_ids, attention_mask, token_type_ids)

                correct = logits.argmax(dim=-1).eq(cls_labels).sum().item()
                val_correct += correct
                val_total += cls_labels.size(0)

            t.set_postfix(acc=val_correct/val_total)

            wandb.log({
                "val_cls_acc": val_correct/val_total
            })

    os.makedirs(os.path.join(CHECKPOINT_DIR), exist_ok=True)
    if SAVE_EACH_EPOCH:
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"epoch_{epoch+1}.pt"))

    checkpoint = {"val_acc": val_correct/val_total, "state_dict": model.state_dict()}
    best_checkpoints.append(checkpoint)
    best_checkpoints = sorted(best_checkpoints, key=lambda x: x["val_acc"], reverse=True)[:CHECKPOINT_SAVE_BEST]


for i, checkpoint in enumerate(best_checkpoints):
    torch.save(checkpoint["state_dict"], os.path.join(CHECKPOINT_DIR, f"best_{i+1}.pt"))

wandb.finish()

In [None]:
# Test

test_dataset = OcnliDataset(OCNLI_TEST_FILE, test_mode=True, s1_prompt=S1_PROMPT, s2_prompt=S2_PROMPT)
test_data_loader = MultiSentencesDataLoader(
    tokenizer=tokenizer,
    dataset=test_dataset,
    max_length=MAX_SEQ_LEN,
    batch_size=1,
    device=device,
    shuffle=False,
    drop_last=False,
    test_mode=True
)

model = BertClassifier(pretrained_model_name=PRETRAINED_MODEL_NAME, num_classes=3).to(device)

state_dict = torch.load(os.path.join(CHECKPOINT_DIR, "best_1.pt"))
model.load_state_dict(state_dict)
model.eval()

with torch.no_grad():
    test_preds = []
    with tqdm(test_data_loader, desc="Test", unit="batch") as t:
        for input_ids, attention_mask, token_type_ids in t:
            logits = model(input_ids, attention_mask, token_type_ids)
            output = logits.argmax(dim=-1).item()
            test_preds.append(output)

test_df = pd.read_json(OCNLI_TEST_FILE, lines=True)
test_df["label"] = test_preds
test_df["label"] = test_df["label"].map({0: "entailment", 1: "neutral", 2: "contradiction"})
test_df = test_df[["label", "id"]]
test_df.to_json(os.path.join(WORK_DIR, f"ocnli_50k_predict.json"), orient="records", lines=True, force_ascii=False)

os.system(f"zip -j {WORK_DIR}/ocnli_50k_predict.zip {WORK_DIR}/ocnli_50k_predict.json")