# Initial Prototype

In [1]:
import math

from comp550.dataset import SSTDataset

from tqdm import tqdm

In [2]:
dataset = SSTDataset(cachedir="cache", num_workers=4)
dataset.prepare_data()
dataset.setup("fit")

In [3]:
def _compute_frequency_not_in_class_and_not_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataloader:
        for sentence, label in zip(batch.sentence.tolist(), batch.label.tolist()):
            if token_id not in sentence and label != class_:
                frequency += 1
            total += 1
    
    return frequency

def _compute_frequency_in_class_and_not_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataloader:
        for sentence, label in zip(batch.sentence.tolist(), batch.label.tolist()):
            if token_id not in sentence and label == class_:
                frequency += 1
            total += 1
    
    return frequency

def _compute_frequency_not_in_class_and_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataloader:
        for sentence, label in zip(batch.sentence.tolist(), batch.label.tolist()):
            if token_id in sentence and label != class_:
                frequency += 1
            total += 1
    
    return frequency

def _compute_frequency_in_class_and_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataloader:
        for sentence, label in zip(batch.sentence.tolist(), batch.label.tolist()):
            if token_id in sentence and label == class_:
                frequency += 1
            total += 1
    
    return frequency

In [4]:
# For prototyping, assume we are only analyzing class 0 in SST (i.e., negative).
class_ = 0

for token_id, token in enumerate(tqdm(dataset.tokenizer.token_to_ids)): 
    # Compute probability of all joint random variable configurations.
    p_not_in_class_and_not_in_document = _compute_frequency_not_in_class_and_not_in_document(
        dataset.train_dataloader(shuffle=False), token_id, class_=class_
    )
    # p_not_in_class_and_not_in_document = 774106 / 801948
    p_in_class_and_not_in_document = _compute_frequency_in_class_and_not_in_document(
        dataset.train_dataloader(shuffle=False), token_id, class_=class_
    )
    # p_in_class_and_not_in_document = 141 / 801948
    p_not_in_class_and_in_document = _compute_frequency_not_in_class_and_in_document(
        dataset.train_dataloader(shuffle=False), token_id, class_=class_
    )
    # p_not_in_class_and_in_document = 27652 / 801948
    p_in_class_and_in_document = _compute_frequency_in_class_and_in_document(
        dataset.train_dataloader(shuffle=False), token_id, class_=class_
    )
    # p_in_class_and_in_document = 49 / 801948
    
    # Compute marginal distributions.
    p_not_in_document = p_not_in_class_and_not_in_document + p_in_class_and_not_in_document
    p_in_document = p_not_in_class_and_in_document + p_in_class_and_in_document
    
    p_not_in_class = p_not_in_class_and_not_in_document + p_not_in_class_and_in_document
    p_in_class = p_in_class_and_not_in_document + p_in_class_and_in_document
    
    importance = (
        p_in_class_and_in_document * math.log2(p_in_class_and_in_document / (p_in_class * p_in_document)) +
        p_not_in_class_and_in_document * math.log2(p_not_in_class_and_in_document / (p_not_in_class * p_in_document)) +
        p_in_class_and_not_in_document * math.log2(p_in_class_and_not_in_document / (p_in_class * p_not_in_document)) +
        p_not_in_class_and_not_in_document * math.log2(p_not_in_class_and_not_in_document / (p_not_in_class * p_not_in_document))
    )
    
    print(f"Token: {token}")
    print(f"U = 0, C = 0: {p_not_in_class_and_not_in_document}")
    print(f"U = 0, C = 1: {p_in_class_and_not_in_document}")
    print(f"U = 1, C = 0: {p_not_in_class_and_in_document}")
    print(f"U = 1, C = 1: {p_in_class_and_in_document}")
    print(f"Importance: {importance}")
    print(
        "Total count = ", 
          p_not_in_class_and_not_in_document + 
          p_in_class_and_not_in_document + 
          p_not_in_class_and_in_document + 
          p_in_class_and_in_document
         )
    
    break

  0%|          | 0/13766 [00:01<?, ?it/s]

Token: [PAD]
U = 0, C = 0: 138
U = 0, C = 1: 90
U = 1, C = 0: 3311
U = 1, C = 1: 3040
Importance: -83441.22408864423
Total count =  6579





# Implementation

In [1]:
import math
from collections import Counter

from comp550.dataset import SSTDataset

from tqdm import tqdm

In [2]:
# Load the dataset.
dataset = SSTDataset(cachedir="cache")
dataset.prepare_data()
dataset.setup("fit")

In [23]:
def _compute_token_frequencies(dataloader, tokenizer, class_, in_example_frequency=True, k=1):
    """Computes the frequency of all tokens in a split (e.g., train) for a particular class
    using Laplace smoothing."""
    vocab_token_ids = set(range(len(tokenizer.token_to_ids)))
    
    # Initalize token counts for add-k smoothing.
    frequencies = Counter()
    for token_id in vocab_token_ids:
        frequencies[token_id] = k
    total = k * len(vocab_token_ids)
    
    for batch in dataloader:
        for sentence, label in zip(batch.sentence.tolist(), batch.label.tolist()):
            if label != class_:
                continue
            
            if in_example_frequency:
                frequencies.update(set(sentence))
            else:
                frequencies.update(vocab_token_ids - set(sentence))
            
            total += 1
    
    return frequencies

frequencies = _compute_token_frequencies(
    dataloader=dataset.train_dataloader(shuffle=False), 
    tokenizer=dataset.tokenizer, 
    class_=1
)

In [25]:
# Each possible RV configuration.
configs = [
    (class_, in_example_frequency) 
    for class_ in range(len(dataset.label_names)) 
    for in_example_frequency in [True, False]
]

frequencies = {}
for config in configs:
    class_, in_example_frequency = config
    
    frequencies[(class_, in_example_frequency)] = _compute_token_frequencies(
        dataloader=dataset.train_dataloader(shuffle=False),
        tokenizer=dataset.tokenizer,
        class_=class_,
        in_example_frequency=in_example_frequency
    )