In [None]:
import time
from functools import wraps

def execution_timer(func):
  @wraps(func)
  def wrapper(*args, **kwargs):
    start_time = time.perf_counter()
    result = func(*args, **kwargs)
    end_time = time.perf_counter()
    elapsed_time = end_time - start_time
    print(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
    return result
  return wrapper

In [None]:
from google.colab import drive

# This command mounts your Google Drive content to the Colab instance
print("Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully!")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully!


In [None]:
import heapq

class Node:
  def __init__(self, symbol=None, frequency=None):
    self.symbol = symbol
    self.frequency = frequency
    self.left = None
    self.right = None

  def __lt__(self, other):
    return self.frequency < other.frequency

  def is_leaf(self):
    return self.left is None and self.right is None

def build_tree(leaf_nodes, frequencies):
  # Create a priority queue of nodes
  priority_queue = [Node(val, freq) for val, freq in zip(leaf_nodes, frequencies)]
  heapq.heapify(priority_queue)

  internal_node_counter = 0
  # Build the Huffman tree
  while len(priority_queue) > 1:
    left_child = heapq.heappop(priority_queue)
    right_child = heapq.heappop(priority_queue)
    merged_node = Node(
      symbol=f'Internal Node {internal_node_counter}', frequency=left_child.frequency + right_child.frequency
    )
    merged_node.left = left_child
    merged_node.right = right_child
    heapq.heappush(priority_queue, merged_node)
  return priority_queue[0]

def generate_paths(node, code, path_dict):
  if node is not None:
    if node.symbol is not None and not isinstance(node.symbol, str):
      path_dict[node.symbol] = code
    generate_paths(node.left, code + [0], path_dict)
    generate_paths(node.right, code + [1], path_dict)
  return path_dict

def max_depth(node):
  if node is None:
    return 0
  left_depth = max_depth(node.left)
  right_depth = max_depth(node.right)
  return max(left_depth, right_depth) + 1

In [None]:
import random
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer

# 1. Initialize the Tokenizer (Same as before)
# The tokenizer is framework-agnostic.
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

vocabulary = tokenizer.get_vocab()
print(f"Vocabulary size: {len(vocabulary)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Vocabulary size: 50257


In [None]:
root = build_tree(list(vocabulary.values()), [1] * len(vocabulary))
paths = generate_paths(root, [], {})
print(f"Paths for vocabulary: {len(paths)}")
print(f"Path from root to leaf of {0}: {paths[0]}")
print(f"All items in vocabulary present in tree: {all(item in paths for item in vocabulary.values())}")
print(f"Number of items in vocabulary not present in tree: {len([item for item in vocabulary.values() if item not in paths])}")

Paths for vocabulary: 50257
Path from root to leaf of 0: [1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0]
All items in vocabulary present in tree: True
Number of items in vocabulary not present in tree: 0


In [20]:
from transformers import DataCollatorForLanguageModeling
from torch.utils.data import DataLoader, Dataset

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
MAX_SEQ_LEN = 128
BATCH_SIZE = 8
SHUFFLE_DATA = True

from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [21]:
datasets['train'] = datasets['train'].select(range(10000))
datasets['validation'] = datasets['validation'].select(range(500))
datasets['test'] = datasets['test'].select(range(100))

def tokenize(examples):
    return tokenizer(examples["text"])

tokenized_datasets = datasets.map(tokenize, batched=True, num_proc=4, remove_columns=["text"])

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def group_texts(examples, block_size = MAX_SEQ_LEN):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False # False for Causal Language Modeling (CLM)
)

train_dataloader = DataLoader(
    lm_datasets["train"],
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE_DATA,
    num_workers=2, # Use multiple processes for faster data loading
    pin_memory=True, # Speeds up transfer to GPU
    collate_fn=data_collator,
)
test_dataloader = DataLoader(
    lm_datasets["validation"],
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE_DATA,
    num_workers=2, # Use multiple processes for faster data loading
    pin_memory=True, # Speeds up transfer to GPU
    collate_fn=data_collator,
)

print(f"Total number of training sequences: {len(lm_datasets["train"])}")
print(f"Number of batches per epoch: {len(train_dataloader)}")

print("\n--- Example Batch Inspection ---")
for batch in train_dataloader:
    print(f"Input IDs shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    print(f"Attention Mask shape: {batch['attention_mask'].shape}")
    print("\nFirst sequence (decoded):")
    print(tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True))
    break # Stop after showing the first batch
print("\n--- Example Batch Inspection ---")
for batch in test_dataloader:
    print(f"Input IDs shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    print(f"Attention Mask shape: {batch['attention_mask'].shape}")
    print("\nFirst sequence (decoded):")
    print(tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True))
    break # Stop after showing the first batch

Map (num_proc=4):   0%|          | 0/100 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/500 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/100 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/500 [00:00<?, ? examples/s]

Total number of training sequences: 5011
Number of batches per epoch: 627

--- Example Batch Inspection ---
Input IDs shape: torch.Size([8, 128])
Labels shape: torch.Size([8, 128])
Attention Mask shape: torch.Size([8, 128])

First sequence (decoded):
 . He also played for the Fairfield Stallions in the New England Collegiate Baseball League in 1994 . 
 In August 2008 , he gave the SUNY / Stony Brook athletics department $ 500 @,@ 000 for a new baseball facility . In recognition of this " lead gift " from the Joe Nathan Charitable Foundation , the college named it " Joe Nathan Field . " 
 = = Professional career = = 
 = = = Minor Leagues = = = 
 He began his minor league career in Class A for the Bellingham Giants . After an unsuccessful year at the plate the Giants tried to convert Nathan into a pitcher

--- Example Batch Inspection ---
Input IDs shape: torch.Size([8, 128])
Labels shape: torch.Size([8, 128])
Attention Mask shape: torch.Size([8, 128])

First sequence (decoded):
 of the 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config

class HierarchicalSoftmaxNodeLanguageGeneration(nn.Module):
  def __init__(self, root, tokenizer, hidden_size):
    super().__init__()
    self.root = root
    self.tokenizer = tokenizer
    self.hidden_size = hidden_size
    print(f"Using hidden size: {self.hidden_size}")
    self.paths = generate_paths(root, [], {})
    self.node_name_map = {}
    self.node_weights = nn.ModuleDict()
    self.param_counter = 0

    def initialize_node_parameters(node):
      if node is None or node.is_leaf():
        return None
      node_str = str(self.param_counter)
      self.node_name_map[node] = node_str
      self.node_weights[node_str] = nn.Linear(self.hidden_size, 1, bias=False)
      self.param_counter += 1
      initialize_node_parameters(node.left)
      initialize_node_parameters(node.right)

    initialize_node_parameters(self.root)
    print(f"HSM initialized with {len(self.node_weights)} internal nodes")

  def forward(self, hidden_state, target_ids):
      h = hidden_state.view(-1, hidden_state.size(-1))
      targets = target_ids.view(-1)

      total_loss = torch.tensor(0.0)
      total_loss.requires_grad = True
      total_valid_tokens = 0

      for h_i, target_id in zip(h, targets):
        path_step_loss = []
        target = target_id.item()
        if target not in self.paths:
          continue
        choices = self.paths[target]
        curr = self.root
        for choice in choices:
          if curr.is_leaf():
            break
          node_str = self.node_name_map[curr]
          W = self.node_weights[node_str]
          binary_loss = F.binary_cross_entropy_with_logits(
              W(h_i),
              torch.tensor([float(choice)], device=device),
              reduction='sum'
          )
          path_step_loss.append(binary_loss)
          curr = curr.left if not choice else curr.right
        total_loss = total_loss + torch.stack(path_step_loss).sum()
        total_valid_tokens += 1
      return total_loss / max(1, total_valid_tokens)


In [37]:
from transformers import GPT2Model, GPT2Config
from transformers import GenerationConfig
import torch.nn as nn

class GPT2HierarchicalSoftmaxModel(nn.Module):
  def __init__(self, config, root, tokenizer, hidden_size):
    super().__init__()
    self.transformer = GPT2Model(config)
    self.generation_config = GenerationConfig()
    self.hsm_head = HierarchicalSoftmaxNodeLanguageGeneration(root, tokenizer, hidden_size)

  def forward(self, input_ids, attention_mask=None, labels=None):
    outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    hidden_states = outputs.last_hidden_state
    if labels is not None:
      shifted_hidden_states = hidden_states[:, :-1, :].contiguous()
      shifted_labels = labels[:, 1:].contiguous()
      labels = shifted_labels.view(-1)
      max_label = labels.max().item()
      if max_label >= len(vocabulary):
          print(f"ERROR: Label ID ({max_label}) is >= vocab_size ({len(vocabulary)})!")
          raise ValueError("Labels contain Out-of-Vocabulary token IDs.")
      loss = self.hsm_head(shifted_hidden_states, shifted_labels)
      return {"loss": loss, "hidden_states": shifted_hidden_states}
    return {"hidden_states": hidden_states}


def _greedy_predict(model, input_ids, hidden_state, **kwargs):
    probability = 1.0
    temperature = 1.0
    if 'temperature' in kwargs:
        temperature = kwargs['temperature']
    curr = model.hsm_head.root
    while not curr.is_leaf():
      node_str = model.hsm_head.node_name_map[curr]
      W = model.hsm_head.node_weights[node_str]
      choice = F.sigmoid(W(hidden_state) / temperature) > 0.5
      probability *= F.sigmoid(W(hidden_state) / temperature)
      curr = curr.left if not choice else curr.right
    return curr.symbol


def _top_k_generate(model, hidden_state, k = 5, max_iterations = 1000, **kwargs):
  beam = [(0.0, model.hsm_head.root)]
  final_candidates = []
  iterations = 0
  temperature = 1.0
  if 'temperature' in kwargs:
      temperature = kwargs['temperature']

  while beam and len(final_candidates) < k and iterations < max_iterations:
      iterations += 1
      neg_log_prob, curr_node = heapq.heappop(beam)
      neg_log_prob_tensor = torch.tensor(neg_log_prob)
      if curr_node.is_leaf():
          probability = torch.exp(-neg_log_prob_tensor).item()
          final_candidates.append((curr_node.symbol, probability))
          continue
      node_str = model.hsm_head.node_name_map[curr_node]
      W = model.hsm_head.node_weights[node_str]
      p_right = F.sigmoid(W(hidden_state) / temperature)
      p_left = 1.0 - p_right
      log_p_right = torch.log(p_right)
      log_p_left = torch.log(p_left)
      if curr_node.left:
          new_neg_log_prob_left = neg_log_prob - log_p_left.item()
          heapq.heappush(beam, (new_neg_log_prob_left, curr_node.left))
      if curr_node.right:
          new_neg_log_prob_right = neg_log_prob - log_p_right.item()
          heapq.heappush(beam, (new_neg_log_prob_right, curr_node.right))
  final_candidates.sort(key=lambda x: x[1], reverse=True)
  return final_candidates

def _top_k_predict(model, input_ids, hidden_state, k = 5, max_iterations = 1000, **kwargs):
  final_candidates = _top_k_generate(model, hidden_state, k, max_iterations, **kwargs)
  raw_probabilities = torch.tensor([item[1] for item in final_candidates])
  sum_probabilities = torch.sum(raw_probabilities)
  if sum_probabilities.item() == 0:
      num_candidates = len(raw_probabilities)
      distribution = (torch.ones(num_candidates) / num_candidates).squeeze()
  else:
      distribution = raw_probabilities / sum_probabilities.squeeze()
  sampled_index_tensor = torch.multinomial(distribution, num_samples=1, replacement=False)
  return final_candidates[sampled_index_tensor][0]


def _top_p_predict(model, input_ids, hidden_state, p = 0.6, max_iterations = 1000, **kwargs):
  temperature = 1.0
  if 'temperature' in kwargs:
      temperature = kwargs['temperature']
  beam = [(0.0, model.hsm_head.root)]
  curr_p = 0.0
  final_candidates = []
  iterations = 0

  while beam and curr_p < p and iterations < max_iterations:
      iterations += 1
      neg_log_prob, curr_node = heapq.heappop(beam)
      neg_log_prob_tensor = torch.tensor(neg_log_prob)
      if curr_node.is_leaf():
          probability = torch.exp(-neg_log_prob_tensor).item()
          curr_p += probability
          final_candidates.append((curr_node.symbol, probability))
          continue
      node_str = model.hsm_head.node_name_map[curr_node]
      W = model.hsm_head.node_weights[node_str]
      p_right = F.sigmoid(W(hidden_state) / temperature)
      p_left = 1.0 - p_right
      log_p_right = torch.log(p_right)
      log_p_left = torch.log(p_left)
      if curr_node.left:
          new_neg_log_prob_left = neg_log_prob - log_p_left.item()
          heapq.heappush(beam, (new_neg_log_prob_left, curr_node.left))
      if curr_node.right:
          new_neg_log_prob_right = neg_log_prob - log_p_right.item()
          heapq.heappush(beam, (new_neg_log_prob_right, curr_node.right))
  final_candidates.sort(key=lambda x: x[1], reverse=True)
  raw_probabilities = torch.tensor([item[1] for item in final_candidates])
  sum_probabilities = torch.sum(raw_probabilities)
  if sum_probabilities.item() == 0:
      num_candidates = len(raw_probabilities)
      distribution = (torch.ones(num_candidates) / num_candidates).squeeze()
  else:
      distribution = raw_probabilities / sum_probabilities.squeeze()
  sampled_index_tensor = torch.multinomial(distribution, num_samples=1, replacement=False)
  return final_candidates[sampled_index_tensor][0]

def _beam_predict(model, input_ids, hidden_state, num_beams = 5, max_iterations = 10000, beam_depth = 2, **kwargs):
  inputs = input_ids.flatten().tolist()
  beam_results = [(inputs, hidden_state, 0.0) for i in range(num_beams)]
  i = 0
  iterations = 0

  while i < beam_depth and iterations < max_iterations:
    candidates = []
    for prev_seq, prev_hidden_state, prev_probability in beam_results:
      if prev_seq and prev_seq[-1] == model.hsm_head.tokenizer.eos_token_id:
        continue
      if len(prev_seq) > len(inputs) + beam_depth:
        continue
      next_token_candidates = _top_k_generate(model, prev_hidden_state, num_beams, max_iterations, **kwargs)
      for next_token_candidate, probability in next_token_candidates:
        updated_seq = prev_seq + [next_token_candidate]
        updated_hidden_state = model.transformer(torch.tensor([updated_seq], device=input_ids.device)).last_hidden_state[0, -1, :]
        candidates.append((updated_seq, updated_hidden_state, prev_probability + torch.log(torch.tensor(probability))))
        iterations += 1
    candidates.sort(key=lambda x: x[2], reverse=True)
    beam_results = candidates[:num_beams]
    i += 1
  beam_results.sort(key=lambda x: x[2], reverse=True)
  raw_probabilities = torch.exp(torch.tensor([item[2] for item in beam_results], dtype=torch.float32))
  sum_probabilities = torch.sum(raw_probabilities)
  if sum_probabilities.item() == 0:
      num_candidates = len(raw_probabilities)
      distribution = (torch.ones(num_candidates) / num_candidates).squeeze()
  else:
      distribution = raw_probabilities / sum_probabilities.squeeze()
  sampled_index_tensor = torch.multinomial(distribution, num_samples=1, replacement=False)
  return beam_results[sampled_index_tensor][0][len(inputs):]

@torch.no_grad()
@execution_timer
def generate(model, input_ids, generation_config=None, generate_method=None, **kwargs):
  model.eval()
  generation_config = generation_config or model.generation_config
  cur_length = input_ids.shape[1]
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens

  while cur_length < max_length:
    outputs = model.transformer(input_ids)
    hidden_state = outputs.last_hidden_state[0, -1, :]
    next_tokens = generate_method(input_ids=input_ids, hidden_state=hidden_state, **kwargs)
    if not isinstance(next_tokens, list):
      next_tokens = [next_tokens]
    if next_tokens[-1] == model.hsm_head.tokenizer.eos_token_id:
      break
    next_token_tensor = torch.tensor([next_tokens], device=input_ids.device)
    input_ids = torch.cat([input_ids, next_token_tensor], dim=-1)
    cur_length += 1

  return input_ids


In [22]:
import torch
import torch.optim as optim
from tqdm import tqdm
import math

@execution_timer
def train_llm(model, data_loader):
  model.train()
  progress_bar = tqdm(data_loader, desc=f"Training")
  for batch in progress_bar:
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    optimizer.zero_grad()
    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs["loss"]
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    progress_bar.set_postfix({'loss': loss.item()})

@execution_timer
def test_llm(model, data_loader) -> dict:
    model.eval()
    total_eval_loss = 0.0
    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc="Testing (Evaluation)")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs["loss"]
            total_eval_loss += loss.item()
            progress_bar.set_postfix({'batch_loss': f'{loss.item():.4f}'})
    avg_loss = total_eval_loss / len(data_loader)
    try:
        perplexity = math.exp(avg_loss)
    except OverflowError:
        perplexity = float('inf') # Handle cases where loss is too high

    print(f"\n--- Evaluation Complete ---")
    print(f"Average Loss (NLL): {avg_loss:.4f}")
    print(f"Perplexity (PPL): {perplexity:.2f}")

    return {"avg_loss": avg_loss, "perplexity": perplexity}

In [None]:
config = GPT2Config.from_pretrained("gpt2")
hsm_model = GPT2HierarchicalSoftmaxModel(config, root, tokenizer, 768)

Using hidden size: 768
HSM initialized with 50256 internal nodes


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model.to(device)
optimizer = optim.AdamW(hsm_model.parameters(), lr=5e-5)

for epoch in range(5):
  print(f"\n--- Epoch {epoch+1} ---")
  train_llm(hsm_model, train_dataloader)


--- Epoch 1 ---


Training: 100%|██████████| 627/627 [1:05:48<00:00,  6.30s/it, loss=6.9]


Function train_llm executed in 3948.4051 seconds

--- Epoch 2 ---


Training: 100%|██████████| 627/627 [1:04:45<00:00,  6.20s/it, loss=6.76]


Function train_llm executed in 3886.0910 seconds

--- Epoch 3 ---


Training: 100%|██████████| 627/627 [1:04:07<00:00,  6.14s/it, loss=6.11]


Function train_llm executed in 3847.8026 seconds

--- Epoch 4 ---


Training: 100%|██████████| 627/627 [1:04:20<00:00,  6.16s/it, loss=5.78]


Function train_llm executed in 3860.2947 seconds

--- Epoch 5 ---


Training: 100%|██████████| 627/627 [1:04:55<00:00,  6.21s/it, loss=5.33]


Function train_llm executed in 3895.7355 seconds


In [None]:
torch.save(hsm_model.state_dict(), "/content/drive/MyDrive/GPT2HSMModel.pt")

In [34]:
config = GPT2Config.from_pretrained("gpt2")
hsm_model_loaded = GPT2HierarchicalSoftmaxModel(config, root, tokenizer, 768)
hsm_model_loaded.load_state_dict(torch.load("/content/drive/MyDrive/GPT2HSMModel.pt", map_location=torch.device('cpu')))
hsm_model_loaded.eval()

Using hidden size: 768
HSM initialized with 50256 internal nodes


GPT2HierarchicalSoftmaxModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (hsm_head): HierarchicalSoftmaxNodeLanguageGeneration(
    (node_weights): ModuleDict(
      (0): Linear(in_f

In [39]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

print(f"\n--- Testing Generation ---")
input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
print(f"Input shape: {input_ids.shape}")
greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_greedy_predict, model=hsm_model_loaded))
generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.3147 seconds
Generated text: preference for grassland rather than the city . The French was to be the city 's character


In [40]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

for temperature in [0.1, 1, 2, 5, 100]:
  print(f"\n--- Testing Generation ---")
  input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
  print(f"Input shape: {input_ids.shape}")
  greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_greedy_predict, model=hsm_model_loaded, temperature=temperature))
  generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
  print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.4136 seconds
Generated text: preference for grassland rather than the city . The French was to be the city 's character

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.3001 seconds
Generated text: preference for grassland rather than the city . The French was to be the city 's character

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.3170 seconds
Generated text: preference for grassland rather than the city . The French was to be the city 's character

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.3027 seconds
Generated text: preference for grassland rather than the city . The French was to be the city 's character

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.2922 seconds
Generated text: preference for grassland rather

In [41]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

print(f"\n--- Testing Generation ---")
input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
print(f"Input shape: {input_ids.shape}")
greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_top_k_predict, model=hsm_model_loaded))
generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.6247 seconds
Generated text: preference for grassland rather than that it is not a " . 
 The first of the


In [14]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

for temperature in [0.1, 1, 2]:
  print(f"\n--- Testing Generation ---")
  input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
  print(f"Input shape: {input_ids.shape}")
  greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_top_k_predict, model=hsm_model_loaded, temperature=temperature, max_iterations=10000))
  generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
  print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.4234 seconds
Generated text: preference for grassland rather than the city . The ship was to the game , and the North

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.6962 seconds
Generated text: preference for grassland rather than the most of the " . In a " The album 's

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 1.3960 seconds
Generated text: preference for grassland rather anchored to be the first two @-@ day in his first @


In [15]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

print(f"\n--- Testing Generation ---")
input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
print(f"Input shape: {input_ids.shape}")
greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_top_p_predict, model=hsm_model_loaded, p=0.9))
generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 2.9338 seconds
Generated text: preference for grassland rather metallic Canadian @-@ off film 's main end was built as


In [16]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

for temperature in [0.1, 1, 2]:
  print(f"\n--- Testing Generation ---")
  input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
  print(f"Input shape: {input_ids.shape}")
  greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_top_p_predict, model=hsm_model_loaded, temperature=temperature, max_iterations=10000, p=0.9))
  generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
  print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 0.3733 seconds
Generated text: preference for grassland rather than one in the city . The ship was to be the city '

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 20.9812 seconds
Generated text: preference for grassland rather Fei and their use of general ticicry south . 


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 37.6493 seconds
Generated text: preference for grassland rather enjoying October at the low . It was he playedortion of the close


In [17]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

print(f"\n--- Testing Generation ---")
input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
print(f"Input shape: {input_ids.shape}")
greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=5), generate_method=partial(_beam_predict, model=hsm_model_loaded, num_beams=5, beam_depth=1, max_iterations=1000))
generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")



--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 7.2243 seconds
Generated text: preference for grassland rather than the " . 
 = = = = = = = 


In [18]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hsm_model_loaded.to(device)

for temperature in [0.1, 1, 2]:
  print(f"\n--- Testing Generation ---")
  input_ids = tokenizer.encode('preference for grassland rather', return_tensors='pt').to(device)
  print(f"Input shape: {input_ids.shape}")
  greedy_output=generate(hsm_model_loaded, input_ids, generation_config=GenerationConfig(max_new_tokens=1), generate_method=partial(_beam_predict, model=hsm_model_loaded, temperature=temperature, num_beams=5, beam_depth=1, max_iterations=1000))
  generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
  print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 4.7851 seconds
Generated text: preference for grassland rather than the city . The city is the city 's case , and

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 9.1089 seconds
Generated text: preference for grassland rather than the " . 
 = = = = = = = 

--- Testing Generation ---
Input shape: torch.Size([1, 6])
Function generate executed in 13.2353 seconds
Generated text: preference for grassland rather than the " . 
 = = = = = = 



In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config, GenerationConfig
import numpy as np

def ce_loss(logits, labels) -> float:
    C = logits.max().item()
    shifted_logits = logits - C
    exp_shifted_logits = torch.exp(shifted_logits)
    sum_exp_shifted_logits = torch.sum(exp_shifted_logits, dim=1, keepdim=True)
    log_probabilities = shifted_logits - torch.log(sum_exp_shifted_logits)
    batch_indices = torch.arange(labels.shape[0])
    log_p_true_class = log_probabilities[batch_indices, labels]
    mean_loss = -torch.mean(log_p_true_class)
    return mean_loss

class GPT2RegularSoftmaxModel(nn.Module):
    def __init__(self, config, vocab_size):
        super().__init__()
        self.transformer = GPT2Model(config)
        self.generation_config = GenerationConfig()
        self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
        self.vocab_size = vocab_size


    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.last_hidden_state

        if labels is not None:
            shifted_hidden_states = hidden_states[:, :-1, :].contiguous()
            shifted_labels = labels[:, 1:].contiguous()
            logits = self.lm_head(shifted_hidden_states)
            logits = logits.view(-1, self.vocab_size)
            labels = shifted_labels.view(-1)
            max_label = labels.max().item()
            if max_label >= self.vocab_size:
                print(f"ERROR: Label ID ({max_label}) is >= vocab_size ({self.vocab_size})!")
                raise ValueError("Labels contain Out-of-Vocabulary token IDs.")
            loss = ce_loss(logits, labels)
            return {"loss": loss, "logits": logits, "hidden_states": shifted_hidden_states}

        return {"hidden_states": hidden_states}

    def _greedy_predict(self, hidden_state):
        h_i = hidden_state.unsqueeze(0)
        logits = self.lm_head(h_i)
        next_token = torch.argmax(logits, dim=-1).item()
        return next_token


    @torch.no_grad()
    @execution_timer
    def generate(self, input_ids, generation_config=None, **kwargs):
        self.eval()
        tokenizer = kwargs.get('tokenizer')
        if tokenizer is None:
             raise ValueError("Tokenizer must be provided in kwargs for generation.")

        eos_token_id = tokenizer.eos_token_id

        generation_config = generation_config or self.generation_config
        cur_length = input_ids.shape[1]
        max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens

        if input_ids.shape[0] != 1:
            raise ValueError("This generation implementation supports only batch size 1.")

        while cur_length < max_length:
            outputs = self.transformer(input_ids)
            hidden_state = outputs.last_hidden_state[0, -1, :] # [E]
            next_token = self._greedy_predict(hidden_state)
            if next_token == eos_token_id:
                break
            next_token_tensor = torch.tensor([[next_token]], device=input_ids.device)
            input_ids = torch.cat([input_ids, next_token_tensor], dim=-1)
            cur_length += 1
        return input_ids

In [24]:
config = GPT2Config.from_pretrained("gpt2")
sm_model = GPT2RegularSoftmaxModel(config, tokenizer.vocab_size)

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sm_model.to(device)
optimizer = optim.AdamW(sm_model.parameters(), lr=2e-5)

for epoch in range(5):
  print(f"\n---Epoch {epoch + 1}---")
  train_llm(sm_model, train_dataloader)


---Epoch 1---


Training: 100%|██████████| 627/627 [03:10<00:00,  3.29it/s, loss=6.22]


Function train_llm executed in 190.3996 seconds

---Epoch 2---


Training: 100%|██████████| 627/627 [03:11<00:00,  3.28it/s, loss=6.02]


Function train_llm executed in 191.2819 seconds

---Epoch 3---


Training: 100%|██████████| 627/627 [03:11<00:00,  3.28it/s, loss=5.67]


Function train_llm executed in 191.3247 seconds

---Epoch 4---


Training: 100%|██████████| 627/627 [03:11<00:00,  3.28it/s, loss=6.27]


Function train_llm executed in 191.1678 seconds

---Epoch 5---


Training: 100%|██████████| 627/627 [03:12<00:00,  3.26it/s, loss=5.46]

Function train_llm executed in 192.0879 seconds





In [28]:
torch.save(sm_model.state_dict(), "/content/drive/MyDrive/GPT2SMModel.pt")

In [30]:
config = GPT2Config.from_pretrained("gpt2")
sm_model_loaded = GPT2RegularSoftmaxModel(config, tokenizer.vocab_size)
sm_model_loaded.load_state_dict(torch.load("/content/drive/MyDrive/GPT2SMModel.pt", map_location=torch.device('cpu')))
sm_model_loaded.eval()

GPT2RegularSoftmaxModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [32]:
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sm_model_loaded.to(device)

print(f"\n--- Testing Generation ---")
input_ids = tokenizer.encode('He also played for the Fairfield', return_tensors='pt').to(device)
print(f"Input shape: {input_ids.shape}")
greedy_output=sm_model.generate(input_ids, generation_config=GenerationConfig(max_new_tokens=5), tokenizer=tokenizer)
generated_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")


--- Testing Generation ---
Input shape: torch.Size([1, 7])
Function generate executed in 0.1906 seconds
Generated text: He also played for the Fairfield . The first season , the first season , the first season ,
