In [1]:
import datasets
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from transformers import BertModel, BertTokenizer


In [None]:
dataset = datasets.load_dataset("stanfordnlp/snli")
dataset.keys()

In [3]:
train = dataset["train"]
validation = dataset["validation"]
test = dataset["test"]

In [None]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(device)

model.to(device)
len(train), len(validation), len(test)

In [None]:
test_entailment = test.filter(lambda x: x["label"] == 0)
test_neutral = test.filter(lambda x: x["label"] == 1)
test_contradiction = test.filter(lambda x: x["label"] == 2)

validation = validation.filter(lambda x: x["label"] != -1)

len(test_entailment), len(test_neutral), len(test_contradiction), len(validation)

In [7]:
def extract_emb(dataset, batch_size=32):
    texts = [data["premise"] + " " + data["hypothesis"] for data in dataset]
    total_samples = len(texts)

    result = torch.zeros(total_samples, 768)

    for i in range(0, total_samples, batch_size):
        batch_texts = texts[i : i + batch_size]
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512,
        )
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu()

        result[i : i + len(batch_texts), :] = cls_embeddings

    return result


test_entailment_emb = extract_emb(test_entailment)
test_neutral_emb = extract_emb(test_neutral)
test_contradiction_emb = extract_emb(test_contradiction)

validation_emb = extract_emb(validation)

In [None]:
test_contradiction_emb.shape


In [None]:
pca = PCA(n_components=64)

test_entailment_emb_pca = pca.fit_transform(test_entailment_emb)
print(sum(pca.explained_variance_ratio_))
test_neutral_emb_pca = pca.fit_transform(test_neutral_emb)
print(sum(pca.explained_variance_ratio_))
test_contradiction_emb_pca = pca.fit_transform(test_contradiction_emb)
print(sum(pca.explained_variance_ratio_))

validation_emb_pca = pca.fit_transform(validation_emb)
print(sum(pca.explained_variance_ratio_))

In [None]:
test_entailment_emb_pca.shape

In [None]:
k_clusters = int(
    sum((len(test_entailment), len(test_neutral), len(test_contradiction)))
    * 0.02
    // (2 * 3 - 1)
)
k_clusters

In [12]:
kmeans = KMeans(n_clusters=k_clusters, random_state=42)

test_entailment_labels = kmeans.fit_predict(test_entailment_emb)
test_neutral_labels = kmeans.fit_predict(test_neutral_emb)
test_contradiction_labels = kmeans.fit_predict(test_contradiction_emb)

In [None]:
# flag = True

n_models = 5

lstms = {}

for i in range(n_models):
    lstm = torch.nn.LSTM(64, 3, 3)
    lstm.to(device)
    optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-5)
    criterion = torch.nn.CrossEntropyLoss()

    best_accuracy = 0

    for epoch in range(100):
        total_loss = 0
        for i in range(0, len(validation_emb_pca), 32):
            batch = validation_emb_pca[i : i + 32]
            batch = torch.tensor(batch, dtype=torch.float32).to(device)
            labels = torch.tensor(validation[i : i + 32]["label"]).to(device)

            optimizer.zero_grad()
            output, _ = lstm(batch.unsqueeze(1))

            loss = criterion(output.squeeze(), labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        print(
            f"Model: {i + 1} *** Epoch {epoch} *** Loss: {total_loss / (len(validation_emb_pca) / 32)}"
        )
        # Test
        correct = 0
        total = 0
        for i in range(0, len(validation_emb_pca), 32):
            batch = validation_emb_pca[i : i + 32]
            batch = torch.tensor(batch, dtype=torch.float32).to(device)
            labels = torch.tensor(validation[i : i + 32]["label"]).to(device)

            output, _ = lstm(batch.unsqueeze(1))
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        if 100 * correct / total > best_accuracy:
            best_accuracy = 100 * correct / total
            lstms[i] = lstm.state_dict()

        print(f"Accuracy: {100 * correct / total:.4f}")
    print(f"Best accuracy: {best_accuracy:.4f}")


confidence = np.zeros((len(test), n_models))

for i in range(n_models):
    lstm = torch.nn.LSTM(64, 3, 3)
    lstm.load_state_dict(lstms[i])
    lstm.to(device)
    lstm.eval()
    with torch.no_grad():
        for j in range(0, len(test), 32):
            batch = test_entailment_emb_pca[j : j + 32]
            batch = torch.tensor(batch, dtype=torch.float32).to(device)
            labels = test[j : j + 32]["label"]

            output, _ = lstm(batch.unsqueeze(1))

            confidence[j : j + 32, i] = (
                torch.nn.functional.softmax(output, dim=2)[:, labels]
                .cpu()
                .detach()
                .numpy()
            )


In [None]:
test_entailment_cnt = np.zeros(k_clusters)
test_neutral_cnt = np.zeros(k_clusters)
test_contradiction_cnt = np.zeros(k_clusters)

for i in range(k_clusters):
    test_entailment_cnt[i] = np.sum(test_entailment_labels == i)
    test_neutral_cnt[i] = np.sum(test_neutral_labels == i)
    test_contradiction_cnt[i] = np.sum(test_contradiction_labels == i)

rankings = {
    "entailment": {},
    "neutral": {},
    "contradiction": {},
}

for i in range(k_clusters):
    rankings["entailment"][i] = {}
    rankings["neutral"][i] = {}
    rankings["contradiction"][i] = {}

    for j in range(n_models):
        rankings["entailment"][i][j] = confidence[test_entailment_labels == i, j]
        rankings["entailment"][i][j] = np.argsort(rankings["entailment"][i][j])

        rankings["neutral"][i][j] = confidence[test_neutral_labels == i, j]
        rankings["neutral"][i][j] = np.argsort(rankings["neutral"][i][j])

        rankings["contradiction"][i][j] = confidence[test_contradiction_labels == i, j]
        rankings["contradiction"][i][j] = np.argsort(rankings["contradiction"][i][j])


In [None]:
scores = {}

for i in range(k_clusters):
    scores["entailment"][i] = test_entailment_cnt[i] * n_models
    scores["neutral"][i] = test_neutral_cnt[i] * n_models
    scores["contradiction"][i] = test_contradiction_cnt[i] * n_models
    for j in range(n_models):
        scores["entailment"][i] -= rankings["entailment"][i][j]
        scores["neutral"][i] -= rankings["neutral"][i][j]
        scores["contradiction"][i] -= rankings["contradiction"][i][j]
