In [32]:
import pickle 
with open(r'src/Dataloaders/train_loader.pkl', 'rb') as f:
    train_loader = pickle.load(f)

In [33]:
for val in train_loader:
    data = val['input_ids'] 
    break

In [34]:
import os

In [35]:
os.getcwd()

'd:\\DecoderKAN'

In [37]:
# Import necessary libraries
import torch
from src.model import build_transformer  # Ensure this matches your model.py
from transformers import PreTrainedTokenizerFast  # For loading the tokenizer

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VOCAB_SIZE = 512
MAX_SEQ_LEN = 44

# Load the tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained("src/tokenizer/QED_TOKENIZER")
print(f"Tokenizer loaded with vocab size: {tokenizer.vocab_size}")

# Ensure padding and truncation settings
tokenizer.pad_token = tokenizer.eos_token  # Often needed for consistency
print(f"Pad token: {tokenizer.pad_token}, EOS token: {tokenizer.eos_token}")

# Function to create causal mask (same as training)
def create_causal_mask(size, device):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool().to(device)
    return ~mask

# Load the saved model
model = build_transformer(vocab_size=VOCAB_SIZE, d_model=512, num_heads=8)
checkpoint = torch.load("transformer_qed_sequence_full.pth", map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
print(f"Model loaded on {DEVICE} from transformer_qed_sequence_full.pth")

# Example input: Tokenize a physics interaction
# Replace this string with your actual input (e.g., from your CSV)
input_text = "AntiPart e_[ID](X) e_[ID](X)^(*) to s_eps_18941(X) AntiPart s_eta_22311(X)^(*)"
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=MAX_SEQ_LEN, truncation=True)
src = encoded['input_ids'].to(DEVICE)  # [1, 44]
src_mask = encoded['attention_mask'].unsqueeze(1).unsqueeze(2).expand(1, 1, MAX_SEQ_LEN, MAX_SEQ_LEN).to(DEVICE)

# Start token for target sequence
tgt_start = torch.tensor([[tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1]], dtype=torch.long).to(DEVICE)  # [1, 1]

# Inference: Generate sequence autoregressively
predicted_sequence = tgt_start
with torch.no_grad():
    for _ in range(MAX_SEQ_LEN - 1):  # Generate up to MAX_SEQ_LEN - 1 tokens
        tgt_seq_len = predicted_sequence.size(1)
        causal_mask = create_causal_mask(tgt_seq_len, DEVICE)
        tgt_padding_mask = (predicted_sequence != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
        tgt_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(1, 1, tgt_seq_len, tgt_seq_len)
        tgt_mask = tgt_mask & tgt_padding_mask.expand(1, 1, tgt_seq_len, tgt_seq_len)
        cross_mask = (src != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2).expand(1, 1, tgt_seq_len, MAX_SEQ_LEN)

        output = model(src, predicted_sequence, src_mask, tgt_mask, cross_mask)  # [1, tgt_seq_len, VOCAB_SIZE]
        next_token_logits = output[:, -1, :]  # [1, VOCAB_SIZE]
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)  # [1, 1]
        predicted_sequence = torch.cat([predicted_sequence, next_token], dim=1)  # [1, tgt_seq_len + 1]

        # Stop if end token is predicted
        if next_token.item() == tokenizer.eos_token_id:
            break

# Decode the predicted sequence
predicted_tokens = predicted_sequence.tolist()[0]  # List of token IDs, e.g., [1, 108, 11, ...]
decoded_output = tokenizer.decode(predicted_tokens)
print(f"Input text: {input_text}")
print(f"Predicted token IDs: {predicted_tokens}")
print(f"Decoded output: {decoded_output}")

Tokenizer loaded with vocab size: 512
Pad token: None, EOS token: None
Model loaded on cpu from transformer_qed_sequence_full.pth
Input text: AntiPart e_[ID](X) e_[ID](X)^(*) to s_eps_18941(X) AntiPart s_eta_22311(X)^(*)
Predicted token IDs: [1, 5, 65, 33, 18, 103, 71, 49, 33, 14, 89, 54, 239, 46, 33, 38, 32, 15, 61, 135, 33, 90, 61, 140, 12, 15, 8, 94, 33, 95, 7, 110, 140, 12, 222, 8, 38, 32, 17, 104, 109, 8, 2, 8]
Decoded output: [SEP] % gam _ 5 33 }( p _ 1 )_ v ^(*)/( m _ e ^ 2 Ġ+ Ġs _ 12 Ġ+ Ġ1 / 2 * reg _ prop ) Ġ: Ġ1 / 36 * e ^ 4 *( 16 * [PAD] *


In [None]:
data[0].unsqueeze(0)