In [None]:
import numpy as np
import pickle

with open("all_results_mnli.pkl", "rb") as f:
    results = pickle.load(f)

In [None]:
import torch
import torch.nn as nn
import random 

class Ensemble(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

        # self.fc1 = nn.Linear(3, 5)
        # self.fc2 = nn.Linear(5, 1)
        self.fc2 = nn.Linear(3, 1)
    
    def forward(self, x):
        # x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def set_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    return torch.Generator().manual_seed(random_seed)


In [None]:
X = torch.tensor(results, dtype=torch.float32)

In [None]:
set_seed(42)
indices = np.random.choice(X.shape[2], X.shape[2], replace=False)
X = X[:, :, indices]

X_train = X[0, :, :1000]
X_test = X[0, :, 1000:]

In [None]:
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

In [None]:
all_f1 = []
all_precisions = []
all_recalls = []
all_accuracies = []

for runs in range(100):
    all_losses = []
    set_seed(runs)
    indices = np.random.choice(X.shape[2], X.shape[2], replace=False)
    X = X[:, :, indices]


    n_train = 500
    n_epochs = 500

    for exp in range(X.shape[0]):


        X_train = X[exp, :, :n_train]
        X_test = X[exp, :, n_train:]
        
        losses = []

        ensemble = Ensemble()
        opt = Adam(ensemble.parameters(), lr=0.01)
        loss_fn = nn.BCEWithLogitsLoss()

        for epoch in range(n_epochs):
            opt.zero_grad()
            y_pred = ensemble(X_train.T[:, :3]).flatten()
            loss = loss_fn(y_pred, X_train.T[:, 3])
            loss.backward()
            opt.step()
            losses.append(loss.item())
        
        all_losses.append(losses)

        with torch.no_grad():
            y_pred = ensemble(X_test.T[:, :3]).flatten().detach()
            y_pred = torch.sigmoid(y_pred)
            y_pred = (y_pred > 0.5).float().numpy()

            y_true = (X_test.T[:, 3] > 0.5).numpy()

        f1 = [
            f1_score(y_true, (X_test.T[:, 0] > 0.5).numpy()),
            f1_score(y_true, (X_test.T[:, 1] > 0.5).numpy()),
            f1_score(y_true, (X_test.T[:, 2] > 0.5).numpy()),
            f1_score(y_true, y_pred)
        ]
        all_f1.append(f1)
        precision = [
            precision_score(y_true, (X_test.T[:, 0] > 0.5).numpy()),
            precision_score(y_true, (X_test.T[:, 1] > 0.5).numpy()),
            precision_score(y_true, (X_test.T[:, 2] > 0.5).numpy()),
            precision_score(y_true, y_pred)
        ]
        all_precisions.append(precision)
        recall = [
            recall_score(y_true, (X_test.T[:, 0] > 0.5).numpy()),
            recall_score(y_true, (X_test.T[:, 1] > 0.5).numpy()),
            recall_score(y_true, (X_test.T[:, 2] > 0.5).numpy()),
            recall_score(y_true, y_pred)
        ]
        all_recalls.append(recall)
        accuracy = [
            accuracy_score(y_true, (X_test.T[:, 0] > 0.5).numpy()),
            accuracy_score(y_true, (X_test.T[:, 1] > 0.5).numpy()),
            accuracy_score(y_true, (X_test.T[:, 2] > 0.5).numpy()),
            accuracy_score(y_true, y_pred)
        ]
        all_accuracies.append(accuracy)




all_f1 = np.array(all_f1)
all_precisions = np.array(all_precisions)
all_recalls = np.array(all_recalls)
all_accuracies = np.array(all_accuracies)

print("F1")
print("Baseline: {:.4f} +- {:.4f}".format(np.mean(all_f1[:, 0]), np.std(all_f1[:, 0])))
print("crlft: {:.4f} +- {:.4f}".format(np.mean(all_f1[:, 1]), np.std(all_f1[:, 1])))
print("sequential: {:.4f} +- {:.4f}".format(np.mean(all_f1[:, 2]), np.std(all_f1[:, 2])))
print("ensemble: {:.4f} +- {:.4f}".format(np.mean(all_f1[:, 3]), np.std(all_f1[:, 3])))

print("Precision")
print("Baseline: {:.4f} +- {:.4f}".format(np.mean(all_precisions[:, 0]), np.std(all_precisions[:, 0])))
print("crlft: {:.4f} +- {:.4f}".format(np.mean(all_precisions[:, 1]), np.std(all_precisions[:, 1])))
print("sequential: {:.4f} +- {:.4f}".format(np.mean(all_precisions[:, 2]), np.std(all_precisions[:, 2])))
print("ensemble: {:.4f} +- {:.4f}".format(np.mean(all_precisions[:, 3]), np.std(all_precisions[:, 3])))

print("Recall")
print("Baseline: {:.4f} +- {:.4f}".format(np.mean(all_recalls[:, 0]), np.std(all_recalls[:, 0])))
print("crlft: {:.4f} +- {:.4f}".format(np.mean(all_recalls[:, 1]), np.std(all_recalls[:, 1])))
print("sequential: {:.4f} +- {:.4f}".format(np.mean(all_recalls[:, 2]), np.std(all_recalls[:, 2])))
print("ensemble: {:.4f} +- {:.4f}".format(np.mean(all_recalls[:, 3]), np.std(all_recalls[:, 3])))

print("Accuracy")
print("Baseline: {:.4f} +- {:.4f}".format(np.mean(all_accuracies[:, 0]), np.std(all_accuracies[:, 0])))
print("crlft: {:.4f} +- {:.4f}".format(np.mean(all_accuracies[:, 1]), np.std(all_accuracies[:, 1])))
print("sequential: {:.4f} +- {:.4f}".format(np.mean(all_accuracies[:, 2]), np.std(all_accuracies[:, 2])))
print("ensemble: {:.4f} +- {:.4f}".format(np.mean(all_accuracies[:, 3]), np.std(all_accuracies[:, 3])))

In [None]:
import matplotlib.pyplot as plt
for losses in all_losses:
    plt.plot(losses)