In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [2]:
# Check if MPS is available
device = "mps" if torch.backends.mps.is_available() else "cpu"

# Check if CUDA is available
# device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/deepseek-coder-1.3b-instruct",
    trust_remote_code=True,
    torch_dtype=torch.float16  # MPS prefers float16 instead of bfloat16
).to(device)

# If you are using CUDA, you can use bfloat16 instead of float16
# model = AutoModelForCausalLM.from_pretrained(
#     "deepseek-ai/deepseek-coder-1.3b-instruct",
#     trust_remote_code=True,
#     torch_dtype=torch.bfloat16
# ).to(device)

print("Running on:", device)

Running on: mps


In [4]:
from sentence_transformers import SentenceTransformer, util

In [5]:
# Load a sentence similarity model
similarity_model = SentenceTransformer('sentence-transformers/gtr-t5-large')

In [6]:
def generate_code_description(messages):
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
    # tokenizer.eos_token_id is the id of <|EOT|> token
    outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
    description = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
    return description

In [7]:
def compare_descriptions(generated_desc, user_desc):
    embeddings = similarity_model.encode([generated_desc, user_desc], convert_to_tensor=True)
    similarity_score = util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
    return similarity_score

In [8]:
def initial_code_verification(messages, user_description, threshold=0.75):
    generated_description = generate_code_description(messages)
    similarity_score = compare_descriptions(generated_description, user_description)

    result = {
        "generated_description": generated_description,
        "similarity_score": similarity_score,
        "matches_expectation": similarity_score >= threshold
    }

    return result

In [9]:
prefix = "Write a short summary for the following code: "

code_snippet = """
def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n-1)
"""

In [10]:
messages=[
    { 'role': 'user', 'content': prefix+code_snippet}
]

In [11]:
user_description = "This function calculates the factorial of a given number recursively."

In [12]:
result = initial_code_verification(messages, user_description)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [13]:
print("User Description:", user_description)

User Description: This function calculates the factorial of a given number recursively.


In [14]:
print("Generated Description:", result["generated_description"])

Generated Description: This Python function calculates the factorial of a number using recursion. The factorial of a number is the product of all positive integers less than or equal to that number. For example, the factorial of 5 is 5*4*3*2*1 = 120.

The function takes an integer `n` as input and returns the factorial of `n`. If `n` is 0, the function returns 1 (since the factorial of 0 is defined to be 1). Otherwise, the function calls itself with the argument `n-1` and multiplies the result by `n`. This continues until `n` is 0, at which point the function returns 1.



In [15]:
print("Similarity Score:", result["similarity_score"])

Similarity Score: 0.8555039167404175


In [16]:
print("Matches Expectation:", result["matches_expectation"])

Matches Expectation: True
