In [None]:
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from datasets import load_dataset
from evaluate import load
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
#  You can install and import any other libraries if needed

import time
import numpy as np
import random
import os

os.makedirs("./saved_models", exist_ok=True)


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # print(f"\n\nUsing random seed {seed}")


set_seed(int(time.time()))

In [None]:
# Some Chinese punctuations will be tokenized as [UNK], so we replace them with English ones
token_replacement = [
    ["：", ":"],
    ["，", ","],
    ["“", '"'],
    ["”", '"'],
    ["？", "?"],
    ["……", "..."],
    ["！", "!"],
]

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    "google-bert/bert-base-uncased", cache_dir="./cache/"
)

In [None]:
class SemevalDataset(Dataset):
    def __init__(self, split="train") -> None:
        super().__init__()
        assert split in ["train", "validation", "test"]
        self.data = load_dataset(
            "sem_eval_2014_task_1",
            split=split,
            trust_remote_code=True,
            cache_dir="./cache/",
        ).to_list()

    def __getitem__(self, index):
        d = self.data[index]
        # Replace Chinese punctuations with English ones
        for k in ["premise", "hypothesis"]:
            for tok in token_replacement:
                d[k] = d[k].replace(tok[0], tok[1])
        return d

    def __len__(self):
        return len(self.data)


data_sample = SemevalDataset(split="train").data[:3]
print(f"Dataset example: \n{data_sample[0]} \n{data_sample[1]} \n{data_sample[2]}")

In [None]:
# Define the hyperparameters
# You can modify these values if needed
adamw_lr = 0.00148975071320148
adamw_weight_decay = 0.0661990066204037

muon_lr = 0.000638688467548953
muon_momentum = 0.932638804220204
muon_weight_decay = 0.0841106463026747

alpha = 0.205310994109178
warmup_ratio = 0.192282667612615

# lr = 3e-5
epochs = 4
train_batch_size = 16
validation_batch_size = 256

In [None]:
# TODO1: Create batched data for DataLoader
# `collate_fn` is a function that defines how the data batch should be packed.
# This function will be called in the DataLoader to pack the data batch.


def collate_fn(batch):
    # TODO1-1: Implement the collate_fn function
    # Write your code here
    # The input parameter is a data batch (tuple), and this function packs it into tensors.
    # Use tokenizer to pack tokenize and pack the data and its corresponding labels.
    # Return the data batch and labels for each sub-task.
    pair_ids = [item["sentence_pair_id"] for item in batch]
    premises = [item["premise"] for item in batch]
    hypotheses = [item["hypothesis"] for item in batch]
    relatedness_labels = [item["relatedness_score"] for item in batch]
    entailment_labels = [item["entailment_judgment"] for item in batch]

    encoded = tokenizer(
        premises,
        hypotheses,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )

    relatedness_tensor = torch.tensor(relatedness_labels, dtype=torch.float)
    entailment_tensor = torch.tensor(entailment_labels, dtype=torch.long)

    return {
        "sentence_pair_id": pair_ids,
        "premise": premises,
        "hypothesis": hypotheses,
        "input_ids": encoded["input_ids"],
        "attention_mask": encoded["attention_mask"],
        "token_type_ids": encoded["token_type_ids"],
        "relatedness_score": relatedness_tensor,
        "entailment_judgment": entailment_tensor,
    }


# TODO1-2: Define your DataLoader
# dl_train = # Write your code here
# dl_validation = # Write your code here
# dl_test = # Write your code here
dl_train = DataLoader(
    SemevalDataset(split="train"),
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=min(4, os.cpu_count()),
)
dl_validation = DataLoader(
    SemevalDataset(split="validation"),
    batch_size=validation_batch_size,
    shuffle=False,
    collate_fn=collate_fn,
)
dl_test = DataLoader(
    SemevalDataset(split="test"),
    batch_size=validation_batch_size,
    shuffle=False,
    collate_fn=collate_fn,
)

In [None]:
# TODO2: Construct your model
class MultiLabelModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Write your code here
        # Define what modules you will use in the model
        # Please use "google-bert/bert-base-uncased" model (https://huggingface.co/google-bert/bert-base-uncased)
        # Besides the base model, you may design additional architectures by incorporating linear layers, activation functions, or other neural components.
        # Remark: The use of any additional pretrained language models is not permitted.
        self.bert = BertModel.from_pretrained(
            "google-bert/bert-base-uncased", cache_dir="./cache/"
        )
        hidden_size = self.bert.config.hidden_size

        self.gating = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 1),
            torch.nn.Sigmoid(),
        )

        self.expert0 = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
        )

        self.expert1 = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
        )

        self.regression_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, 1),  # [1, 5]
        )

        self.classification_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.RMSNorm(hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, 3),  # 0, 1, 2
        )

    def forward(self, **kwargs):
        # Write your code here
        # Forward pass

        input_ids = kwargs["input_ids"]
        attention_mask = kwargs["attention_mask"]
        token_type_ids = kwargs["token_type_ids"]

        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        cls_representation = bert_output.last_hidden_state[:, 0, :]

        expert_weight = self.gating(cls_representation)
        expert0_features = self.expert0(cls_representation)
        expert1_features = self.expert1(cls_representation)

        shared_features = (
            expert_weight * expert0_features + (1 - expert_weight) * expert1_features
        )
        regression_output = self.regression_head(shared_features)
        classification_output = self.classification_head(shared_features)

        return {
            "relatedness_score": regression_output,
            "entailment_judgment": classification_output,
        }

In [None]:
# TODO3: Define your optimizer and loss function

model = MultiLabelModel().to(device)
# TODO3-1: Define your Optimizer
# optimizer = # Write your code here
muon_params = [
    p
    for layer in [
        model.bert,
        model.gating,
        model.expert0,
        model.expert1,
        model.regression_head,
        model.classification_head,
    ]
    for p in layer.parameters()
    if p.ndim >= 2
]

adamw_params = [
    p
    for layer in [
        model.bert,
        model.gating,
        model.expert0,
        model.expert1,
        model.regression_head,
        model.classification_head,
    ]
    for p in layer.parameters()
    if p.ndim < 2
]

optimizer = [
    torch.optim.Muon(
        muon_params,
        lr=muon_lr,
        weight_decay=muon_weight_decay,
        momentum=muon_momentum,
    ),
    torch.optim.AdamW(adamw_params, lr=adamw_lr, weight_decay=adamw_weight_decay),
]

num_training_steps = len(dl_train) * epochs
num_warmup_steps = int(num_training_steps * warmup_ratio)
scheduler = [
    get_linear_schedule_with_warmup(
        optimizer=optimizer[0],
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    ),
    get_linear_schedule_with_warmup(
        optimizer=optimizer[1],
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    ),
]


# TODO3-2: Define your loss functions (you should have two)
# Write your code here
criterion_regression = torch.nn.MSELoss()
criterion_classification = torch.nn.CrossEntropyLoss()


def consistency_loss(reg_scores, clf_logits):
    # reg_scores shape: [B, 1], clf_logits shape: [B, 3]
    device = reg_scores.device

    # --- reg_scores -> expected_reg_from_clf ---
    E_neutral = (1.5 * 451 + 2.5 * 615 + 3.5 * 1398 + 4.5 * 326) / 2790
    E_entail = (1.5 * 1 + 2.5 * 0 + 3.5 * 65 + 4.5 * 1338) / 1404
    E_contra = (1.5 * 0 + 2.5 * 59 + 3.5 * 496 + 4.5 * 157) / 712
    E_vec = torch.tensor([E_neutral, E_entail, E_contra], device=device)

    clf_probs = torch.softmax(clf_logits, dim=1)
    expected_reg_from_clf = (clf_probs * E_vec).sum(dim=1)  # Shape: [B]
    reg_consis_loss = torch.nn.functional.mse_loss(reg_scores, expected_reg_from_clf)

    # --- clf_logits -> expected_clf_from_reg ---
    p_1_2 = torch.tensor([451 / 452.0, 1 / 452.0, 0 / 452.0], device=device)
    p_2_3 = torch.tensor([615 / 674.0, 0 / 674.0, 59 / 674.0], device=device)
    p_3_4 = torch.tensor([1398 / 1959.0, 65 / 1959.0, 496 / 1959.0], device=device)
    p_4_5 = torch.tensor([326 / 1821.0, 1338 / 1821.0, 157 / 1821.0], device=device)

    mask_1_2 = ((reg_scores >= 1.0) & (reg_scores < 2.0)).unsqueeze(-1)
    mask_2_3 = ((reg_scores >= 2.0) & (reg_scores < 3.0)).unsqueeze(-1)
    mask_3_4 = ((reg_scores >= 3.0) & (reg_scores < 4.0)).unsqueeze(-1)
    mask_4_5 = (reg_scores >= 4.0).unsqueeze(-1)

    expected_clf_from_reg = (
        mask_1_2 * p_1_2 + mask_2_3 * p_2_3 + mask_3_4 * p_3_4 + mask_4_5 * p_4_5
    )  # Shape: [B, 3]

    clf_log_probs = torch.nn.functional.log_softmax(clf_logits, dim=1)
    clf_consis_loss = torch.nn.functional.kl_div(
        clf_log_probs, expected_clf_from_reg, reduction="batchmean"
    )

    return reg_consis_loss + clf_consis_loss


# scoring functions
psr = load("pearsonr")
acc = load("accuracy")
f1_metric = load("f1")

In [None]:
best_score = 0.0
for ep in range(epochs):
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model.train()
    # TODO4: Write the training loop
    # Write your code here
    # train your model
    # clear gradient
    # forward pass
    # compute loss
    # back-propagation
    # model optimization

    for batch in pbar:
        batch = {
            k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
            for k, v in batch.items()
        }
        optimizer[0].zero_grad()
        optimizer[1].zero_grad()

        outputs = model(**batch)

        loss_reg = criterion_regression(
            outputs["relatedness_score"].squeeze(), batch["relatedness_score"]
        )
        loss_clf = criterion_classification(
            outputs["entailment_judgment"], batch["entailment_judgment"]
        )

        if ep > 3:
            consis_loss = consistency_loss(
                outputs["relatedness_score"].squeeze(), outputs["entailment_judgment"]
            )
            loss = (1 - alpha) * (loss_reg + loss_clf) + alpha * consis_loss
        else:

            loss = 0.5 * (loss_reg + loss_clf)

        loss.backward()

        raw_grad_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                raw_grad_norm += param_norm.item() ** 2
        raw_grad_norm = raw_grad_norm**0.5

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer[0].step()
        optimizer[1].step()
        scheduler[0].step()
        scheduler[1].step()

        pbar.set_postfix(loss=loss.item())

    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()
    # TODO5: Write the evaluation loop
    # Write your code here
    # Evaluate your model
    # Output all the evaluation scores (PearsonCorr, Accuracy)
    # pearson_corr = # Write your code here
    # accuracy = # Write your code here
    # print(f"F1 Score: {f1.compute()}")
    with torch.no_grad():
        pearson_corr = 0
        accuracy = 0
        all_reg_preds = []
        all_reg_targets = []
        all_clf_preds = []
        all_clf_targets = []

        for batch in pbar:
            batch = {
                k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()
            }
            outputs = model(**batch)

            reg_pred = outputs["relatedness_score"].squeeze().cpu().numpy()
            reg_target = batch["relatedness_score"].cpu().numpy()
            all_reg_preds.extend(reg_pred)
            all_reg_targets.extend(reg_target)

            clf_pred = torch.argmax(outputs["entailment_judgment"], dim=1).cpu().numpy()
            clf_target = batch["entailment_judgment"].cpu().numpy()
            all_clf_preds.extend(clf_pred)
            all_clf_targets.extend(clf_target)

        pearson_result = psr.compute(
            predictions=all_reg_preds, references=all_reg_targets
        )
        pearson_corr = pearson_result["pearsonr"]

        accuracy_result = acc.compute(
            predictions=all_clf_preds, references=all_clf_targets
        )
        accuracy = accuracy_result["accuracy"]

        f1_macro = f1_metric.compute(
            predictions=all_clf_preds, references=all_clf_targets, average="macro"
        )["f1"]
        f1_weighted = f1_metric.compute(
            predictions=all_clf_preds, references=all_clf_targets, average="weighted"
        )["f1"]

        combined_score = 0.5 * pearson_corr + 0.5 * accuracy
        print(
            f"Epoch {ep+1}: Pearson={pearson_corr}, Accuracy={accuracy}, Macro-F1={f1_macro}, Weighted-F1={f1_weighted}, Combine={combined_score}"
        )

        if combined_score > best_score:
            best_score = combined_score
            torch.save(model.state_dict(), f"./saved_models/best_model.ckpt")

In [None]:
# Load the model
model = MultiLabelModel().to(device)
model.load_state_dict(torch.load(f"./saved_models/best_model.ckpt", weights_only=True))

# Test Loop
pbar = tqdm(dl_test, desc="Test")
model.eval()

# TODO6: Write the test loop
# Write your code here
# We have loaded the best model with the highest evaluation score for you
# Please implement the test loop to evaluate the model on the test dataset
# We will have 10% of the total score for the test accuracy and pearson correlation
with torch.no_grad():
    pearson_corr = 0
    accuracy = 0
    all_reg_preds = []
    all_reg_targets = []
    all_clf_preds = []
    all_clf_targets = []

    for batch in pbar:
        batch = {
            k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
            for k, v in batch.items()
        }
        outputs = model(**batch)

        reg_pred = outputs["relatedness_score"].squeeze().cpu().numpy()
        reg_target = batch["relatedness_score"].cpu().numpy()
        all_reg_preds.extend(reg_pred)
        all_reg_targets.extend(reg_target)

        clf_pred = torch.argmax(outputs["entailment_judgment"], dim=1).cpu().numpy()
        clf_target = batch["entailment_judgment"].cpu().numpy()
        all_clf_preds.extend(clf_pred)
        all_clf_targets.extend(clf_target)

        batch_size = batch["input_ids"].shape[0]
        pair_ids = batch["sentence_pair_id"]

    pearson_result = psr.compute(predictions=all_reg_preds, references=all_reg_targets)
    pearson_corr = pearson_result["pearsonr"]

    accuracy_result = acc.compute(predictions=all_clf_preds, references=all_clf_targets)
    accuracy = accuracy_result["accuracy"]

    f1_macro = f1_metric.compute(
        predictions=all_clf_preds, references=all_clf_targets, average="macro"
    )["f1"]
    f1_weighted = f1_metric.compute(
        predictions=all_clf_preds, references=all_clf_targets, average="weighted"
    )["f1"]

    combined_score = 0.5 * pearson_corr + 0.5 * accuracy
    print(
        f"\nTest: Pearson={pearson_corr}, Accuracy={accuracy}, Macro-F1={f1_macro}, Weighted-F1={f1_weighted}, Combine={combined_score}"
    )