In [1]:
import torch
from peft import PeftModel    
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/CodeLlama-7b-Instruct-hf"
#adapters_name = "codellama-instruct-repair"

m = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
#m = PeftModel.from_pretrained(m, adapters_name)
#m = m.merge_and_unload()
tok = AutoTokenizer.from_pretrained(model_name)
tok.pad_token = tok.eos_token

print(f"Successfully loaded the model {model_name} into memory")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Successfully loaded the model meta-llama/CodeLlama-7b-Instruct-hf into memory


In [2]:
prompt = "You are an automatic program repair tool. Your task is to fix the provided buggy code.\n\nThe following code contains a buggy function:\n```java\n    /**\n     * Attempts to read <code>len</code> bytes into the specified <code>b</code> array starting at <code>offset</code>\n     * from this InputStream.\n     * \n     * @param b\n     *            destination byte array\n     * @param offset\n     *            where to start writing the bytes\n     * @param len\n     *            maximum number of bytes to read\n     * \n     * @return number of bytes read\n     * @throws IOException\n     *             if an I/O error occurs.\n     * @throws NullPointerException\n     *             if the byte array parameter is null\n     * @throws IndexOutOfBoundsException\n     *             if offset, len or buffer size are invalid\n     */\n    public int read(byte b[], int offset, int len) throws IOException {\n        if (b == null) {\n            throw new NullPointerException();\n        } else if (offset < 0 || len < 0) {\n            throw new IndexOutOfBoundsException();\n        } else if (offset > b.length || offset + len > b.length) {\n            throw new IndexOutOfBoundsException();\n        } else if (len == 0) {\n            return 0;\n        } else {\n            /*\n             Rationale for while-loop on (readLen == 0):\n             -----\n             Base64.readResults() usually returns > 0 or EOF (-1).  In the\n             rare case where it returns 0, we just keep trying.\n\n             This is essentially an undocumented contract for InputStream\n             implementors that want their code to work properly with\n             java.io.InputStreamReader, since the latter hates it when\n             InputStream.read(byte[]) returns a zero.  Unfortunately our\n             readResults() call must return 0 if a large amount of the data\n             being decoded was non-base64, so this while-loop enables proper\n             interop with InputStreamReader for that scenario.\n             -----\n             This is a fix for CODEC-101\n            */\n                if (!base64.hasData()) {\n                    byte[] buf = new byte[doEncode ? 4096 : 8192];\n                    int c = in.read(buf);\n                    // A little optimization to avoid System.arraycopy()\n                    // when possible.\n                    if (c > 0 && b.length == len) {\n                        base64.setInitialBuffer(b, offset, len);\n                    }\n                    if (doEncode) {\n                        base64.encode(buf, 0, c);\n                    } else {\n                        base64.decode(buf, 0, c);\n                    }\n                }\n            return base64.readResults(b, offset, len);\n        }\n    }\n\n```\n\nThe code fails the following test:\n```java\n    /**\n     * Test for the CODEC-101 bug:  InputStream.read(byte[]) should never return 0\n     * because Java's builtin InputStreamReader hates that.\n     *\n     * @throws Exception for some failure scenarios.\n     */\n    public void testCodec101() throws Exception {\n        byte[] codec101 = StringUtils.getBytesUtf8(Base64TestData.CODEC_101_MULTIPLE_OF_3);\n        ByteArrayInputStream bais = new ByteArrayInputStream(codec101);\n        Base64InputStream in = new Base64InputStream(bais);\n        byte[] result = new byte[8192];\n        int c = in.read(result);\n        assertTrue(\"Codec101: First read successful [c=\" + c + \"]\", c > 0);\n\n        c = in.read(result);\n        assertTrue(\"Codec101: Second read should report end-of-stream [c=\" + c + \"]\", c < 0);\n    }\n\n```\n\nWith the following test error:\n```\njunit.framework.AssertionFailedError: Codec101: First read successful [c=0]\n```\n\nPlease provide a fixed version of the buggy function, and only that function:\n"

def format_prompt(prompt):
    return f"<s>[INST] {prompt} [\INST]"

In [7]:
m.eval()

with torch.no_grad():
    inputs = tok([format_prompt(prompt), format_prompt(prompt)], return_tensors="pt")
    outputs = m.generate(**inputs, max_length=16384)
    print(tok.batch_decode(outputs, skip_special_tokens=True)[0].split("[\INST]")[1])
    print(tok.batch_decode(outputs, skip_special_tokens=True)[1].split("[\INST]")[1])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


 I have fixed the buggy function.
```java
    /**
     * Attempts to read <code>len</code> bytes into the specified <code>b</code> array starting at <code>offset</code>
     * from this InputStream.
     * 
     * @param b
     *            destination byte array
     * @param offset
     *            where to start writing the bytes
     * @param len
     *            maximum number of bytes to read
     * 
     * @return number of bytes read
     * @throws IOException
     *             if an I/O error occurs.
     * @throws NullPointerException
     *             if the byte array parameter is null
     * @throws IndexOutOfBoundsException
     *             if offset, len or buffer size are invalid
     */
    public int read(byte b[], int offset, int len) throws IOException {
        if (b == null) {
            throw new NullPointerException();
        } else if (offset < 0 || len < 0) {
            throw new IndexOutOfBoundsException();
        } else if (offset > b.length || off