In [1]:
import datasets
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from transformers import BertModel, BertTokenizer


In [2]:
dataset = datasets.load_dataset("stanfordnlp/snli")
dataset.keys()

dict_keys(['test', 'validation', 'train'])

In [3]:
train = dataset["train"]
validation = dataset["validation"]
test = dataset["test"]

In [4]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(device)

model.to(device)
len(train), len(validation), len(test)

mps


(550152, 10000, 10000)

In [5]:
val_entailment = validation.filter(lambda x: x["label"] == 0)
val_neutral = validation.filter(lambda x: x["label"] == 1)
val_contradiction = validation.filter(lambda x: x["label"] == 2)

len(val_entailment), len(val_neutral), len(val_contradiction)

(3329, 3235, 3278)

In [6]:
def extract_emb(dataset, batch_size=32):
    texts = [data["premise"] + " " + data["hypothesis"] for data in dataset]
    total_samples = len(texts)

    result = torch.zeros(total_samples, 768)

    for i in range(0, total_samples, batch_size):
        batch_texts = texts[i : i + batch_size]
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512,
        )
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu()

        result[i : i + len(batch_texts), :] = cls_embeddings

    return result


val_entailment_emb = extract_emb(val_entailment)
val_neutral_emb = extract_emb(val_neutral)
val_contradiction_emb = extract_emb(val_contradiction)

In [7]:
val_contradiction_emb.shape


torch.Size([3278, 768])

In [8]:
pca = PCA(n_components=64)

val_entailment_emb_pca = pca.fit_transform(val_entailment_emb)
print(sum(pca.explained_variance_ratio_))
val_neutral_emb_pca = pca.fit_transform(val_neutral_emb)
print(sum(pca.explained_variance_ratio_))
val_contradiction_emb_pca = pca.fit_transform(val_contradiction_emb)
print(sum(pca.explained_variance_ratio_))

0.7980432920938139
0.7975950572671607
0.8020781574625703
0.8020781574625703


In [9]:
val_entailment_emb_pca.shape

(3329, 64)

In [10]:
k_clusters = int(
    sum((len(val_entailment), len(val_neutral), len(val_contradiction)))
    * 0.02
    // (2 * 3 - 1)
)
k_clusters

39

In [11]:
kmeans = KMeans(n_clusters=k_clusters, random_state=42)

val_entailment_labels = kmeans.fit_predict(val_entailment_emb)
val_neutral_labels = kmeans.fit_predict(val_neutral_emb)
val_contradiction_labels = kmeans.fit_predict(val_contradiction_emb)

In [12]:
def get_borda_score(labels):
    # TODO: 根据borda ranking计算borda score
    raise NotImplementedError


# TODO: 选出每个类中每个簇的borda score较低的样本，视为outliers
