# Word attributes

### Definitions

<strong>Shannon information</strong>: amount of information gained when an event occurs which had some probability value associated with it</br>
* Mathematically: for some token $ x_i $ in a sequence $ X = \langle x_1, x_2, ... \rangle $ and its associated probability $ p(x_i) $, the information content of the token is given by $$ h(x_i) = -\log_{2}{p(x_i)} \text{ bits} $$<br>
* An token with $ p = 1 $ would yield an information content of $ 0 \text{ bits} $ — no new information is gained.<br>

<strong>Shannon entropy</strong>: average number of bits required to represent or transmit a message without losing any data<br>
* Mathematically: the entropy of a distribution $ P $ is given by the expected information content $$ H(X) = -\sum\limits_{x \in \mathcal{X}} {P(x) \log_2{P(x)}} \text{ bits} $$ where $ X $ is a discrete random variable that takes values in the alphabet $ \mathcal{X} $ and is distributed according to $ p : \mathcal{X} \rightarrow [0, 1] $
	* Note: In machine learning, $ \ln $ used rather than $ \log_{2} $ as in information theory
* <strong>Entropy is a function of a single distribution</strong> $ P $
* Equivalent to <strong>average information content</strong>


<strong>Cross-entropy:</strong> expected entropy under the <em>true</em> distribution $ P $ but drawn from <em>estimated</em> distribution $ Q $
* Plain English: Imagine we're sending encoded messages where the underlying data is drawn from a data-generating distribution $ P $ (true distribution), but while using an encoding scheme optimized for an estimated distribution $ Q $. The cross-entropy is the expected length of a message encoded according to  𝑄  but drawn according to  𝑃 .
* Mathematically: $$ H(P, Q) = -\sum\limits_{x \in \mathcal{X}} P(x) \log_{2}{Q(x)} $$
* Note: $ H(L) \leq H(L, M) $, i.e. the cross-entropy is bounded by the true entropy of the language 
* Expected message length according to $ Q $ but drawn from $ P $
* <strong>Cross-entropy is thus a function of both $ P $ and $ Q $</strong>

<strong>Perplexity</strong>: measures degree of uncertainty of a model in predicting (i.e. assigning probabilities to) text
* Mathematically: exponentiated cross-entropy between the data and model predictions $$ p = \exp{-\frac{1}{t} \sum\limits_{i}^{t} \log{p_{\theta}(x_i | x_{< i})}} $$
where $ \log{p_{\theta}(x_i | x_{< i})} $ is the log-likelihood of the $i$-th token conditioned on the preceding tokens $ x_{< i} $ according to our model
* Model's ability to predict uniformly among the set of specified tokens in a corpus
* Note: the tokenization procedure has a direct impact on a model's perplexity

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Ensure that necessary external libraries are installed
!pip install datasets
!pip install matplotlib
!pip install numpy
!pip install seaborn
!pip install torch
!pip install transformers
!pip install tqdm

### <strong>Language modeling</strong> with GPT2 and WikiText 103

In [None]:
import torch
from datasets import load_dataset
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: { device.upper() }')
model_checkpoint = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(
	model_checkpoint,
	add_special_tokens=True
)
model = GPT2LMHeadModel.from_pretrained(
	model_checkpoint
)

# Load dataset
test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')['text'] # Wikitext
print(f'Number of texts: { len(test) }')

In [None]:
from tqdm import tqdm

max_input_length = model.config.n_positions # GPT2
print(f'GPT2 model maximum input length: { max_input_length }')

# Model will have minimum <stride> context tokens when calculating conditional likelihood of a token
# (Provided there are <stride> preceding tokens available to condition on)
# These tokens are set to -100 so that they're ignored for the loss calculation
stride = 512

all_out = list()
num_examples = 4

for text in tqdm(test[:num_examples]):
	tmp = dict()
	
	if len(text.split()) < 10: # Ignore short texts
		continue

	# Convert prompt to model input IDS
	encodings = tokenizer(
		text=text,
		max_length=max_input_length-2, # subtract 2 because we want to add BOS and EOS tokens
		truncation=True,
		return_offsets_mapping=True
	)
	input_ids = torch.tensor(
		data=[tokenizer.bos_token_id]+encodings['input_ids']+[tokenizer.eos_token_id],
		device=model.device
	)

	with torch.no_grad(): # Suppress gradient calculation to speed up computation
		output = model(
			input_ids=input_ids,
			labels=input_ids
		)
	
	# logit(p) = log(p/(1-p)) transforms probability values => domain [0, 1] range [-inf, +inf]
	# shift logits: up to index -1 because we don't care about probability distribution over tokens after EOS
	logits_shifted = output['logits'][..., :-1, :].contiguous() # input ids x vocabulary

	# labels = ground truths taken from <tensor_input>
	# originally, logits[0] is the distribution conditioned on labels[0], but for logits we actually care about the probability conditioned on the preceding token
	# (as consistent with language modeling: what is the probability of token x_t given x_<t?)
	# => shift labels: align position 1 of the labels with position 0 of the logits => logits[0] should correspond to labels[1]
	labels_shifted = input_ids[..., 1:].contiguous() # input ids x 1

	# Here, multi-class cross entropy is equivalent to NLL calculation and is calculated in log-e
	# CE loss: H(P, Q) = -\sum(x in X) P(x) \log_{2}{Q(x)}, where Q(x) is the estimated distribution
	# Note: this calculation is equivalent to calculating probs * -1og_probs, then for each token, select column corresponding to vocab index (see sanity check below)
	ce_loss = torch.nn.functional.cross_entropy( # input ids x 1
		input=logits_shifted.view(-1, logits_shifted.size(-1)), # Predicted unnormalized logits
		target=labels_shifted.view(-1), # Ground truth class labels (one-hot encoded)
		reduction='none' # No reduction applied to cross entropy output
	)
	
	# Sanity check: mean CE loss should be close to automatic loss calculation
	loss = output['loss']
	# print(f'\nMean cross entropy: { mean_ce_loss }')
	# print(f'Automatic model loss calculation: { loss }')
	assert torch.isclose(
		input=torch.mean(ce_loss), # Loss defined as mean NLL
		other=loss
	)
	
	# Softmax to normalize logits
	probs = logits_shifted.softmax(dim=-1) # input ids x vocabulary
	token_probs = torch.tensor([prob[vocab_idx] for vocab_idx, prob in zip(labels_shifted, probs)]) # input ids x 1
	
	# Sanity check: probabilities across each encoding (row) sums to 1
	for idx, vocab_token in enumerate(probs):
		assert torch.isclose(
			input=torch.sum(vocab_token),
			other=torch.tensor(1.),
			rtol=0.01
		)
	# Sanity check: equivalently, sum of all elements in probabilities tensor should equal number of tokens
	sum_probs = torch.sum(probs, dtype=torch.float32)
	len_labels = torch.tensor([labels_shifted.size()[0]], dtype=torch.float32)
	# print(f'Sum of all elements in probabilities tensor: { sum_probs }')
	# print(f'Number of tokens: { len_labels }')
	assert torch.isclose(
		input=sum_probs,
		other=len_labels,
		rtol=0.01
	)

	# Sanity check: NLL of token at token index equals CE loss
	nlls = -1 * logits_shifted.log_softmax(dim=-1)
	ics = torch.tensor([nll[vocab_idx] for vocab_idx, nll in zip(labels_shifted, nlls)])
	assert torch.allclose(
		input=ics,
		other=ce_loss
	)

	# 
	per_token_entropy = probs * nlls # input ids x vocabulary
	per_token_entropy = torch.sum(per_token_entropy, dim=1) # input ids x 1, dim=1 means sum along each row

	# Sanity check: number of entropies calculated equals number of tokens minus 1 (shifted logits and labels)
	# print(f'Expected: { torch.tensor(input_ids.size()) - torch.tensor([1]) }')
	# print(f'Actual: { per_token_entropy.size() }\n')
	assert torch.tensor(per_token_entropy.size()) == torch.tensor(input_ids.size()) - torch.tensor([1])

	tmp['text'] = text
	tmp['total_tokens'] = labels_shifted.size()[0]
	tmp['tensor_input'] = input_ids
	tmp['logits'] = logits_shifted
	tmp['labels'] = labels_shifted
	tmp['probs'] = token_probs # probs but at the correct indices => input ids x 1
	tmp['nlls'] = nlls # input ids x vocabulary
	tmp['ics'] = ics # nlls but at the correct indices (equivalently ce_loss and per token NLL) => input ids x 1
	tmp['per_token_entropy'] = per_token_entropy

	all_out.append(tmp)

### Analysis

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
sns.set() # Set sns as default style

In [None]:
print(f'Number of texts processed: { len(all_out) }')

all_token_probs = torch.cat([text['probs'] for text in all_out])
all_ics = torch.cat([text['ics'] for text in all_out])
all_entropies = torch.cat([text['per_token_entropy'] for text in all_out])
assert all_ics.size() == all_entropies.size()

In [None]:
# Perplexity
mean_nll = torch.mean(all_ics)
perplexity = torch.exp(mean_nll)
print(f'Perplexity: { perplexity }')

In [None]:
#### NLL & ENTROPY
## Density plot
fig1, (ics, entropies) = plt.subplots(1, 2, sharex=True, sharey=True);
fig1.suptitle('Token attributes');

# Information content (NLL)
sns.kdeplot(
	ax=ics,
	data=all_ics,
	bw_method=0.5,
	color='green',
	fill=True,
	label='IC'
);
ics.set(title='Negative log likelihood');
ics.set(xlabel='Bits');
ics.set(xlim=(0));

# Entropy (expected IC)
sns.kdeplot(
	ax=entropies,
	data=all_entropies,
	bw_method=0.5,
	color='blue',
	fill=True,
	label='Entropy'
);
entropies.set(title='Entropy');
entropies.set(xlabel='Bits');
entropies.set(xlim=(0));
entropies.xaxis.set_major_locator(ticker.MultipleLocator(2));
entropies.xaxis.set_major_formatter(ticker.ScalarFormatter());

plt.show();


#### DIFF(NLL, ENTROPY)
diff_ic_ent = torch.sub(all_ics, all_entropies);
# density_diff_ic_ent = sns.kdeplot(
# 	data=diff_ic_ent,
# 	bw_method=1,
# 	color='red',
# 	fill=True,
# 	label='diff'
# 	# clip=(0, 100)
# )
# density_diff_ic_ent.set(
# 	title='diff(IC, entropy)',
# 	xlabel='Bits',
# 	ylabel='Density'
# )

# abs(diff)
abs_diff_ic_ent = torch.abs(diff_ic_ent);

density_abs_diff_ic_ent = sns.kdeplot(
	data=abs_diff_ic_ent,
	bw_method=1,
	color='orange',
	fill=True,
	label='abs. diff'
	# clip=(0, 100)
);

# sq(diff)
sq_diff_ic_ent = torch.square(diff_ic_ent)
density_sq_diff_ic_ent = sns.kdeplot(
	data=sq_diff_ic_ent,
	bw_method=1,
	color='yellow',
	fill=True,
	label='sq. diff'
	# clip=(0, 100)
);

xmax = max(sq_diff_ic_ent)+1;
density_sq_diff_ic_ent.set_title('Deviation of IC from expected IC');
density_sq_diff_ic_ent.legend(loc=0);
density_sq_diff_ic_ent.set(xlim=(0, xmax));
density_sq_diff_ic_ent.set(xlabel='Bits');

plt.show();

print(f'Mean abs diff: { torch.mean(abs_diff_ic_ent) }');
print(f'Median abs diff: { torch.median(abs_diff_ic_ent) }');
print(f'Max abs diff: { torch.max(abs_diff_ic_ent) }');
print(f'Min abs diff: { torch.min(abs_diff_ic_ent) }');

### <strong>Abstractive summarization</strong> with T5 and CNN DailyMail

In [None]:
import torch
from datasets import load_dataset
from transformers import T5TokenizerFast, T5ForConditionalGeneration


# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: { device.upper() }')
model_checkpoint = 't5-small'
tokenizer = T5TokenizerFast.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

# Load dataset
test_articles = load_dataset('cnn_dailymail', '1.0.0', split='test')['article'] # article
test_summaries = load_dataset('cnn_dailymail', '1.0.0', split='test')['highlights'] # summaries
print(f'Number of texts: { len(test_articles) }')

In [None]:
import matplotlib.pyplot as plt


len_articles = [len(tokenizer.encode(article)) for article in test_articles]
len_summaries = [len(tokenizer.encode(summary)) for summary in test_summaries]

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5), sharey=True);
axes[0].hist(len_articles, bins=20, color="C0", edgecolor="C0");
axes[0].set_title('Article Token Length');
axes[0].set_xlabel('Length (Tokens)');
axes[0].set_ylabel('Count');
axes[1].hist(len_summaries, bins=20, color="C0", edgecolor="C0");
axes[1].set_title('Summary Token Length');
axes[1].set_xlabel('Length (Tokens)');
plt.tight_layout();
plt.show();

In [None]:
import traceback

from tqdm import tqdm


# Model will have minimum <stride> context tokens when calculating conditional likelihood of a token
# (Provided there are <stride> preceding tokens available to condition on)
# These tokens are set to -100 so that they're ignored for the loss calculation
stride = 512

all_out = list()
num_examples = 10

for article, ref_summary in zip(tqdm(test_articles[:num_examples]), tqdm(test_summaries[:num_examples])):
	# Initialize hash, counter
	tmp = dict()

	# Prepare article for summarization per T5 config
	article = 'summarize: ' + article

	# Convert source text to model input IDs
	encodings_source = tokenizer(
		text=article,
		truncation=True,
		return_offsets_mapping=True
	)
	input_ids = torch.tensor( # input ids (including EOS token) x 1
		data=encodings_source['input_ids']+[tokenizer.eos_token_id],
		device=model.device
	)
	input_ids = input_ids.unsqueeze(dim=0) # T5 model needs tensor([[]]), encodings alone return tensor([])
	print(f'Input IDs (encodings + EOS): { input_ids.size() }')

	# Convert target text to model decoder input IDs and labels
	encodings_target = tokenizer(
		text=ref_summary,
		truncation=True,
		return_offsets_mapping=True,
	)
	labels = torch.tensor( # input ids (including EOS token) x 1
		data=encodings_target['input_ids']+[tokenizer.eos_token_id],
		device=model.device
	)
	labels = labels.unsqueeze(dim=0) # T5 model needs tensor([[]]), encodings alone return tensor([])
	print(f'Labels: { labels.size() }')

	with torch.no_grad():
		output = model( # not using model.generate since it doesn't support loss calculation
			input_ids=input_ids,
			labels=labels
		)
	
	# logit(p) = log(p/(1-p)) transforms probability values => domain [0, 1] range [-inf, +inf]
	# unlike LM, don't need to shift logits for AS
	logits = output['logits'].squeeze()

	# labels = ground truths taken from reference summary
	# thus, don't need to shift labels
	labels = labels.contiguous().squeeze(dim=0)

	# NLL = multi-class cross entropy (in this case) --> calculated in log-e
	# CE loss: H(P, Q) = -\sum(x in X) P(x) \log_{2}{Q(x)}, where Q(x) is the estimated distribution
	# Note: this calculation is equivalent to calculating probs * -log_probs, then for each token, select column corresponding to vocab index (see sanity check below)
	# Equivalent to IC: higher P(w) --> lower IC (relatively little info gained)
	ce_loss = torch.nn.functional.cross_entropy( # input ids x 1
		input=logits.view(-1, logits.size(-1)), # Predicted unnormalized logits
		target=labels.view(-1), # Ground truth class labels (one-hot encoded)
		reduction='none' # No reduction applied to cross entropy output
	)

	# Sanity check: mean NLL should be close to automatic loss calculation
	# print(f'Mean NLL: { torch.mean(ce_loss) }')
	# print(f'Automatic loss calculation: { output["loss"] }')
	assert torch.isclose(
		input=torch.mean(ce_loss),
		other=output['loss'],
		rtol=0.1
	)

	# Softmax to normalize logits
	probs = logits.softmax(dim=-1) # input ids x vocabulary
	print(f'Probs { probs.size() }:\n{ probs }')
	token_probs = torch.tensor([prob[vocab_idx] for vocab_idx, prob in zip(labels, probs)]) # input ids x 1
	print(f'Token probs { token_probs.size() }:\n{ token_probs }')

	# Sanity check: probabilities across each encoding (row) sums to 1
	for idx, vocab_token in enumerate(probs):
		assert torch.isclose(
			input=torch.sum(vocab_token),
			other=torch.tensor(1.),
			rtol=0.01
		)
	# Sanity check: equivalently, sum of all elements in probabilities tensor should equal number of tokens
	sum_probs = torch.sum(probs, dtype=torch.float32)
	len_labels = torch.tensor([labels.size()[0]], dtype=torch.float32)
	# print(f'Sum of all elements in probabilities tensor: { sum_probs }')
	# print(f'Number of tokens: { len_labels }')
	assert torch.isclose(
		input=sum_probs,
		other=len_labels,
		rtol=0.01
	)

	# Sanity check: NLL of token at token index equals CE loss
	nlls = -1 * logits.log_softmax(dim=-1) # input ids x vocabularly
	ics = torch.tensor([nll[vocab_idx] for vocab_idx, nll in zip(labels, nlls)]) # input ids x 1
	assert torch.allclose(
		input=ics,
		other=ce_loss
	)
	
	# Entropy (a.k.a. expected IC): sum_{v in vocab} p(v) log p(v) at each token position (sum along row)
	per_token_entropy = probs * nlls # input ids x vocabulary
	per_token_entropy = torch.sum(per_token_entropy, dim=1) # input ids x 1, dim=1 means sum along each row (token)

	# Sanity check: number of entropies calculated equals number of summary tokens
	print(f'Expected: { torch.tensor(labels.size()) }')
	print(f'Actual: { torch.tensor(per_token_entropy.size()) }\n')
	assert torch.tensor(per_token_entropy.size()) == torch.tensor(labels.size())
	
	tmp['article'] = article
	tmp['ref_summary'] = ref_summary
	tmp['tensor_input'] = input_ids
	tmp['logits'] = logits
	tmp['labels'] = labels
	tmp['probs'] = token_probs # probs but at the correct indices => input ids x 1
	tmp['nlls'] = nlls # input ids x vocabulary
	tmp['ics'] = ics # nlls but at the correct indices (equivalently ce_loss and per token NLL) => input ids x 1
	tmp['per_token_entropy'] = per_token_entropy

	all_out.append(tmp)

### Analysis

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set()

In [None]:
print(f'Number of articles summarized: { len(all_out) }')

all_token_probs = torch.cat([text['probs'] for text in all_out])
all_ics = torch.cat([text['ics'] for text in all_out])
all_entropies = torch.cat([text['per_token_entropy'] for text in all_out])
assert all_ics.size() == all_entropies.size()

In [None]:
# Perplexity
mean_nll = torch.mean(all_ics)
perplexity = torch.exp(mean_nll)
print(f'Perplexity: { perplexity }')

In [None]:
#### NLL & ENTROPY
## Density plot
fig1, (ics, entropies) = plt.subplots(1, 2, sharex=False, sharey=False);
fig1.suptitle('Token attributes');

# Information content (NLL)
sns.kdeplot(
	ax=ics,
	data=all_ics,
	bw_method=0.5,
	color='green',
	fill=True,
	label='IC'
);
ics.set(title='Negative log likelihood');
ics.set(xlabel='Bits');

# Entropy (expected IC)
sns.kdeplot(
	ax=entropies,
	data=all_entropies,
	bw_method=0.5,
	color='blue',
	fill=True,
	label='Entropy'
);
entropies.set(title='Entropy');
entropies.set(xlabel='Bits');
# entropies.xaxis.set_major_locator(ticker.MultipleLocator(2));
# entropies.xaxis.set_major_formatter(ticker.ScalarFormatter());

plt.show();

# abs(diff)
diff_ic_ent = torch.sub(all_ics, all_entropies);

abs_diff_ic_ent = torch.abs(diff_ic_ent);

density_abs_diff_ic_ent = sns.kdeplot(
	data=abs_diff_ic_ent,
	bw_method=1,
	color='orange',
	fill=True,
	label='abs. diff'
	# clip=(0, 100)
);

# sq(diff)
sq_diff_ic_ent = torch.square(diff_ic_ent)
density_sq_diff_ic_ent = sns.kdeplot(
	data=sq_diff_ic_ent,
	bw_method=1,
	color='yellow',
	fill=True,
	label='sq. diff'
	# clip=(0, 100)
);

xmax = max(sq_diff_ic_ent);
density_sq_diff_ic_ent.set_title('Deviation of IC from expected IC');
density_sq_diff_ic_ent.legend(loc=0);
density_sq_diff_ic_ent.set(xlim=(0, xmax));
density_sq_diff_ic_ent.set(xlabel='Bits');

plt.show();

print(f'Mean abs diff: { torch.mean(abs_diff_ic_ent) }');
print(f'Median abs diff: { torch.median(abs_diff_ic_ent) }');
print(f'Max abs diff: { torch.max(abs_diff_ic_ent) }');
print(f'Min abs diff: { torch.min(abs_diff_ic_ent) }');

### <strong>Abstractive summarization</strong> with BART and CNN DailyMail
* Note: BART tokenizer automatically adds BOS and EOS tokens, so there is no need to add them manually to the model inputs/labels (as done in the LM implementation)

In [None]:
import torch
from datasets import load_dataset
from transformers import BartTokenizerFast, BartForConditionalGeneration, BartConfig


# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: { device.upper() }')
model_checkpoint = 'facebook/bart-large-cnn'
tokenizer = BartTokenizerFast.from_pretrained(model_checkpoint)
if tokenizer.pad_token is None:
	tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = BartForConditionalGeneration.from_pretrained(model_checkpoint)

# Load dataset
test_articles = load_dataset('cnn_dailymail', '1.0.0', split='test')['article'] # article
test_summaries = load_dataset('cnn_dailymail', '1.0.0', split='test')['highlights'] # summaries
print(f'Number of texts: { len(test_articles) }')

In [None]:
import matplotlib.pyplot as plt


len_articles = [len(tokenizer.encode(article)) for article in test_articles]
len_summaries = [len(tokenizer.encode(summary)) for summary in test_summaries]

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5), sharey=True);
axes[0].hist(len_articles, bins=20, color="C0", edgecolor="C0");
axes[0].set_title('Article Token Length');
axes[0].set_xlabel('Length (Tokens)');
axes[0].set_ylabel('Count');
axes[1].hist(len_summaries, bins=20, color="C0", edgecolor="C0");
axes[1].set_title('Summary Token Length');
axes[1].set_xlabel('Length (Tokens)');
plt.tight_layout();
plt.show();

In [None]:
import traceback

from tqdm import tqdm


max_input_length = model.config.max_position_embeddings # BART
print(f'BART model maximum INPUT length: { max_input_length }')
max_target_length = 100 # manually set
assert max_target_length <= max_input_length
print(f'Manually set maximum TARGET length: { max_target_length }')

# Model will have minimum <stride> context tokens when calculating conditional likelihood of a token
# (Provided there are <stride> preceding tokens available to condition on)
# These tokens are set to -100 so that they're ignored for the loss calculation
stride = 512

all_out = list()
num_examples = 10

for article, ref_summary in zip(tqdm(test_articles[:num_examples]), tqdm(test_summaries[:num_examples])):
	# Initialize hash, counter
	tmp = dict()

	# Convert source text to model input IDs
	encodings_source = tokenizer(
		text=article,
		max_length=max_input_length-2, # subtract 2 because we want to add BOS and EOS tokens
		truncation=True,
		return_offsets_mapping=True
		# padding=True
	)
	# no need to add BOS and EOS since tokenizer does that already
	input_ids = torch.tensor(encodings_source['input_ids'], device=model.device)
	input_ids = input_ids.unsqueeze(dim=0) # BART model needs tensor([[]]), encodings alone return tensor([])
	print(f'Input IDs (BOS + encodings + EOS): { input_ids.size() }')

	# Convert target text to model decoder input IDs and labels
	encodings_target = tokenizer(
		text=ref_summary,
		max_length=max_target_length-1, # subtract 1 because we want to add BOS/EOS token
		truncation=True,
		return_offsets_mapping=True,
	)
	label_ids = torch.tensor(encodings_target['input_ids'], device=model.device)
	label_ids = label_ids.unsqueeze(dim=0)
	print(f'Label IDs (BOS + encodings + EOS): { label_ids.size() }')

	# decoder_input_ids = torch.tensor(
	# 	data=[tokenizer.bos_token_id]+encodings_target['input_ids'],
	# 	device=model.device
	# )
	# decoder_input_ids = decoder_input_ids.unsqueeze(dim=0)
	# print(f'Decoder inputs: { decoder_input_ids.size() }')
	# labels = torch.tensor(
	# 	data=encodings_target['input_ids']+[tokenizer.eos_token_id],
	# 	device=model.device
	# )
	# labels = labels.unsqueeze(dim=0)
	# print(f'Labels: { labels.size() }')

	with torch.no_grad():
		output = model( # not using model.generate since it doesn't support loss calculation
			input_ids=input_ids,
			# decoder_input_ids=decoder_input_ids,
			labels=label_ids
		)
	
	# logit(p) = log(p/(1-p)) transforms probability values => domain [0, 1] range [-inf, +inf]
	# unlike LM, don't need to shift logits for AS
	logits = output['logits'].squeeze()

	# labels = ground truths taken from reference summary
	# thus, don't need to shift labels
	label_ids = label_ids.contiguous().squeeze(dim=0)

	# NLL = multi-class cross entropy (in this case) --> calculated in log-e
	# CE loss: H(P, Q) = -\sum(x in X) P(x) \log_{2}{Q(x)}, where Q(x) is the estimated distribution
	# Note: this calculation is equivalent to calculating probs * -log_probs, then for each token, select column corresponding to vocab index (see sanity check below)
	# Equivalent to IC: higher P(w) --> lower IC (relatively little info gained)
	ce_loss = torch.nn.functional.cross_entropy( # input ids x 1
		input=logits.view(-1, logits.size(-1)), # Predicted unnormalized logits
		target=label_ids.view(-1), # Ground truth class labels (one-hot encoded)
		reduction='none' # No reduction applied to cross entropy output
	)

	# Sanity check: mean NLL should be close to automatic loss calculation
	# print(f'Mean NLL: { torch.mean(ce_loss) }')
	# print(f'Automatic loss calculation: { output["loss"] }')
	assert torch.isclose(
		input=torch.mean(ce_loss),
		other=output['loss'],
		rtol=0.1
	)

	# Softmax to normalize logits
	probs = logits.softmax(dim=-1) # input ids x vocabulary
	print(f'Probs { probs.size() }:\n{ probs }')
	token_probs = torch.tensor([prob[vocab_idx] for vocab_idx, prob in zip(label_ids, probs)]) # input ids x 1
	print(f'Token probs { token_probs.size() }:\n{ token_probs }')

	# Sanity check: probabilities across each encoding (row) sums to 1
	for idx, vocab_token in enumerate(probs):
		assert torch.isclose(
			input=torch.sum(vocab_token),
			other=torch.tensor(1.),
			rtol=0.01
		)
	# Sanity check: equivalently, sum of all elements in probabilities tensor should equal number of tokens
	sum_probs = torch.sum(probs, dtype=torch.float32)
	len_labels = torch.tensor([label_ids.size()[0]], dtype=torch.float32)
	# print(f'Sum of all elements in probabilities tensor: { sum_probs }')
	# print(f'Number of tokens: { len_labels }')
	assert torch.isclose(
		input=sum_probs,
		other=len_labels,
		rtol=0.01
	)

	# Sanity check: NLL of token at token index equals CE loss
	nlls = -1 * logits.log_softmax(dim=-1) # input ids x vocabularly
	ics = torch.tensor([nll[vocab_idx] for vocab_idx, nll in zip(label_ids, nlls)]) # input ids x 1
	assert torch.allclose(
		input=ics,
		other=ce_loss
	)
	
	# Entropy (a.k.a. expected IC): sum_{v in vocab} p(v) log p(v) at each token position (sum along row)
	per_token_entropy = probs * nlls # input ids x vocabulary
	per_token_entropy = torch.sum(per_token_entropy, dim=1) # input ids x 1, dim=1 means sum along each row (token)

	# Sanity check: number of entropies calculated equals number of summary tokens
	print(f'Expected: { torch.tensor(label_ids.size()) }')
	print(f'Actual: { torch.tensor(per_token_entropy.size()) }\n')
	assert torch.tensor(per_token_entropy.size()) == torch.tensor(label_ids.size())
	
	tmp['article'] = article
	tmp['ref_summary'] = ref_summary
	tmp['tensor_input'] = input_ids
	tmp['logits'] = logits
	tmp['labels'] = label_ids
	tmp['probs'] = token_probs # probs but at the correct indices => input ids x 1
	tmp['nlls'] = nlls # input ids x vocabulary
	tmp['ics'] = ics # nlls but at the correct indices (equivalently ce_loss and per token NLL) => input ids x 1
	tmp['per_token_entropy'] = per_token_entropy

	all_out.append(tmp)

### Analysis

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set()

In [None]:
print(f'Number of articles summarized: { len(all_out) }')

all_token_probs = torch.cat([text['probs'] for text in all_out])
all_ics = torch.cat([text['ics'] for text in all_out])
all_entropies = torch.cat([text['per_token_entropy'] for text in all_out])
assert all_ics.size() == all_entropies.size()

In [None]:
# Perplexity
mean_nll = torch.mean(all_ics)
perplexity = torch.exp(mean_nll)
print(f'Perplexity: { perplexity }')

In [None]:
#### NLL & ENTROPY
## Density plot
fig1, (ics, entropies) = plt.subplots(1, 2, sharex=False, sharey=False);
fig1.suptitle('Token attributes');

# Information content (NLL)
sns.kdeplot(
	ax=ics,
	data=all_ics,
	bw_method=0.5,
	color='green',
	fill=True,
	label='IC'
);
ics.set(title='Negative log likelihood');
ics.set(xlabel='Bits');

# Entropy (expected IC)
sns.kdeplot(
	ax=entropies,
	data=all_entropies,
	bw_method=0.5,
	color='blue',
	fill=True,
	label='Entropy'
);
entropies.set(title='Entropy');
entropies.set(xlabel='Bits');
# entropies.xaxis.set_major_locator(ticker.MultipleLocator(2));
# entropies.xaxis.set_major_formatter(ticker.ScalarFormatter());

plt.show();

# abs(diff)
diff_ic_ent = torch.sub(all_ics, all_entropies);

abs_diff_ic_ent = torch.abs(diff_ic_ent);

density_abs_diff_ic_ent = sns.kdeplot(
	data=abs_diff_ic_ent,
	bw_method=1,
	color='orange',
	fill=True,
	label='abs. diff'
	# clip=(0, 100)
);

# sq(diff)
sq_diff_ic_ent = torch.square(diff_ic_ent)
density_sq_diff_ic_ent = sns.kdeplot(
	data=sq_diff_ic_ent,
	bw_method=1,
	color='yellow',
	fill=True,
	label='sq. diff'
	# clip=(0, 100)
);

xmax = max(sq_diff_ic_ent);
density_sq_diff_ic_ent.set_title('Deviation of IC from expected IC');
density_sq_diff_ic_ent.legend(loc=0);
density_sq_diff_ic_ent.set(xlim=(0, xmax));
density_sq_diff_ic_ent.set(xlabel='Bits');

plt.show();

print(f'Mean abs diff: { torch.mean(abs_diff_ic_ent) }');
print(f'Median abs diff: { torch.median(abs_diff_ic_ent) }');
print(f'Max abs diff: { torch.max(abs_diff_ic_ent) }');
print(f'Min abs diff: { torch.min(abs_diff_ic_ent) }');

### **Story generation** with GPT and ```WritingPrompts```

In [None]:
import torch
from datasets import load_dataset
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: { device.upper() }')
model_checkpoint = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_checkpoint)
if tokenizer.pad_token is None:
	tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained(model_checkpoint)

In [None]:
# Load data
import pandas as pd
# from google.colab import drive
# drive.mount('/content/drive')

prompts = pd.read_fwf('../data/writingPrompts/test.wp_source', header=None)[0]
# prompts = pd.read_fwf('/content/drive/MyDrive/ColabNotebooks/data/writingPrompts/test.wp_source', header=None)[0]
prompts.replace('.*] ', '', regex=True, inplace=True)
prompts.replace('.?<newline>.?', ' ', regex=True, inplace=True) # or keep?

ref_stories = pd.read_fwf('../data/writingPrompts/test.wp_target', header=None)[0]
# ref_stories = pd.read_fwf('/content/drive/MyDrive/ColabNotebooks/data/writingPrompts/test.wp_target', header=None)[0]
ref_stories.replace('.?<newline>.?', ' ', regex=True, inplace=True) # or keep?

In [None]:
import sys

from tqdm import tqdm


max_input_length = model.config.n_positions # GPT2
print(f'GPT2 model maximum input length: { max_input_length }')


# Model will have minimum <stride> context tokens when calculating conditional likelihood of a token
# (Provided there are <stride> preceding tokens available to condition on)
# These tokens are set to -100 so that they're ignored for the loss calculation
stride = 512

all_out = list()
num_examples = 10

for prompt, ref_story in zip(tqdm(prompts[:num_examples]), tqdm(ref_stories[:num_examples])):
	# Initialize hash, counter
	tmp = dict()

	# Concatenate prompt and reference story
	input_text = prompt + ref_story
	if len(input_text.split()) > max_input_length:
		continue

	# Convert text to model input IDs
	encodings = tokenizer(
		text=input_text,
		max_length=max_input_length-2, # subtract 2 because we want to add BOS and EOS tokens
		truncation=True,
		return_offsets_mapping=True # for fast tokenizers only
	)
	input_ids = torch.tensor(
		# BOS token + encodings + EOS token
		data=[tokenizer.bos_token_id]+encodings['input_ids']+[tokenizer.eos_token_id],
		device=model.device
	)

	with torch.no_grad():
		output = model(
			input_ids=input_ids,
			labels=input_ids
		)
	
	# logit(p) = log(p/(1-p)) transforms probability values => domain [0, 1] range [-inf, +inf]
	# shift logits: up to index -1 because we don't care about probability distribution over tokens after EOS
	logits_shifted = output['logits'][..., :-1, :].contiguous() # input ids x vocabulary

	# labels = ground truths taken from <tensor_input>
	# originally, logits[0] is the distribution conditioned on labels[0], but for logits we actually care about the probability conditioned on the preceding token
	# (as consistent with language modeling: what is the probability of token x_t given x_<t?)
	# => shift labels: align position 1 of the labels with position 0 of the logits => logits[0] should correspond to labels[1]
	labels_shifted = input_ids[..., 1:].contiguous().squeeze() # input ids x 1

	# NLL = multi-class cross entropy (in this case) --> calculated in log-e
	# CE loss: H(P, Q) = -\sum(x in X) P(x) \log_{2}{Q(x)}, where Q(x) is the estimated distribution
	# Note: this calculation is equivalent to calculating probs * -1og_probs, then for each token, select column corresponding to vocab index (see sanity check below)
	# Equivalent to IC: higher P(w) --> lower IC (relatively little info gained)
	ce_loss = torch.nn.functional.cross_entropy( # input ids x 1
		input=logits_shifted.view(-1, logits_shifted.size(-1)), # Predicted unnormalized logits
		target=labels_shifted.view(-1), # Ground truth class labels (one-hot encoded)
		reduction='none' # No reduction applied to cross entropy output
	)
	
	# Sanity check: mean NLL (CE loss) should be close to automatic loss calculation
	# print(f'Mean NLL: { torch.mean(ce_loss) }')
	# print(f'Automatic loss calculation: { output["loss"] }')
	assert torch.isclose(
		input=torch.mean(ce_loss),
		other=output['loss'],
		rtol=0.1
	)
	
	# Softmax to normalize logits
	probs = logits_shifted.softmax(dim=-1).squeeze() # input ids x vocabulary
	token_probs = torch.tensor([prob[vocab_idx] for vocab_idx, prob in zip(labels_shifted, probs)]) # probs for the correct indices (i.e. those corresponding to the tokens) => input ids x 1
	
	# Sanity check: probabilities across each encoding (row) sums to 1
	for idx, vocab_token in enumerate(probs):
		assert torch.isclose(
			input=torch.sum(vocab_token),
			other=torch.tensor(1.),
			rtol=0.01
		)
	# Equivalently, sum of all elements in probabilities tensor should equal number of tokens
	sum_probs = torch.sum(probs, dtype=torch.float32)
	len_labels = torch.tensor([labels_shifted.size()[0]], dtype=torch.float32)
	# print(f'Sum of all elements in probabilities tensor: { sum_probs }')
	# print(f'Number of tokens: { len_labels }')
	assert torch.isclose(
		input=sum_probs,
		other=len_labels,
		rtol=0.01
	)

	# Sanity check: NLL of token at token index equals CE loss
	nlls = -1 * logits_shifted.log_softmax(dim=-1).squeeze() # input ids x vocabulary
	# print(f'labels tensor (size {labels_shifted.size()}):\n{labels_shifted}')
	# print(f'NLLs tensor (size {nlls.size()}):\n{nlls}')
	# for vocab_idx, prob in zip(labels_shifted, probs):
	# 	print(f'vocab index {vocab_idx}: {prob[vocab_idx]}')
	ics = torch.tensor([nll[vocab_idx] for vocab_idx, nll in zip(labels_shifted, nlls)]) # input ids x 1
	assert torch.allclose(
				input=ics,
				other=ce_loss
			)

	# Entropy (a.k.a. expected IC): sum_{v in vocab} p(v) log p(v) at each token position (sum along row)
	per_token_entropy = probs * nlls # input ids x vocabulary
	per_token_entropy = torch.sum(per_token_entropy, dim=1) # input ids x 1, dim=1 means sum along each row

	# Sanity check: number of entropies calculated equals number of tokens minus 1 (shifted logits and labels)
	# print(f'Expected: { torch.tensor(labels.size())[1] - torch.tensor([1.]) }')
	# print(f'Actual: { torch.tensor(per_token_entropy.size()) }\n')
	assert torch.tensor(per_token_entropy.size()) == torch.tensor(input_ids.size()) - torch.tensor([1.])

	tmp['prompt'] = prompt
	tmp['ref_story'] = ref_story
	tmp['full_input'] = input_text
	tmp['len_full_input'] = input_ids.size()[0]
	tmp['tensor_input'] = input_ids
	tmp['logits'] = logits_shifted.squeeze()
	tmp['labels'] = labels_shifted
	tmp['probs'] = token_probs # probs but at the correct indices => input ids x 1
	tmp['nlls'] = nlls # input ids x vocabulary
	tmp['ics'] = ics # nlls but at the correct indices (equivalently ce_loss and per token NLL) => input ids x 1
	tmp['per_token_entropy'] = per_token_entropy

	all_out.append(tmp)

### Analysis

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
sns.set() # Set sns as default style

In [None]:
print(f'Number of texts processed: { len(all_out) }')

all_token_probs = torch.cat([text['probs'] for text in all_out])
all_ics = torch.cat([text['ics'] for text in all_out])
all_entropies = torch.cat([text['per_token_entropy'] for text in all_out])
assert all_ics.size() == all_entropies.size()

In [None]:
# Perplexity
mean_nll = torch.mean(all_ics)
perplexity = torch.exp(mean_nll)
print(f'Perplexity: { perplexity }')

In [None]:
#### NLL & ENTROPY
## Density plot
fig1, (ics, entropies) = plt.subplots(1, 2, sharex=True, sharey=True);
fig1.suptitle('Token attributes');

# Information content (NLL)
sns.kdeplot(
	ax=ics,
	data=all_ics,
	bw_method=0.5,
	color='green',
	fill=True,
	label='IC'
);
ics.set(title='Negative log likelihood');
ics.set(xlabel='Bits');
ics.set(xlim=(0));
ics.xaxis.set_major_locator(ticker.MultipleLocator(2));
ics.xaxis.set_major_formatter(ticker.ScalarFormatter());

# Entropy (expected IC)
sns.kdeplot(
	ax=entropies,
	data=all_entropies,
	bw_method=0.5,
	color='blue',
	fill=True,
	label='Entropy'
);
entropies.set(title='Entropy');
entropies.set(xlabel='Bits');
# entropies.set(xlim=(0));
entropies.xaxis.set_major_locator(ticker.MultipleLocator(2));
entropies.xaxis.set_major_formatter(ticker.ScalarFormatter());

plt.show();


#### DIFF(NLL, ENTROPY)
diff_ic_ent = torch.sub(all_ics, all_entropies);
# density_diff_ic_ent = sns.kdeplot(
# 	data=diff_ic_ent,
# 	bw_method=1,
# 	color='red',
# 	fill=True,
# 	label='diff'
# 	# clip=(0, 100)
# )
# density_diff_ic_ent.set(
# 	title='diff(IC, entropy)',
# 	xlabel='Bits',
# 	ylabel='Density'
# )

# abs(diff)
abs_diff_ic_ent = torch.abs(diff_ic_ent);

density_abs_diff_ic_ent = sns.kdeplot(
	data=abs_diff_ic_ent,
	bw_method=1,
	color='orange',
	fill=True,
	label='abs. diff'
	# clip=(0, 100)
);

# sq(diff)
sq_diff_ic_ent = torch.square(diff_ic_ent)
density_sq_diff_ic_ent = sns.kdeplot(
	data=sq_diff_ic_ent,
	bw_method=1,
	color='yellow',
	fill=True,
	label='sq. diff'
	# clip=(0, 100)
);

xmax = max(sq_diff_ic_ent)+1;
density_sq_diff_ic_ent.set_title('Deviation of IC from expected IC');
density_sq_diff_ic_ent.legend(loc=0);
density_sq_diff_ic_ent.set(xlim=(0, xmax));
density_sq_diff_ic_ent.set(xlabel='Bits');

plt.show();

print(f'Mean abs diff: { torch.mean(abs_diff_ic_ent) }');
print(f'Median abs diff: { torch.median(abs_diff_ic_ent) }');
print(f'Max abs diff: { torch.max(abs_diff_ic_ent) }');
print(f'Min abs diff: { torch.min(abs_diff_ic_ent) }');

### <strong>Story generation</strong> with BART and ```WritingPrompts```

In [None]:
import torch
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: { device.upper() }')
model_checkpoint = 'facebook/bart-large-cnn'
tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
if tokenizer.pad_token is None:
	tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = BartForConditionalGeneration(BartConfig())

In [None]:
# Load data
import pandas as pd
# from google.colab import drive
# drive.mount('/content/drive')

prompts = pd.read_fwf('../data/writingPrompts/test.wp_source', header=None)[0]
# prompts = pd.read_fwf('/content/drive/MyDrive/ColabNotebooks/data/writingPrompts/test.wp_source', header=None)[0]
prompts.replace('.*] ', '', regex=True, inplace=True)
prompts.replace('.?<newline>.?', ' ', regex=True, inplace=True) # or keep?

ref_stories = pd.read_fwf('../data/writingPrompts/test.wp_target', header=None)[0]
# ref_stories = pd.read_fwf('/content/drive/MyDrive/ColabNotebooks/data/writingPrompts/test.wp_target', header=None)[0]
ref_stories.replace('.?<newline>.?', ' ', regex=True, inplace=True) # or keep?

In [None]:
import sys

from tqdm import tqdm


max_input_length = model.config.max_position_embeddings # BART
print(f'BART model maximum INPUT length: { max_input_length }')

# Model will have minimum <stride> context tokens when calculating conditional likelihood of a token
# (Provided there are <stride> preceding tokens available to condition on)
# These tokens are set to -100 so that they're ignored for the loss calculation
stride = 512

all_out = list()
num_examples = 10

for prompt, target in zip(tqdm(prompts[:num_examples]), tqdm(ref_stories[:num_examples])):
	# Initialize hash, counter
	tmp = dict()

	# Convert prompt to model input IDs
	encodings_prompt = tokenizer(
		text=prompt,
		max_length=max_input_length-1, # subtract 1 because we want to add BOS token
		truncation=True
		# return_offsets_mapping=True # for fast tokenizers only
	)
	input_ids = torch.tensor(
		# want prompt to start with BOS but not end with EOS
		data=encodings_prompt['input_ids'][:-1],
		device=model.device
	)
	input_ids = input_ids.unsqueeze(dim=0) # Bart model needs tensor([[]]), input alone returns only tensor([])

	# Convert reference story to label IDs
	encodings_target = tokenizer(
		text=target,
		max_length=max_input_length-1, # subtract 1 because we want to keep EOS token
		truncation=True,
		padding=True
	)
	labels = torch.tensor(
		# want generated story to end with EOS but not start with BOS
		data=encodings_target['input_ids'][1:],
		device=model.device
	)
	labels = labels.unsqueeze(dim=0)

	with torch.no_grad():
		output = model(
			input_ids=input_ids,
			labels=labels
		)
	loss = output['loss']
	# print(f'model loss: { loss }')
	
	# logit(p) = log(p/(1-p)) transforms probability values => domain [0, 1] range [-inf, +inf]
	# shift logits: up to index -1 because we don't care about probability distribution over tokens after EOS
	logits = output['logits'].squeeze(dim=0) # input ids x vocabulary

	# labels = ground truths taken from reference summary
	# thus, don't need to shift labels
	labels = labels.contiguous().squeeze(dim=0)

	# NLL = multi-class cross entropy (in this case) --> calculated in log-e
	# CE loss: H(P, Q) = -\sum(x in X) P(x) \log_{2}{Q(x)}, where Q(x) is the estimated distribution
	# Note: this calculation is equivalent to calculating probs * -1og_probs, then for each token, select column corresponding to vocab index (see sanity check below)
	# Equivalent to IC: higher P(w) --> lower IC (relatively little info gained)
	ce_loss = torch.nn.functional.cross_entropy( # input ids x 1
		input=logits.view(-1, logits.size(-1)), # Predicted unnormalized logits
		target=labels.view(-1), # Ground truth class labels (one-hot encoded)
		reduction='none' # No reduction applied to cross entropy output
	)
	# print(f'mean CE loss: { torch.mean(ce_loss) }')
	
	# Sanity check: mean NLL (CE loss) should be close to automatic loss calculation
	# print(f'Mean NLL: { torch.mean(ce_loss) }')
	# print(f'Automatic loss calculation: { output["loss"] }')
	assert torch.isclose(
		input=torch.mean(ce_loss),
		other=output['loss'],
		rtol=0.1
	)
	
	# Softmax to normalize logits
	labels = labels.squeeze(dim=0)
	probs = logits.softmax(dim=-1).squeeze(dim=0) # input ids x vocabulary
	token_probs = torch.tensor([prob[vocab_idx] for vocab_idx, prob in zip(labels, probs)]) # probs for the correct indices (i.e. those corresponding to the tokens) => input ids x 1
	
	# Sanity check: probabilities across each encoding (row) sums to 1
	for idx, vocab_token in enumerate(probs):
		assert torch.isclose(
			input=torch.sum(vocab_token),
			other=torch.tensor(1.),
			rtol=0.01
		)
	# Sanity check: equivalently, sum of all elements in probabilities tensor should equal number of tokens
	sum_probs = torch.sum(probs, dtype=torch.float32)
	len_labels = torch.tensor([labels.size()[0]], dtype=torch.float32)
	# print(f'Sum of all elements in probabilities tensor: { sum_probs }')
	# print(f'Number of tokens: { len_labels }')
	assert torch.isclose(
		input=sum_probs,
		other=len_labels,
		rtol=0.01
	)

	# Sanity check: NLL of token at token index equals CE loss
	nlls = -1 * logits.log_softmax(dim=-1).squeeze() # input ids x vocabulary
	ics = torch.tensor([nll[vocab_idx] for vocab_idx, nll in zip(labels, nlls)]) # input ids x 1
	assert torch.allclose(
				input=ics,
				other=ce_loss
			)

	# Entropy (a.k.a. expected IC): sum_{v in vocab} p(v) log p(v) at each token position (sum along row)
	per_token_entropy = probs * nlls # input ids x vocabulary
	per_token_entropy = torch.sum(per_token_entropy, dim=1) # input ids x 1, dim=1 means sum along each row

	# Sanity check: number of entropies calculated equals number of summary tokens minus 1 (shifted logits and labels)
	# print(f'Expected: { torch.tensor(labels.size()) }')
	# print(f'Actual: { torch.tensor(per_token_entropy.size()) }\n')
	assert torch.tensor(per_token_entropy.size()) == torch.tensor(labels.size())

	tmp['prompt'] = prompt
	tmp['ref_story'] = target
	tmp['len_prompt'] = input_ids.size()[0] - labels.size()[0]
	tmp['len_ref_story'] = labels.size()[0]
	tmp['tensor_input'] = input_ids.squeeze()
	tmp['logits'] = logits
	tmp['labels'] = labels
	tmp['probs'] = token_probs # probs but at the correct indices => input ids x 1
	tmp['nlls'] = nlls # input ids x vocabulary
	tmp['ics'] = ics # nlls but at the correct indices (equivalently ce_loss and per token NLL) => input ids x 1
	tmp['per_token_entropy'] = per_token_entropy

	all_out.append(tmp)

### Analysis

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
sns.set() # Set sns as default style

In [None]:
print(f'Number of texts processed: { len(all_out) }')

all_token_probs = torch.cat([text['probs'] for text in all_out])
all_ics = torch.cat([text['ics'] for text in all_out])
all_entropies = torch.cat([text['per_token_entropy'] for text in all_out])
assert all_ics.size() == all_entropies.size()

In [None]:
# Perplexity
mean_nll = torch.mean(all_ics)
perplexity = torch.exp(mean_nll)
print(f'Perplexity: { perplexity }')

In [None]:
#### NLL & ENTROPY
## Density plot
fig1, (ics, entropies) = plt.subplots(1, 2, sharex=False, sharey=False);
fig1.suptitle('Token attributes');

# Information content (NLL)
sns.kdeplot(
	ax=ics,
	data=all_ics,
	bw_method=0.5,
	color='green',
	fill=True,
	label='IC'
);
ics.set(title='Negative log likelihood');
ics.set(xlabel='Bits');
ics.set(xlim=(0));
ics.xaxis.set_major_locator(ticker.MultipleLocator(2));
ics.xaxis.set_major_formatter(ticker.ScalarFormatter());

# Entropy (expected IC)
sns.kdeplot(
	ax=entropies,
	data=all_entropies,
	bw_method=0.5,
	color='blue',
	fill=True,
	label='Entropy'
);
entropies.set(title='Entropy');
entropies.set(xlabel='Bits');
# entropies.set(xlim=(0));
# entropies.xaxis.set_major_locator(ticker.MultipleLocator(2));
# entropies.xaxis.set_major_formatter(ticker.ScalarFormatter());

plt.show();


#### DIFF(NLL, ENTROPY)
diff_ic_ent = torch.sub(all_ics, all_entropies);
# density_diff_ic_ent = sns.kdeplot(
# 	data=diff_ic_ent,
# 	bw_method=1,
# 	color='red',
# 	fill=True,
# 	label='diff'
# 	# clip=(0, 100)
# )
# density_diff_ic_ent.set(
# 	title='diff(IC, entropy)',
# 	xlabel='Bits',
# 	ylabel='Density'
# )

# abs(diff)
abs_diff_ic_ent = torch.abs(diff_ic_ent);

density_abs_diff_ic_ent = sns.kdeplot(
	data=abs_diff_ic_ent,
	bw_method=1,
	color='orange',
	fill=True,
	label='abs. diff'
	# clip=(0, 100)
);

# sq(diff)
sq_diff_ic_ent = torch.square(diff_ic_ent)
density_sq_diff_ic_ent = sns.kdeplot(
	data=sq_diff_ic_ent,
	bw_method=1,
	color='yellow',
	fill=True,
	label='sq. diff'
	# clip=(0, 100)
);

xmax = max(sq_diff_ic_ent)+1;
density_sq_diff_ic_ent.set_title('Deviation of IC from expected IC');
density_sq_diff_ic_ent.legend(loc=0);
density_sq_diff_ic_ent.set(xlim=(0, xmax));
density_sq_diff_ic_ent.set(xlabel='Bits');

plt.show();

print(f'Mean abs diff: { torch.mean(abs_diff_ic_ent) }');
print(f'Median abs diff: { torch.median(abs_diff_ic_ent) }');
print(f'Max abs diff: { torch.max(abs_diff_ic_ent) }');
print(f'Min abs diff: { torch.min(abs_diff_ic_ent) }');