In [1]:
import sys
sys.path.append("../code")
from dataset import build_dataset, build_dataloader
from config import DataArguments, TrainerArguments, ModelArguments

In [2]:
data_args = DataArguments()
training_args = TrainerArguments()
model_args = ModelArguments()

data_args.data_dir = "../data/"
data_args.asset_dir = "../assets/"
model_args.asset_dir = data_args.asset_dir
data_args.seed = training_args.seed
data_args.max_seq_len = model_args.max_seq_len

data_args.init_pct = 1

train_dataset, model_args.vocab_size, model_args.num_labels = build_dataset(data_args, "train")
train_dataloader = build_dataloader(train_dataset, data_args)

[11/24/2021 13:37:50] INFO - dataset: Initialize Train Dataset.
[11/24/2021 13:37:50] INFO - dataset: Remove abstract.
[11/24/2021 13:37:50] INFO - dataset: Using cached dataset, wasn't able to remove columns.
[11/24/2021 13:37:50] INFO - dataset: Use the full dataset, for train dataset of total 47250 papers.
[11/24/2021 13:37:50] INFO - dataset: Train dataset was successfully initialized.
[11/24/2021 13:37:50] INFO - dataset: Successfully loaded mapper file ..\assets\area2idx.json
[11/24/2021 13:37:50] INFO - dataset: Use tokenized_paperswithcode dataset.
[11/24/2021 13:37:50] INFO - dataset: Preprocessed dataset. Use default vocab_size=83931.
[11/24/2021 13:37:50] INFO - dataset: Successfully converted dataset to dataloader.


In [None]:
from trainer import NaiveTrainer
trainer = NaiveTrainer()

In [6]:
import torch
import torch.nn as nn

loss_fc = nn.BCEWithLogitsLoss()

In [7]:
loss_fc()

tensor([ 0.,  7., 13., 10.,  7.,  7.,  7., 10.,  2., 15.,  7., 14.,  7., 15.,
        10., 10.,  7.,  4., 14., 14.,  8.,  7.,  3., 14., 12., 10.,  8., 12.,
        10.,  7., 15.,  7.,  9.,  7., 14.,  3., 12., 15.,  4., 15., 14.,  7.,
        10.,  7.,  7.,  7.,  4.,  5.,  4., 12.,  3.,  7., 15.,  3., 14., 12.,
        12., 10.,  7., 12.,  3.,  3., 15.,  7.,  8., 15.,  3.,  7.,  2., 14.,
         3.,  7.,  3., 10., 12.,  3.,  7., 15., 14., 15., 13.,  7.,  4., 12.,
         7.,  7.,  4., 14.,  3., 10.,  8.,  0., 15.,  7.,  3.,  7.,  8.,  7.,
        15.,  7., 15., 14., 15., 15.,  4.,  3., 14., 12., 15.,  3.,  1.,  5.,
         3.,  3., 10.,  4., 10., 10.,  7.,  6., 10.,  7.,  3.,  5.,  6., 10.,
        10., 11.])

In [None]:
loss = loss_fc()

In [110]:
import math
from dataclasses import dataclass
from typing import List

import torch
from toma import toma
from tqdm.auto import tqdm

from batchbald_redux import joint_entropy

K = 20

In [104]:
import numpy as np


def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

In [111]:
def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        nats_n_K_C[probs_n_K_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1)
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
        nats_n_C[mean_probs_n_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [109]:
!pip install toma

