In [1]:
import os
import youtokentome as yttm
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("JetBrains_model", gguf_file="flcc.model", torch_dtype=torch.float32)

Converting and de-quantizing GGUF tensors...:   0%|          | 0/57 [00:00<?, ?it/s]

In [28]:
prefix_code = """
public class HelloWorld {
    public static void main(String[] args) {
        System.out.
"""

bpe = yttm.BPE(model="JetBrains_model/flcc.bpe")

tokens = bpe.encode(prefix_code, output_type=yttm.OutputType.ID)

# Print tokenization to verify
print(f"Token IDs: {tokens}")
print(f"Decoded back: {bpe.decode([tokens])[0]}")

# Convert to tensor
input_ids = torch.tensor([tokens]).to(model.device)
attention_mask = torch.ones_like(input_ids).to(model.device)

# Generate completions
with torch.no_grad():
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=20,  # Controls how many new tokens to generate
        do_sample=True,     # Use sampling
        top_p=0.95,         # Nucleus sampling parameter
        top_k=40,           # Limit to top 50 tokens at each step
        temperature=0.6,    # Controls randomness (lower is more deterministic)
        num_return_sequences=1,
        pad_token_id=model.config.eos_token_id,  # Set padding token to EOS token
        repetition_penalty=1.2
    )

# Print the raw output to understand what's happening
print("Raw output tokens:", outputs[0].tolist())

# Get the generated content after the prompt
original_length = len(tokens)
generated_tokens = outputs[0][original_length:].tolist()

# Decode only the newly generated tokens
completion = bpe.decode([generated_tokens])[0]

# Format the output
print("Input prompt:")
print(prefix_code.strip())
print("\nGenerated completion:")
print(f"        System.out.{completion}")

Token IDs: [4, 8251, 5588, 4529, 142, 4, 845, 1907, 6467, 95, 12885, 315, 9499, 34, 6122, 699, 2772, 4, 1887, 6289, 14605, 4]
Decoded back: 
public class HelloWorld {
    public static void main(String[] args) {
        System.out.

Raw output tokens: [4, 8251, 5588, 4529, 142, 4, 845, 1907, 6467, 95, 12885, 315, 9499, 34, 6122, 699, 2772, 4, 1887, 6289, 14605, 4, 8380, 2112, 5588, 136, 4529, 5498, 60, 4, 36, 421, 545, 210, 4, 434, 1380, 132, 12109, 1921, 918, 143]
Input prompt:
public class HelloWorld {
    public static void main(String[] args) {
        System.out.

Generated completion:
        System.out.            println("Hello, World!");
}   }")
private val root: Element = VectorElement()
