In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
import pandas as pd
print('All libraries imported.')

All libraries imported.


In [2]:
# === ALL CLASS DEFINITIONS ARE NOW INSIDE THE NOTEBOOK ===

class VSA:
    def __init__(self, dim: int):
        self.dim = dim
    def bind(self, vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
        fft1 = np.fft.fft(vec1)
        fft2 = np.fft.fft(vec2)
        return np.fft.ifft(fft1 * fft2).real
    def unbind(self, vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
        fft1 = np.fft.fft(vec1)
        fft2 = np.fft.fft(vec2)
        return np.fft.ifft(fft1.conj() * fft2).real

class HGC_Layer(nn.Module):
    def __init__(self, hidden_size, hkm_dim=2048):
        super().__init__()
        self.hkm_dim = hkm_dim
        self.query_projection = nn.Linear(hidden_size, hkm_dim)
        self.hkm = nn.Parameter(torch.randn(1, hkm_dim))
    def forward(self, hidden_states):
        queries = self.query_projection(hidden_states)
        # Simplified interaction with the HKM
        return torch.tanh(queries @ self.hkm.T @ self.hkm)

class GPT2_With_HGC(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.hgc_layer = HGC_Layer(hidden_size=config.n_embd, hkm_dim=2048)
        self.hgc_output_projection = nn.Linear(self.hgc_layer.hkm_dim, config.n_embd)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Pass inputs to the main transformer body
        transformer_outputs = self.transformer(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            **kwargs
        )
        hidden_states = transformer_outputs[0]
        
        # HGC-specific logic
        hgc_info = self.hgc_layer(hidden_states)
        projected_hgc_info = self.hgc_output_projection(hgc_info)
        conditioned_hidden_states = hidden_states + projected_hgc_info
        
        # Get final logits
        lm_logits = self.lm_head(conditioned_hidden_states)
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        from transformers.modeling_outputs import CausalLMOutputWithPast
        # THE FIX: Ensure the loss is included in the output object
        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
        )

print('Custom HGC classes defined successfully.')

Custom HGC classes defined successfully.


In [3]:
# --- Data Loading and Preparation ---
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset('truthful_qa', 'generation')
train_test = dataset['validation'].train_test_split(test_size=0.2, seed=42)
def format_and_tokenize(examples):
    formatted_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(examples['question'], examples['best_answer'])]
    return tokenizer(formatted_texts, truncation=True, padding='max_length', max_length=128)
tokenized_train = train_test['train'].map(format_and_tokenize, batched=True, remove_columns=train_test['train'].column_names)
tokenized_test = train_test['test'].map(format_and_tokenize, batched=True, remove_columns=train_test['test'].column_names)
tokenized_train.set_format(type='torch')
tokenized_test.set_format(type='torch')
print('Data prepared.')

Data prepared.


In [4]:
# --- THE DEFINITIVE FIX: Load pre-trained model and manually transfer weights ---
print('Loading pre-trained GPT-2 model...')
pretrained_model = GPT2LMHeadModel.from_pretrained('gpt2')

print('Initializing custom HGC model architecture...')
config = pretrained_model.config
hgc_model = GPT2_With_HGC(config)

print('Transferring weights to HGC model...')
hgc_model.load_state_dict(pretrained_model.state_dict(), strict=False)
print('Weight transfer complete.')

# --- Setup Trainer ---
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
OUTPUT_DIR = 'C:/HGC/models/hgc_gpt2_truthfulqa'

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=100,
    load_best_model_at_end=True,
    # THE FIX: Tell the trainer not to remove columns
    remove_unused_columns=False,
)

trainer = Trainer(
    model=hgc_model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    data_collator=data_collator,
)

# --- Train and Evaluate ---
print('Starting fine-tuning of HGC model...')
trainer.train()

print('\nEvaluating the HGC model...')
eval_results = trainer.evaluate()

perplexity = np.exp(eval_results['eval_loss'])
print(f'\nHGC Model Perplexity: {perplexity:.2f}')

print(f'Saving HGC model to {OUTPUT_DIR}...')
trainer.save_model(OUTPUT_DIR)
results_df = pd.DataFrame([{'model': 'HGC-Augmented GPT-2', 'perplexity': perplexity, 'eval_loss': eval_results['eval_loss']}])
results_df.to_csv('C:/HGC/data/hgc_results.csv', index=False)
print('\nHGC model fine-tuning complete.')

Loading pre-trained GPT-2 model...
Initializing custom HGC model architecture...
Transferring weights to HGC model...
Weight transfer complete.
Starting fine-tuning of HGC model...


Epoch,Training Loss,Validation Loss
1,2.1933,2.083067


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].



Evaluating the HGC model...



HGC Model Perplexity: 8.03
Saving HGC model to C:/HGC/models/hgc_gpt2_truthfulqa...

HGC model fine-tuning complete.
