In [30]:
import io

import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from og.dictionary import AutoEncoder
from datasets import load_dataset
import random
import safetensors
import torch.nn as nn



In [2]:
# Load the Pythia model and tokenizer
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
tokenizer.pad_token = tokenizer.eos_token

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


In [19]:
# Get 1000 sentences from the pile. 

# Load The Pile in streaming mode
dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)

# Collect 1,000 random sentences
random_sentences = []
for i, example in enumerate(dataset):
    random_sentences.append(example["text"])
    if len(random_sentences) >= 1000:
        break

# Save or process sentences
with open("random_sentences.txt", "w") as f:
    for sentence in random_sentences:
        f.write(sentence + "\n")


In [24]:
print(random_sentences[125])

Alexander Bell Donald

Alexander Bell Donald (18 August 1842–7 March 1922) was a New Zealand seaman, sailmaker, merchant and ship owner. He was born in Inverkeithing, Fife, Scotland on 18 August 1842.

References

Category:1842 births
Category:1922 deaths
Category:Scottish emigrants to New Zealand
Category:People from Inverkeithing


In [23]:
# Outputs before my stuff
test_input = random_sentences[125]

# Tokenize the input
inputs = tokenizer(test_input, return_tensors="pt", padding="max_length", truncation = False)
attention_mask = inputs['attention_mask']  # Get the attention_mask
with torch.no_grad():
    output = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'])

# Decode the output
decoded_output = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)

# Print decoded output
print(decoded_output)


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Alexander Bell Donald

Alexander Bell Donald (18 August 1842–7 March 1922) was a New Zealand seaman, sailmaker, merchant and ship owner. He was born in Inverkeithing, Fife, Scotland on 18 August 1842.

References

Category:1842 births
Category:1922 deaths
Category:Scottish emigrants to New Zealand
Category:People from Inverkeithing
Category:People from Fife
Category:New Zealand merchants
Category:New Zealand ship owners


Setting up the Autoencoder

In [31]:
# Set up autoencoder

safetensor_file = "dictionaries/sae.safetensors"

In [27]:
state_dict = {}
with safetensors.safe_open(safetensor_file, framework="pt", device="cpu") as f:
    for k in f.keys():
        state_dict[k] = f.get_tensor(k)
    
    # inspect the keys and the shapes of the associated tensors
    for key, value in state_dict.items():
        print(f'{key}: {value.shape}')

W_dec: torch.Size([65536, 1024])
b_dec: torch.Size([1024])
encoder.bias: torch.Size([65536])
encoder.weight: torch.Size([65536, 1024])


In [32]:
config = {'d_in': 410, 'expansion_factor': 64, 'normalize_decoder': True, 'num_latents': 65536, 'k': 32,
          'signed': False}


In [33]:
class Autoencoder(nn.Module):
    def __init__(self, config):
        super(Autoencoder, self).__init__()
        self.expansion_factor = config['expansion_factor']
        self.normalize_decoder = config['normalize_decoder']
        self.num_latents = config['num_latents']
        self.k = config['k']
        self.signed = config['signed']
        self.d_in = config['d_in']

        self.encoder = nn.Sequential(
            nn.Linear(self.d_in, self.expansion_factor * self.d_in),
            nn.ReLU(),
            nn.Linear(self.expansion_factor * self.d_in, self.num_latents)
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.num_latents, self.expansion_factor * self.d_in),
            nn.ReLU(),
            nn.Linear(self.expansion_factor * self.d_in, self.d_in)
        )

        if self.normalize_decoder:
            self.decoder.add_module('normalize', nn.BatchNorm1d(self.d_in))

    def forward(self, x):
        z = self.encoder(x)
        if self.signed:
            z = torch.sign(z)
        x_recon = self.decoder(z)
        return x_recon

In [None]:

ae = Autoencoder(config)
ae.load_state_dict(state_dict)

In [None]:
# Save this state dict, delete the safetensors. then we eating. 