In [None]:
!git clone https://github.com/LinyingLyu/ChronoGPT.git
%cd ChronoGPT

# Text generation

In [None]:
import torch
import torch.nn.functional as F
import tiktoken
from huggingface_hub import HfApi, login
from ChronoGPT_inference import *
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_dir = 'cache'  # Update this path as needed

tokenizer = tiktoken.get_encoding("gpt2")
max_length = 30
num_return_sequences = 5
seed = 11111

# -------------------------- Load Model --------------------------
repo_id = "manelalab/chrono-gpt-v1-20241231"
config_path = hf_hub_download(repo_id=repo_id, filename="config.pt", cache_dir=cache_dir)
bin_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir)
config = torch.load(config_path, map_location='cpu')
print(f"Model config: {config}")
model = ChronoGPT(**config)
model = model.to(device)
model = model.half()

state_dict = torch.load(bin_path, map_location=device)
model.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
gc.collect()

# ------------------------ Prepare Input -------------------------
prompt = "Hello, I am a language model,"
tokens = tokenizer.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
tokens = tokens.repeat(num_return_sequences, 1).to(device)

# -------------------- Sampling Initialization -------------------
xgen = tokens.clone()
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(seed)

# ------------------------- Text Generation -----------------------
while xgen.size(1) < max_length:
    with torch.no_grad():
        logits, _ = model(xgen)

        logits = logits[:, -1, :]  # Last token logits
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 30, dim=-1)

        sampled_idx = torch.multinomial(topk_probs, 1, generator=sample_rng)
        next_token = torch.gather(topk_indices, -1, sampled_idx)

        xgen = torch.cat([xgen, next_token], dim=1)


# ------------------------- Decode Output -------------------------
for i in range(num_return_sequences):
    decoded_tokens = xgen[i, :max_length].tolist()
    decoded_text = tokenizer.decode(decoded_tokens)
    print(f"Rank sample {i}:\n{decoded_text}\n")

# Embeddings extraction

In [None]:
import torch
import torch.nn.functional as F
import tiktoken
from huggingface_hub import HfApi, login
from ChronoGPT_inference import *

# ----------------------------- Setup -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_dir = 'cache'  # Update this path as needed

tokenizer = tiktoken.get_encoding("gpt2")

# -------------------------- Load Model --------------------------
repo_id = "manelalab/chrono-gpt-v1-20241231"
config_path = hf_hub_download(repo_id=repo_id, filename="config.pt", cache_dir=cache_dir)
bin_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir)
config = torch.load(config_path, map_location='cpu')
print(f"Model config: {config}")
model = ChronoGPT(**config)
model = model.to(device)
model = model.half()

state_dict = torch.load(bin_path, map_location=device)
model.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
gc.collect()

# ----------------------- Embedding Generation ---------------------
text = "Obviously, the time continuum has been disrupted, creating a new temporal event sequence resulting in this alternate reality."

inputs = torch.tensor(tokenizer.encode(text))[:max_length].reshape(1,-1).to(device)
logits, emb = model(inputs)
print('Dimension of embeddings:', emb[0].shape)