In [1]:
import datasets
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from transformers import BertModel, BertTokenizer


In [2]:
load_embeddings = True
load_models = True


In [3]:
dataset = datasets.load_dataset("stanfordnlp/snli")
dataset.keys()


In [4]:
train = dataset["train"]
validation = dataset["validation"]
test = dataset["test"]

train = train.filter(lambda x: x["label"] != -1)
validation = validation.filter(lambda x: x["label"] != -1)
test = test.filter(lambda x: x["label"] != -1)

In [5]:
test_entailment_ids = np.where(np.array(test["label"]) == 0)[0]
test_neutral_ids = np.where(np.array(test["label"]) == 1)[0]
test_contradiction_ids = np.where(np.array(test["label"]) == 2)[0]


In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(device)


In [7]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
batch_size = 256


model.to(device)
len(train), len(validation), len(test)


In [8]:
def extract_emb(dataset, _batch_size=batch_size):
    texts = [data["premise"] + " " + data["hypothesis"] for data in dataset]
    total_samples = len(texts)

    result = torch.zeros(total_samples, 768)

    for i in range(0, total_samples, _batch_size):
        batch_texts = texts[i : i + _batch_size]
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512,
        )
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu()

        result[i : i + len(batch_texts), :] = cls_embeddings

    return result


In [9]:
# Load embeddings
if load_embeddings:
    # test_entailment_emb = torch.from_numpy(np.load("test_entailment_emb.npy"))
    # test_neutral_emb = torch.from_numpy(np.load("test_neutral_emb.npy"))
    # test_contradiction_emb = torch.from_numpy(np.load("test_contradiction_emb.npy"))
    validation_emb = torch.from_numpy(np.load("validation_emb.npy"))
    train_emb = torch.from_numpy(np.load("train_emb.npy"))
    test_emb = torch.from_numpy(np.load("test_emb.npy"))
else:
    # test_entailment_emb = extract_emb(test_entailment)
    # test_neutral_emb = extract_emb(test_neutral)
    # test_contradiction_emb = extract_emb(test_contradiction)

    validation_emb = extract_emb(validation)
    train_emb = extract_emb(train)
    test_emb = extract_emb(test)

    # Save embeddings
    # np.save("test_entailment_emb.npy", test_entailment_emb.numpy())
    # np.save("test_neutral_emb.npy", test_neutral_emb.numpy())
    # np.save("test_contradiction_emb.npy", test_contradiction_emb.numpy())
    np.save("validation_emb.npy", validation_emb.numpy())
    np.save("train_emb.npy", train_emb.numpy())
    np.save("test_emb.npy", test_emb.numpy())


In [10]:
pca = PCA(n_components=64)

# test_entailment_emb_pca = pca.fit_transform(test_entailment_emb)
# # print(sum(pca.explained_variance_ratio_))
# test_neutral_emb_pca = pca.fit_transform(test_neutral_emb)
# # print(sum(pca.explained_variance_ratio_))
# test_contradiction_emb_pca = pca.fit_transform(test_contradiction_emb)
# # print(sum(pca.explained_variance_ratio_))

validation_emb_pca = pca.fit_transform(validation_emb)
# print(sum(pca.explained_variance_ratio_))
train_emb_pca = pca.fit_transform(train_emb)
test_emb_pca = pca.fit_transform(test_emb)


test_entailment_emb_pca = test_emb_pca[test_entailment_ids]
test_neutral_emb_pca = test_emb_pca[test_neutral_ids]
test_contradiction_emb_pca = test_emb_pca[test_contradiction_ids]

In [11]:
test_entailment_emb_pca.shape

In [12]:
n_models = 5

models = {}


class Classifier(torch.nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.gru = torch.nn.GRU(64, 64, 10, bidirectional=True)
        self.dropout = torch.nn.Dropout(0.1)
        self.fc1 = torch.nn.Linear(128, 64)
        self.activation = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(64, 3)

    def forward(self, x):
        x, _ = self.gru(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x


if load_models:
    for i in range(n_models):
        models[i] = torch.load(f"model_{i}.pt", map_location=device)
else:
    for i in range(n_models):
        model_ = Classifier()
        model_.to(device)
        optimizer = torch.optim.AdamW(model_.parameters(), lr=1e-5)
        criterion = torch.nn.CrossEntropyLoss()

        best_accuracy = 0
        best_loss = 10000

        early_stop_cnt = 0
        prev_acc = 0

        for epoch in range(100):
            train_loss = 0
            eval_loss = 0
            model_.train()
            for j in range(0, len(train_emb_pca), batch_size):
                batch = train_emb_pca[j : j + batch_size]
                batch = torch.tensor(batch, dtype=torch.float32).to(device)
                labels = torch.tensor(train[j : j + batch_size]["label"]).to(device)

                optimizer.zero_grad()
                outputs = model_(batch)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            print(
                f"Model: {i + 1} *** Epoch {epoch} *** Loss: {train_loss / (len(train_emb_pca) / batch_size)}"
            )

            correct = 0
            total = 0
            model.eval()

            with torch.no_grad():
                for j in range(0, len(validation_emb_pca), batch_size):
                    batch = validation_emb_pca[j : j + batch_size]
                    batch = torch.tensor(batch, dtype=torch.float32).to(device)
                    labels = torch.tensor(validation[j : j + batch_size]["label"]).to(
                        device
                    )

                    output = model_(batch)
                    loss = criterion(output, labels)
                    eval_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    total += labels.size(0)
                    # print(predicted)
                    # print(output.shape)
                    # print(np.argmax(predicted.cpu().detach().numpy(), axis=1).shape)
                    correct += (
                        (
                            predicted.cpu().detach().numpy()
                            == labels.cpu().detach().numpy()
                        )
                        .sum()
                        .item()
                    )

            # if eval_loss / (len(validation_emb_pca) / batch_size) < best_loss:
            #     best_loss = eval_loss / (len(validation_emb_pca) / batch_size)
            #     models[i] = model_.state_dict()
            acc = 100 * correct / total
            if acc <= prev_acc:
                early_stop_cnt += 1
            else:
                prev_acc = acc
                early_stop_cnt = 0

            if acc > best_accuracy:
                best_accuracy = acc
                models[i] = model_.state_dict()

            print(
                f"Model: {i + 1} *** Epoch {epoch} *** Eval Loss: {eval_loss / (len(validation_emb_pca) / batch_size)}"
            )
            print(f"Accuracy: {100 * correct / total:.4f}")
            if early_stop_cnt >= 5:
                break
        print(f"Best acc: {best_accuracy}")

    for i in range(n_models):
        torch.save(models[i], f"model_{i}.pt")


In [13]:
confidence = np.zeros((len(test), n_models))


for i in range(n_models):
    model_ = Classifier()
    model_.load_state_dict(models[i])
    model_.to(device)
    model_.eval()
    with torch.no_grad():
        for j in range(0, len(test), batch_size):
            batch = test_emb_pca[j : j + batch_size]
            batch = torch.tensor(batch, dtype=torch.float32).to(device)

            labels = torch.tensor(test[j : j + batch_size]["label"])

            output = torch.nn.functional.softmax(model_(batch), dim=1)

            num_out = output.shape[0]

            row_indices = torch.arange(num_out)

            confidence[j : j + num_out, i] = (
                output[row_indices, labels[:num_out]].cpu().detach().numpy()
            )


In [15]:
k_clusters = int(
    sum((len(test_entailment_ids), len(test_entailment_ids), len(test_entailment_ids)))
    * 0.02
    // (2 * 3 - 1)
)

kmeans = KMeans(n_clusters=k_clusters, random_state=42, n_init="auto")

test_entailment_labels = kmeans.fit_predict(test_entailment_emb_pca)
test_neutral_labels = kmeans.fit_predict(test_neutral_emb_pca)
test_contradiction_labels = kmeans.fit_predict(test_contradiction_emb_pca)


In [16]:
test_entailment_cnt = np.zeros(k_clusters)
test_neutral_cnt = np.zeros(k_clusters)
test_contradiction_cnt = np.zeros(k_clusters)

for i in range(k_clusters):
    test_entailment_cnt[i] = np.sum(test_entailment_labels == i)
    test_neutral_cnt[i] = np.sum(test_neutral_labels == i)
    test_contradiction_cnt[i] = np.sum(test_contradiction_labels == i)


In [41]:
rankings = {
    "entailment": {},
    "neutral": {},
    "contradiction": {},
}

for i in range(k_clusters):
    rankings["entailment"][i] = {}
    rankings["neutral"][i] = {}
    rankings["contradiction"][i] = {}

    for j in range(n_models):
        # print(confidence[np.where(test_entailment_labels == i), j][0])
        # print(rankings["entailment"][i][j])
        # print(len(test_entailment_ids[test_entailment_labels == i]))
        # print(confidence[test_entailment_ids[test_entailment_labels == i], j])
        rankings["entailment"][i][j] = confidence[
            test_entailment_ids[test_entailment_labels == i], j
        ]
        rankings["entailment"][i][j] = np.argsort(rankings["entailment"][i][j])

        rankings["neutral"][i][j] = confidence[
            test_neutral_ids[test_neutral_labels == i], j
        ]
        rankings["neutral"][i][j] = np.argsort(rankings["neutral"][i][j])

        rankings["contradiction"][i][j] = confidence[
            test_contradiction_ids[test_contradiction_labels == i], j
        ]
        rankings["contradiction"][i][j] = np.argsort(rankings["contradiction"][i][j])


In [46]:
scores = {
    "entailment": {},
    "neutral": {},
    "contradiction": {},
}

for i in range(k_clusters):
    scores["entailment"][i] = test_entailment_cnt[i] * n_models
    scores["neutral"][i] = test_neutral_cnt[i] * n_models
    scores["contradiction"][i] = test_contradiction_cnt[i] * n_models
    for j in range(n_models):
        scores["entailment"][i] -= rankings["entailment"][i][j]
        scores["neutral"][i] -= rankings["neutral"][i][j]
        scores["contradiction"][i] -= rankings["contradiction"][i][j]


In [50]:
scores["entailment"][0]

array([157., 110., 171., 376., 305., 359., 244., 233., 269., 238., 238.,
       310., 271., 204., 225., 239., 151., 224., 244., 254., 294., 260.,
       179., 242., 301., 226., 340., 205., 320., 290., 364., 132., 249.,
       275., 285., 316., 200., 306., 264., 266., 351., 299., 265., 407.,
       313., 244., 266., 282., 297., 421., 315., 176., 266., 343., 247.,
       292., 304., 286., 225., 186., 243., 323., 233., 256., 374., 187.,
       285., 429., 259., 418., 335., 399., 282., 260., 251., 391., 144.,
       265., 161., 327.,  96., 157., 338., 274., 244., 225., 227., 439.,
       329., 322., 263., 279., 294., 211., 262., 311., 295., 387., 121.,
       233., 148., 148., 359., 136., 296., 218.])