In [1]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import transformers
# import datasets
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# from torch.utils.data import Dataset

from functools import partial

# from collections import Counter

import argparse
import time
from tqdm import tqdm # Loading bar
print('Done.')

import os
import pandas as pd
import string
import re
import unicodedata
import random

Done.


In [2]:
from utils import complex_conj_transpose, batched_complex_conj_transpose, complex_exp, complex_exp_v2, complex_hadamard, complex_matmul, complex_division
from utils import batched_complex_conj_transpose, batched_complex_hadamard, batched_complex_matmul, batched_complex_division
from utils import batched_complex_exp, batched_complex_hadamard_full, batched_complex_matmul_full
print('Done.')

Done.


In [3]:
from dynamics import stochastic_LTI, DynamicSim
from dynamics import construct_mapping
from dynamics import get_nth_measurement, get_random_measurements
from dynamics import linear_spiral, linear_spiral_3D, Lorenz, rand_coupling_matrix, Van_der_Pol_osc
print('Done.')

Done.


In [4]:
from precision_attention import compute_residuals, compute_kernel_v1, compute_estimates_and_residuals_vectorized, get_time_diffs, compute_neg_kernel, clamp_exponent_arg
from precision_attention import compute_kernel, batched_compute_estimates_and_residuals_vectorized, compute_estimates_and_residuals_irregular_times, compute_nu
from precision_attention import compute_precision_v1
# from precision_attention import precise_attn, precise_attn_with_correction, precise_attn_full
from precision_attention import compute_precision, compute_precision_tanh
print('Done.')

Done.


In [5]:
from model import compute_lambda_h
from model import init_complex_matrix, build_nearly_identity, initialize_to_correct_model
from model import init_weight_masks, apply_weight_masks
from model import Complex_MSE_Loss, Batched_Complex_MSE_Loss, inverse_penalty
from model import BatchedPrecisionAttentionBlock
from model import HadamardLayer, TemporalNorm, TemporalWhiteningLayer
from model import PrecisionNet_1layer, PrecisionNet
print('Done.')

Done.


In [6]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)

torch.manual_seed(2025)
np.random.seed(2025)

cuda:0


### Load in Wikitext data

In [7]:
# Visualize data

local_data_dir = r"C:\Users\Pracioppo\Desktop\Peter DynAttn Proj\data\wikitext"

# Replace 'your_file.parquet' with the actual path to your Parquet file

try:
    # Load the Parquet file into a Pandas DataFrame
    df = pd.read_parquet(local_data_dir)

    print("--- DataFrame Head (First 5 rows) ---")
    print(df.head())

    print("\n--- DataFrame Info (Columns, Non-Null Counts, Dtypes) ---")
    df.info()

    print("\n--- DataFrame Description (Statistical Summary) ---")
    print(df.describe()) # For numerical columns

    print(f"\n--- DataFrame Shape (Rows, Columns) ---")
    print(f"Shape: {df.shape}")

    print(f"\n--- DataFrame Columns ---")
    print(df.columns.tolist())

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred: {e}")

--- DataFrame Head (First 5 rows) ---
                                                text
0                                                   
1                              = Robert Boulter = \n
2                                                   
3   Robert Boulter is an English film , televisio...
4   In 2006 , Boulter starred alongside Whishaw i...

--- DataFrame Info (Columns, Non-Null Counts, Dtypes) ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44836 entries, 0 to 44835
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    44836 non-null  object
dtypes: object(1)
memory usage: 350.4+ KB

--- DataFrame Description (Statistical Summary) ---
         text
count   44836
unique  26538
top          
freq    15717

--- DataFrame Shape (Rows, Columns) ---
Shape: (44836, 1)

--- DataFrame Columns ---
['text']


In [12]:
# --- Define the local path where you saved the files ---
# IMPORTANT: Replace this with the actual path on your computer!
local_data_dir = r"C:\Users\Pracioppo\Desktop\Peter DynAttn Proj\data\wikitext"

# Define the full paths to your parquet files
train_file = f"{local_data_dir}/train-00000-of-00001.parquet"
validation_file = f"{local_data_dir}/validation-00000-of-00001.parquet"
test_file = f"{local_data_dir}/test-00000-of-00001.parquet"

try:
    print("Attempting to load datasets via Pandas workaround...")

    # Load each parquet file into a Pandas DataFrame
    train_df = pd.read_parquet(train_file)
    validation_df = pd.read_parquet(validation_file)
    test_df = pd.read_parquet(test_file)

    print("Parquet files loaded into Pandas DataFrames successfully.")

    # Convert Pandas DataFrames to Hugging Face Dataset objects
    train_dataset = Dataset.from_pandas(train_df)
    validation_dataset = Dataset.from_pandas(validation_df)
    test_dataset = Dataset.from_pandas(test_df)

    print("Pandas DataFrames converted to Hugging Face Dataset objects.")

    # Create a DatasetDict from these Dataset objects
    raw_wiki_dataset = DatasetDict({
        'train': train_dataset,
        'validation': validation_dataset,
        'test': test_dataset
    })

    print("\nDataset loaded successfully via Pandas workaround!")
    print(raw_wiki_dataset)

except FileNotFoundError:
    print(f"Error: One or more local files not found in '{local_data_dir}'.")
    print("Please ensure the 'local_data_dir' path is correct and the files (train-*.parquet, validation-*.parquet, test-*.parquet) are present.")
except Exception as e:
    print(f"An unexpected error occurred during Pandas loading: {e}")
    print("Please ensure you have pandas and pyarrow installed:")
    print("pip install pandas pyarrow")

Attempting to load datasets via Pandas workaround...
Parquet files loaded into Pandas DataFrames successfully.
Pandas DataFrames converted to Hugging Face Dataset objects.

Dataset loaded successfully via Pandas workaround!
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
})


In [13]:
def clean_text(examples):
    cleaned_texts = []
    allowed_chars_pattern = r"[a-zA-Z0-9\s.,!?'\"-]"
#     allowed_chars_pattern = r"[\x20-\x7E\t\n\r]"
    for text in examples["text"]:
        # Strip leading/trailing whitespace (including newlines)
        stripped_text = text.strip()

        # Filter out empty strings and Wikipedia-style headings
        if not stripped_text or \
           (stripped_text.startswith('=') and stripped_text.endswith('=') and len(stripped_text) < 50) or \
           len(stripped_text) < 50: # Filter very short lines that are likely not useful sentences
            continue # Skip this example if it matches criteria

        # Remove common Wikitext artifacts like '@-@', '= =' using string.replace()
        cleaned_text = stripped_text.replace("@", " ") # Replace " @ " with a space
        cleaned_text = cleaned_text.replace(" = = ", " ")  # Replace " = = " with a space

        # Remove non standard chars
#         cleaned_text = unicodedata.normalize('NFKC', cleaned_text)
        cleaned_text = re.sub(r'[^' + allowed_chars_pattern + ']', '', cleaned_text)
        cleaned_text = re.sub(r'\(.*?\)', '', cleaned_text)
        
        # Remove words that are entirely in uppercase
        words = cleaned_text.split() # Split the current cleaned_text into words
        filtered_words = [
            word for word in words 
            if not (word.strip(string.punctuation).isalpha() and word.strip(string.punctuation).isupper())
        ]
        cleaned_text = ' '.join(filtered_words)
        
        # Normalize spaces (replace multiple spaces with a single space)
        cleaned_text = ' '.join(cleaned_text.split())
        
        # Remove spaces before punctuation
        cleaned_text = re.sub(r'\s+(' + r"[.,!?;:\"\'%]" + ')', r'\1', cleaned_text)

        if cleaned_text: # Ensure it's not empty after all cleaning steps
            cleaned_texts.append(cleaned_text)

    return {"text": cleaned_texts}

In [14]:
wiki_dataset = raw_wiki_dataset.map(
    clean_text,
    batched=True,
    remove_columns=["text"], # The 'text' column will be replaced by the cleaned one
)

  0%|          | 0/37 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

In [17]:
print(f"Original train split size: {len(raw_wiki_dataset['train'])}")
print(f"Cleaned train split size: {len(wiki_dataset['train'])}")
print(f"Example cleaned text from train split (first 500 chars):")
# print(cleaned_dataset['train'][0]['text'][:500])

i = np.random.choice(len(wiki_dataset['train']))
print(i)
print(wiki_dataset['train'][i]['text'][:500])

len(wiki_dataset['train'][i]['text'])

Original train split size: 36718
Cleaned train split size: 16214
Example cleaned text from train split (first 500 chars):
6718


618

### Alternatively, load in from txt files

In [18]:
txt_dir = r"C:\Users\Pracioppo\Desktop\Peter DynAttn Proj\data\poets"

names = ["shakespeare", "shelley", "milton"]
all_cleaned_text = []

for name in names:

    file_path = txt_dir + "\\" + name + ".txt"

    # 1. Load the text file
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            full_text = f.read()
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' was not found.")
    except Exception as e:
        print(f"An error occurred while reading the file: {e}")

    data_for_cleaning = {"text": [full_text]}

    cleaned_result = clean_text(data_for_cleaning)

    all_cleaned_text.extend(cleaned_result["text"])
    
combined_text = " ".join(all_cleaned_text) 

In [19]:
def split_segments_by_nth_period(text, n=3):
    """
    Splits a given text into segments, where each segment ends after n periods.
    """
    segments = []
    current_start_index = 0
    period_count = 0

    for i, char in enumerate(text):
        if char == '.':
            period_count += 1
        
        # If we've hit 'n' periods, or if we're at the very end of the text
        if period_count == n or (i == len(text) - 1 and current_start_index < len(text)):
            segment = text[current_start_index : i + 1].strip()
            if segment: # Only add non-empty segments
                segments.append(segment)
            current_start_index = i + 1
            period_count = 0 # Reset count for the next segment
    
    # Handle any remaining text if the last segment didn't end with 'n' periods
    # and wasn't caught by the (i == len(text) - 1) condition
    if current_start_index < len(text):
        remaining_text = text[current_start_index:].strip()
        if remaining_text:
            segments.append(remaining_text)
            
    return segments

segments = split_segments_by_nth_period(combined_text)
print(f"Number of segments: {len(segments)}")

random.shuffle(segments) # This shuffles the list in-place

Number of segments: 22863


In [20]:
pre_split_dataset = Dataset.from_dict({"text": segments})

train_test_split = pre_split_dataset.train_test_split(test_size=0.2, seed=42) # Using a seed for reproducibility

# Access the training and initial test sets
train_dataset = train_test_split['train']
test_val_dataset = train_test_split['test'] # This will be split further into test and validation

# Now, split the test_val_dataset into actual test and validation sets
# We'll split the remaining 20% in half, so 10% for test and 10% for validation of the overall data
test_val_split = test_val_dataset.train_test_split(test_size=0.5, seed=42) 

test_dataset = test_val_split['train'] # This becomes the true test set
val_dataset = test_val_split['test']   # This becomes the validation set

In [21]:
poets_dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset, # Use 'validation' as the standard key for validation set
    'test': test_dataset
})

In [25]:
print('WIKI DATASET:')
for split_name, dataset_obj in wiki_dataset.items():
    print(f"  {split_name}: {len(dataset_obj)} examples")

for split_name, dataset_obj in wiki_dataset.items():
    total_chars = sum(len(example['text']) for example in dataset_obj)
    print(f"  {split_name}: {total_chars} characters")

print('==================')

print('POETS DATASET:')
for split_name, dataset_obj in poets_dataset.items():
    print(f"  {split_name}: {len(dataset_obj)} examples")

for split_name, dataset_obj in poets_dataset.items():
    total_chars = sum(len(example['text']) for example in dataset_obj)
    print(f"  {split_name}: {total_chars} characters")

WIKI DATASET:
  train: 16214 examples
  validation: 1734 examples
  test: 1945 examples
  train: 10035232 characters
  validation: 1057589 characters
  test: 1180421 characters
POETS DATASET:
  train: 18290 examples
  validation: 2287 examples
  test: 2286 examples
  train: 5858316 characters
  validation: 724226 characters
  test: 715568 characters


### Tokenizer

In [26]:
# Load tokenizer

# from transformers import BertTokenizer, BertModel
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# local_bert_dir = r"C:\Users\Pracioppo\Desktop\Peter DynAttn Proj\data\local_bert_dir"
# tokenizer = BertTokenizer.from_pretrained(local_bert_dir)
# print("Tokenizer loaded!")

local_bert_mini_dir = r"C:\Users\Pracioppo\Desktop\Peter DynAttn Proj\data\prajjwal1_bert_mini"
tokenizer = AutoTokenizer.from_pretrained(local_bert_mini_dir)
print("Tokenizer loaded!")

# Ensure tokenizer has a pad_token if your model or DataCollator needs it
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    print("Added padding.")

Tokenizer loaded!


In [27]:
def tokenize_function(examples):
    # This function uses the 'tokenizer' object defined globally above
    return tokenizer(examples["text"]) # No truncation/max_length here for LM concatenation

In [30]:
# Apply Tokenization
print("\nTokenizing dataset...")
tokenized_wiki_dataset = wiki_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"], # This removes the original 'text' column after tokenization
)


Tokenizing dataset...


  0%|          | 0/17 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [33]:
# Define Grouping Function for Language Modeling
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [36]:
# Define Block Size for Language Modeling
block_size = 128

# Apply Grouping
print(f"\nGrouping texts into blocks of size {block_size} for language modeling...")
lm_dataset = tokenized_wiki_dataset.map(
    group_texts,
    batched=True
)
lm_dataset.set_format("torch")
print("Texts grouped for language modeling.")


Grouping texts into blocks of size 128 for language modeling...


  0%|          | 0/17 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Texts grouped for language modeling.


In [32]:
# Apply Tokenization
print("\nTokenizing dataset...")
tokenized_poets_dataset = poets_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"], # This removes the original 'text' column after tokenization
)

In [37]:
# Define Block Size for Language Modeling
block_size = 128

# Apply Grouping
print(f"\nGrouping texts into blocks of size {block_size} for language modeling...")
lm_dataset = tokenized_poets_dataset.map(
    group_texts,
    batched=True
)
lm_dataset.set_format("torch")
print("Texts grouped for language modeling.")


Grouping texts into blocks of size 128 for language modeling...


  0%|          | 0/19 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

Texts grouped for language modeling.


In [None]:
# def custom_similarity(query_emb, vocab_embs, alpha = 1):
#     """
#     Calculates custom similarity
#     """

#     diff = query_emb.unsqueeze(0) - vocab_embs
#     p = 1/(alpha + diff**2)
    
#     scores = torch.mean(p,axis=1)
    
#     return scores

In [None]:
def embedding_to_nearest_token(vocab_embed, query_embed, similar_fn=1):
    """
    Given an embedding, find the nearest embedding in the vocabulary and output the corresponding token.
    """

    if similar_fn == 1:    
        diff = query_emb.unsqueeze(0) - vocab_embs
        p = 1/(alpha + diff**2)
        similarities = torch.mean(p,axis=1)
    else:
        similarities = F.cosine_similarity(query_embed.unsqueeze(0), vocab_embeddings, dim=1)

    # Find the top K similarities and their indices
    top_k_similarities, top_k_indices = torch.topk(similarities, k=1)

    # Get the token ID
    token_id = top_k_indices.item()
    # Convert the token ID back to a readable token string
    token_string = tokenizer.convert_ids_to_tokens(token_id)
    similarity_score = top_k_similarities.item()

    return token_string, similarity_score

# query_embed = initial_input_embeddings[0,0]
# vocab_embed = model.embeddings.word_embeddings.weight

# token_string, similarity_score = embedding_to_nearest_token(vocab_embed, query_embed)

# print(f"Token: '{token_string}', Similarity: {similarity_score:.4f}")

### Embedding Layer

In [None]:
class CustomEmbeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, pad_token_id=0, layer_norm_eps=1e-12, hidden_dropout_prob=0.1):
        super().__init__()
        # Word Embeddings: Maps token IDs to vectors
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        
        # Positional Embeddings: Maps token positions to vectors
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        
        # Token Type Embeddings: Maps segment IDs (e.g., 0 for first sentence, 1 for second) to vectors
        self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)

        # Layer Normalization: Normalizes the sum of embeddings
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        
        # Dropout: Regularization to prevent overfitting
        self.dropout = nn.Dropout(hidden_dropout_prob)

        # Register position_ids buffer (used if position_ids are not provided in forward)
        self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)))

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        input_shape = input_ids.size()
        seq_length = input_shape[1]

        # 1. Get Word Embeddings
        inputs_embeds = self.word_embeddings(input_ids)

        # 2. Get Token Type Embeddings (if not provided, default to all 0s)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
        token_type_embeds = self.token_type_embeddings(token_type_ids)

        # 3. Get Positional Embeddings (if not provided, generate default sequence)
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length] # Use the pre-registered buffer
        position_embeds = self.position_embeddings(position_ids)

        # 4. Sum them up
        embeddings = inputs_embeds + token_type_embeds + position_embeds
        
        # 5. Apply Layer Normalization and Dropout
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings