# Initial Prototype

In [6]:
import math

from comp550.dataset import SSTDataset

from tqdm import tqdm

In [7]:
dataset = SSTDataset(cachedir="cache", num_workers=0)
dataset.prepare_data()
dataset.setup("fit")

In [11]:
def _compute_frequency_not_in_class_and_not_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataset.dataloader('train', shuffle=False):
        for observation in dataset.uncollate(batch):
            if token_id not in observation['sentence'] and observation['label'] != class_:
                frequency += 1
            total += 1
    
    return frequency / total

def _compute_frequency_in_class_and_not_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataset.dataloader('train', shuffle=False):
        for observation in dataset.uncollate(batch):
            if token_id not in observation['sentence'] and observation['label'] == class_:
                frequency += 1
            total += 1
    
    return frequency / total

def _compute_frequency_not_in_class_and_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataset.dataloader('train', shuffle=False):
        for observation in dataset.uncollate(batch):
            if token_id in observation['sentence'] and observation['label'] != class_:
                frequency += 1
            total += 1
    
    return frequency / total

def _compute_frequency_in_class_and_in_document(dataloader, token_id, class_, k=0):
    total = k
    frequency = k
    
    for batch in dataset.dataloader('train', shuffle=False):
        for observation in dataset.uncollate(batch):
            if token_id in observation['sentence'] and observation['label'] == class_:
                frequency += 1
            total += 1
    
    return frequency / total

In [17]:
# For prototyping, assume we are only analyzing class 0 in SST (i.e., negative).
class_ = 0

for token_id, token in zip(range(10), tqdm(dataset.tokenizer.token_to_ids)): 
    # Compute probability of all joint random variable configurations.
    p_not_in_class_and_not_in_document = _compute_frequency_not_in_class_and_not_in_document(
        dataset, token_id, class_=class_
    )
    # p_not_in_class_and_not_in_document = 774106 / 801948
    p_in_class_and_not_in_document = _compute_frequency_in_class_and_not_in_document(
        dataset, token_id, class_=class_
    )
    # p_in_class_and_not_in_document = 141 / 801948
    p_not_in_class_and_in_document = _compute_frequency_not_in_class_and_in_document(
        dataset, token_id, class_=class_
    )
    # p_not_in_class_and_in_document = 27652 / 801948
    p_in_class_and_in_document = _compute_frequency_in_class_and_in_document(
        dataset, token_id, class_=class_
    )
    # p_in_class_and_in_document = 49 / 801948

    # Compute marginal distributions.
    p_not_in_document = p_not_in_class_and_not_in_document + p_in_class_and_not_in_document
    p_in_document = p_not_in_class_and_in_document + p_in_class_and_in_document
    
    p_not_in_class = p_not_in_class_and_not_in_document + p_not_in_class_and_in_document
    p_in_class = p_in_class_and_not_in_document + p_in_class_and_in_document
    
    try:
        importance = (
            p_in_class_and_in_document * math.log2(p_in_class_and_in_document / (p_in_class * p_in_document)) +
            p_not_in_class_and_in_document * math.log2(p_not_in_class_and_in_document / (p_not_in_class * p_in_document)) +
            p_in_class_and_not_in_document * math.log2(p_in_class_and_not_in_document / (p_in_class * p_not_in_document)) +
            p_not_in_class_and_not_in_document * math.log2(p_not_in_class_and_not_in_document / (p_not_in_class * p_not_in_document))
        )
    except ZeroDivisionError:
        importance = float('nan')

    print(f"Token: {token}")
    print(f"word = 0, class = 0: {p_not_in_class_and_not_in_document}")
    print(f"word = 0, class = 1: {p_in_class_and_not_in_document}")
    print(f"word = 1, class = 0: {p_not_in_class_and_in_document}")
    print(f"word = 1, class = 1: {p_in_class_and_in_document}")
    print(f"class = 1: {p_in_class}")
    print(f"class = 0: {p_not_in_class}")
    print(f"word = 1: {p_in_document}")
    print(f"word = 0: {p_not_in_document}")
    print(f"Importance: {importance}")
    

  0%|          | 1/13766 [00:01<6:51:48,  1.80s/it]

Token: [PAD]
word = 0, class = 0: 0.5242438060495516
word = 0, class = 1: 0.4757561939504484
word = 1, class = 0: 0.0
word = 1, class = 1: 0.0
class = 1: 0.4757561939504484
Importance: nan


  0%|          | 2/13766 [00:03<5:53:48,  1.54s/it]

Token: [CLS]
word = 0, class = 0: 0.0
word = 0, class = 1: 0.0
word = 1, class = 0: 0.5242438060495516
word = 1, class = 1: 0.4757561939504484
class = 1: 0.4757561939504484
Importance: nan


  0%|          | 3/13766 [00:04<6:19:02,  1.65s/it]

Token: [EOS]
word = 0, class = 0: 0.0
word = 0, class = 1: 0.0
word = 1, class = 0: 0.5242438060495516
word = 1, class = 1: 0.4757561939504484
class = 1: 0.4757561939504484
Importance: nan


  0%|          | 4/13766 [00:06<6:20:44,  1.66s/it]

Token: [MASK]
word = 0, class = 0: 0.5242438060495516
word = 0, class = 1: 0.4757561939504484
word = 1, class = 0: 0.0
word = 1, class = 1: 0.0
class = 1: 0.4757561939504484
Importance: nan


  0%|          | 5/13766 [00:08<6:00:43,  1.57s/it]

Token: [UNK]
word = 0, class = 0: 0.5242438060495516
word = 0, class = 1: 0.4757561939504484
word = 1, class = 0: 0.0
word = 1, class = 1: 0.0
class = 1: 0.4757561939504484
Importance: nan


  0%|          | 6/13766 [00:09<5:48:11,  1.52s/it]

Token: and
word = 0, class = 0: 0.25429396564827483
word = 0, class = 1: 0.28058975528195773
word = 1, class = 0: 0.2699498404012768
word = 1, class = 1: 0.19516643866849065
class = 1: 0.4757561939504484
Importance: 0.007947550128166982


  0%|          | 7/13766 [00:10<5:40:10,  1.48s/it]

Token: a
word = 0, class = 0: 0.25429396564827483
word = 0, class = 1: 0.24912600699194407
word = 1, class = 0: 0.2699498404012768
word = 1, class = 1: 0.22663018695850434
class = 1: 0.47575619395044844
Importance: 0.001071146203936061


  0%|          | 7/13766 [00:12<6:42:28,  1.76s/it]


ValueError: math domain error