### Generation with custom `LearnedWeightsLogitsWarper`

In [1]:
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, LogitsWarper, LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria

#### Attribute calculations

In [2]:
def calc_probs(logits: torch.FloatTensor) -> torch.FloatTensor:
	return torch.nn.functional.softmax(logits, dim=-1)


def calc_nlls(logits: torch.FloatTensor) -> torch.FloatTensor:
	return torch.nn.functional.log_softmax(logits, dim=-1)


def calc_entropies(probs: torch.FloatTensor, nlls: torch.FloatTensor) -> torch.FloatTensor:
	entropies = torch.sum(probs * nlls, dim=-1)
	return entropies.unsqueeze(dim=1)


def calc_diff_nlls_ents(nlls: torch.FloatTensor, entropies: torch.FloatTensor) -> torch.FloatTensor:
	return nlls - entropies


def is_top_k(probs: torch.FloatTensor, k: int, device: str) -> torch.FloatTensor:
	top_k_indices = torch.topk(probs, k).indices
	top_k_indicators = torch.zeros([probs.size(0), probs.size(1)], device=device, dtype=torch.long)
	for time_step in range(probs.size(0)):
		for index in top_k_indices[time_step]:
			top_k_indicators[time_step][index] = 1
	return top_k_indicators

#### ```LearnedWeightsLogitsWarper```

In [3]:
class LearnedLogitsWarper(LogitsWarper):
	"""
	[`LogitsWarper`] that upweights token scores using a learned mapping from LM scores to a one-hot-encoded corpus of human-generated texts.
	
	Args:
		learned_weights_path: Absolute path to saved linear weights, i.e. the learned mapping from token attributes to one-hot encoding of human-generated (label) texts. Size is number of features `k` by model vocabulary size `|V|`.
	
	Attributes:
		w: Linear weight of size 1 x `k`.
		b: Linear bias of size [scalar].
	"""
	
	def __init__(self, learned_weights_path: str, unigram_freqs_path: str):
		self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
		self.model = torch.load(learned_weights_path, map_location=torch.device('cpu'))
		self.unigram_freqs_path = unigram_freqs_path
		
		self.w = self.model['linear.weight'] # 1 x k
		self.b = self.model['linear.bias'] # 1

	def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
		attributes = self.calc_features(scores) # k x |V|
		upweighted_scores = torch.matmul(self.w, attributes) + self.b # (1 x k) x (k x |V|)
		return upweighted_scores
	
	def calc_features(self, scores: torch.FloatTensor, k: int=100) -> torch.FloatTensor:
		"""Calculate features tensor."""
		probs = calc_probs(scores)
		nlls = calc_nlls(scores)
		entropies = calc_entropies(probs, nlls)
		abs_diff = torch.abs(calc_diff_nlls_ents(nlls, entropies))
		corpus_unigram_freq = self._get_corpus_unigram_freq()
		top_k = is_top_k(probs, k, self.device)

		features = torch.stack([probs, top_k, corpus_unigram_freq, abs_diff], dim=-1).squeeze()
		return torch.transpose(features, dim0=0, dim1=1)
	
	def _get_corpus_unigram_freq(self):
		return torch.load(self.unigram_freqs_path, map_location=torch.device('cpu'))

#### Generate text

In [4]:
checkpoint = 'gpt2-large'
model = GPT2LMHeadModel.from_pretrained(checkpoint)
tokenizer = GPT2TokenizerFast.from_pretrained(checkpoint)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [5]:
weights_path = '../models/weights/lm_28.pkl'
corpus_path = '../data/unigram_freq/wiki_200k.pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

seq_max_len = 100

logits_warper = LogitsProcessorList(
    [LearnedLogitsWarper(weights_path, corpus_path)]
)

stopping_criteria = StoppingCriteriaList(
    [MaxLengthCriteria(max_length=seq_max_len)]
)

device: cpu


In [6]:
freqs = torch.load(corpus_path, map_location=torch.device(device))

In [7]:
input = 'After the United States entered the war in April 1917'
input_ids = tokenizer(input, return_tensors='pt')['input_ids']

print(input_ids)

tensor([[ 3260,   262,  1578,  1829,  5982,   262,  1175,   287,  3035, 24168]])


In [8]:
torch.manual_seed(42)
outputs = model.sample(
    input_ids,
    logits_warper=logits_warper,
    stopping_criteria=stopping_criteria,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

In [9]:
outputs.squeeze()

tensor([ 3260,   262,  1578,  1829,  5982,   262,  1175,   287,  3035, 24168,
           11, 19297,  8018,   257, 12273,   761,   329,  1913,  2324,  4788,
           25,   284,  1730,   351,   644,   339,  2936,   561,   307,   281,
         5387, 17645,   286, 39627,    13,  1114, 47700,    11,   777,   366,
           79,  4733,  1964, 35198,     1,  2950,  1115,  4237,    25,  1605,
        25310,  3592,   338, 21403,   319, 11292, 29311,   286,  6541,   290,
          584,  9416,   284,  3284,   290,  4492,    26,   257,  3252,   286,
         7396,  4141,  4588,   287,  4881,    26,   290,  7570, 23594,   284,
         4885,  4925,   290,  9572,    13,   679,   900,   503,   287,  3945,
        25859,   284, 10568,   262,  1115,  7432,   290,  4474,   644,  2627])

In [10]:
tokenizer.decode(outputs.squeeze(), skip_special_tokens=True)

'After the United States entered the war in April 1917, Roosevelt recognized a pressing need for strong security policies: to deal with what he felt would be an internal outbreak of Communism. For FDR, these "possible political outbreaks" involved three sources: American Protestant society\'s dependence on overseas shipments of guns and other supplies to Russia and England; a fear of rising French influence in France; and Soviet hostility to Western freedom and intervention. He set out in February 1918 to resolve the three threats and establish what became'