# Test the `map()` Function to Format Datasets

## Install and Import Dependencies

In [1]:
!pip install -qU datasets pprint

[31mERROR: Could not find a version that satisfies the requirement pprint (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for pprint[0m[31m
[0m

In [2]:
from datasets import load_dataset
import pprint
import torch

## Load Dataset

In [3]:
dataset = load_dataset("claudios/code_search_net", "go", split="train[:10]")

Downloading readme:   0%|          | 0.00/13.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/139M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.39M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.97M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/317832 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/14291 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14242 [00:00<?, ? examples/s]

In [4]:
print(dataset)

Dataset({
    features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_documentation_string', 'func_code_url'],
    num_rows: 10
})


## Define Prompt Format Function for a Single Sample

In [5]:
# EOS Token is required to stop open-ended generation and eventual hallucination
# EOS_TOKEN = tokenizer.eos_token
EOS_TOKEN = "<|end_of_text|>"

In [6]:
def formatFunctionSample(sample):
    language = sample['language']
    instruction = f"What does this {language} function do?"
    inputText = sample['func_code_string']
    outputText = sample['func_documentation_string']

    return {
        "instruction": instruction,
        "input": inputText,
        "output": outputText
    }

## Use the `map()` Method to Format the Dataset

In [7]:
formattedDataset = dataset.map(formatFunctionSample, remove_columns=dataset.column_names)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

## Sample Formatted Dataset

In [8]:
len(formattedDataset)

10

In [9]:
pprint.pp(formattedDataset[1])

{'instruction': 'What does this go function do?',
 'input': 'func (s *UpdateSkillGroupInput) SetDescription(v string) '
          '*UpdateSkillGroupInput {\n'
          '\ts.Description = &v\n'
          '\treturn s\n'
          '}',
 'output': "// SetDescription sets the Description field's value."}


## Create Dataset of *Alpaca* Formatted Samples

In [10]:
sample = formattedDataset[2]

In [11]:
alpacaFormatString = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

alpacaFormatString += EOS_TOKEN

In [13]:
print(alpacaFormatString.format(sample['instruction'], sample['input'], sample['output']))

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
What does this go function do?

### Input:
func (s *UpdateSkillGroupInput) SetSkillGroupArn(v string) *UpdateSkillGroupInput {
	s.SkillGroupArn = &v
	return s
}

### Response:
// SetSkillGroupArn sets the SkillGroupArn field's value.<|end_of_text|>


In [16]:
def formatFunctionDataset(samples):
    alpacaFormattedSamples = []
    for sample in samples:
        alpacaFormattedSamples.append(alpacaFormatString.format(sample['instruction'], sample['input'], sample['output']))
        
    return { "text": alpacaFormattedSamples }

In [19]:
pprint.pp(formattedDataset[0])

{'instruction': 'What does this go function do?',
 'input': 'func (s *UpdateSkillGroupInput) Validate() error {\n'
          '\tinvalidParams := request.ErrInvalidParams{Context: '
          '"UpdateSkillGroupInput"}\n'
          '\tif s.Description != nil && len(*s.Description) < 1 {\n'
          '\t\tinvalidParams.Add(request.NewErrParamMinLen("Description", 1))\n'
          '\t}\n'
          '\tif s.SkillGroupName != nil && len(*s.SkillGroupName) < 1 {\n'
          '\t\tinvalidParams.Add(request.NewErrParamMinLen("SkillGroupName", '
          '1))\n'
          '\t}\n'
          '\n'
          '\tif invalidParams.Len() > 0 {\n'
          '\t\treturn invalidParams\n'
          '\t}\n'
          '\treturn nil\n'
          '}',
 'output': '// Validate inspects the fields of the type to determine if they '
           'are valid.'}


In [24]:
type(formattedDataset)

datasets.arrow_dataset.Dataset

In [27]:
texts = formatFunctionDataset(formattedDataset)
pprint.pp(texts)

{'text': ['Below is an instruction that describes a task, paired with an input '
          'that provides further context. Write a response that appropriately '
          'completes the request.\n'
          '\n'
          '### Instruction:\n'
          'What does this go function do?\n'
          '\n'
          '### Input:\n'
          'func (s *UpdateSkillGroupInput) Validate() error {\n'
          '\tinvalidParams := request.ErrInvalidParams{Context: '
          '"UpdateSkillGroupInput"}\n'
          '\tif s.Description != nil && len(*s.Description) < 1 {\n'
          '\t\tinvalidParams.Add(request.NewErrParamMinLen("Description", 1))\n'
          '\t}\n'
          '\tif s.SkillGroupName != nil && len(*s.SkillGroupName) < 1 {\n'
          '\t\tinvalidParams.Add(request.NewErrParamMinLen("SkillGroupName", '
          '1))\n'
          '\t}\n'
          '\n'
          '\tif invalidParams.Len() > 0 {\n'
          '\t\treturn invalidParams\n'
          '\t}\n'
          '\treturn nil\n'

In [28]:
finalDataset = formattedDataset.map(formatFunctionDataset)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

TypeError: string indices must be integers