In [25]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig

In [26]:
model = 'deepseek-ai/deepseek-coder-1.3b-instruct'

In [27]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

# Quantization Config
quant_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    model,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=quant_config
)

In [28]:
from string import Template
from datasets import load_dataset

dataset = load_dataset('google/code_x_glue_cc_clone_detection_big_clone_bench')
dataset_train = dataset['train']
dataset_test = dataset['test']
dataset_val = dataset['validation']

instruction_template = Template(
"""
# Instruction
Do code 1 and code 2 solve identical problems with the same
inputs and outputs? answer with True or False and no explanation.
Example:
code 1: private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + ".json";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + ".new"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + ".new")).renameTo(new File(filename));\n    }\n'
code 2: public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        transform(style, in, out);\n        in.close();\n    }\n'
Answer:
False

Input
$question
Answer:
"""
    
)




In [29]:
dataset_test = dataset_test.select([1, 2, 3])
dataset_test

Dataset({
    features: ['id', 'id1', 'id2', 'func1', 'func2', 'label'],
    num_rows: 3
})

In [30]:
def contact_input(dataset):
    dataset = dataset.map(
        lambda example: {'input': 'code1:' + str(example['func1']) + '\n' + 'code2:' + str(example['func2'])},
        remove_columns=['func1', 'func2']
    )
    dataset = dataset.map(
        lambda example: {'prompt_input':instruction_template.substitute({"question" : example['input']})},
        remove_columns=['input']
    )

    # If 'label' column is of boolean type
    dataset = dataset.map(lambda example: {'label_output': 'True' if example['label'] else 'False'},
                          remove_columns=['label'])

    return dataset

# dataset_train = contact_input(dataset_train)
dataset_test = contact_input(dataset_test)
# dataset_val = contact_input(dataset_val)


In [31]:
xx = dataset_test[0]
xx['prompt_input']

'\n# Instruction\nDo code 1 and code 2 solve identical problems with the same\ninputs and outputs? answer with True or False and no explanation.\nExample:\ncode 1: private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + ".json";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + ".new"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + ".new")).renameTo(new File(filename));\n    }\n\'\ncode 2: public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        transform(style, in, out);\n        in.close();\n    }\n\'\nAnswer:\nFalse\n\nInput\ncod

In [32]:
from dataclasses import dataclass
from transformers import AutoTokenizer, BatchEncoding

@dataclass
class SimpleCollator:
    tokenizer: AutoTokenizer
    config: dict 
    
    def __call__(self, examples: list) -> dict:
        batch = BatchEncoding(
            {
                k: [examples[i][k] for i in range(len(examples))]
                for k, v in examples[0].items()
            }
        )

        encoded_inputs = self.tokenizer(
            batch[self.config["input_column"]], 
            max_length = 240, 
            padding=True, 
            truncation=True,
            return_tensors="pt"
        )

        encoded_targets = self.tokenizer(
            batch[self.config["output_column"]], max_length = 10, padding=True, truncation=True,
            return_tensors="pt"
        )
        encoded_inputs["labels"] = encoded_targets["input_ids"]

        return encoded_inputs

collator = SimpleCollator(tokenizer, {"input_column": "prompt_input", "output_column": "label_output"})

In [33]:
# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
# 
# peft_config = LoraConfig(
#     task_type=TaskType.SEQ_2_SEQ_LM, 
#     inference_mode=False, 
#     target_modules=["q", "k", "v"],
#     r=8, 
#     lora_alpha=32, 
#     lora_dropout=0.5
# )
# 
# model = get_peft_model(model, peft_config)

In [34]:
from torch.utils.data import DataLoader

# Prepare Dataloaders
train_dl = DataLoader(
    dataset_train,
    batch_size=4,
    pin_memory=True,
    shuffle=False,
    collate_fn=collator
)

val_dl = DataLoader(
    dataset_val,
    batch_size=16,
    pin_memory=True,
    shuffle=True,
    collate_fn=collator
)
test_dl = DataLoader(
    dataset_test,
    batch_size=16,
    pin_memory=True,
    shuffle=False,
    collate_fn=collator
)


In [35]:
import tqdm.notebook as tqdm



all_preds = []
for batch in tqdm.tqdm(test_dl, total = len(test_dl)):
    
    preds = model.generate(**batch, max_new_tokens=10)
    outputs = tokenizer.batch_decode(preds, skip_special_tokens=True)
    all_preds.extend(outputs)

all_preds

  0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.


['\n# Instruction\nDo code 1 and code 2 solve identical problems with the same\ninputs and outputs? answer with True or False and no explanation.\nExample:\ncode 1: private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + ".json";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + ".new"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + ".new")).renameTo(new File(filename));\n    }\n\'\ncode 2: public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        out.write(',
 '\n# Instruction\nDo code 1 and code 2 solve identical problems with the 

In [36]:
print(all_preds)

['\n# Instruction\nDo code 1 and code 2 solve identical problems with the same\ninputs and outputs? answer with True or False and no explanation.\nExample:\ncode 1: private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + ".json";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + ".new"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + ".new")).renameTo(new File(filename));\n    }\n\'\ncode 2: public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        out.write(', '\n# Instruction\nDo code 1 and code 2 solve identical problems with the s