### REQUIRED PIP INSTALLS

In [None]:
!pip install transformers kaggle datasets tqdm

### DOWLOAD THE DATASET

In [None]:
import kaggle
from pathlib import Path

def download_dataset_from_kaggle(path="data"):
    """
    Download the CodeSearchNet dataset from Kaggle.
    Make sure to have the Kaggle API token in ~/.kaggle/kaggle.json

    Returns:
        str: Path to the downloaded dataset.
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    kaggle.api.authenticate()
    kaggle.api.dataset_download_files(
        "omduggineni/codesearchnet", path=path, unzip=True
    )
    
#download_dataset_from_kaggle()

### LOAD THE DATASET

In [2]:
import glob

from datasets import load_dataset
from pathlib import Path

LANG = 'python'

def load_local_dataset(lang="all", path="data"):
    """
    Load a local dataset from the downloaded Kaggle dataset.

    Args:
        lang (str): The language to be used for the dataset.
        path (str, optional): Path to the downloaded dataset. Defaults to "data".

    Returns:
        Dataset: dataset loaded from local files
    """
    path = Path(path)

    if lang != "all":
        # Read the downloaded dataset
        path = path / lang / lang / "final/jsonl"
        dataset = load_dataset(
            "json",
            data_files={
                "train": glob.glob(path.as_posix() + "/train/*.jsonl"),
                "validation": glob.glob(path.as_posix() + "/valid/*.jsonl"),
                "test": glob.glob(path.as_posix() + "/test/*.jsonl"),
            },
        )
    else:
        train_files = glob.glob(path.as_posix() + "/**/train/*.jsonl", recursive=True)
        valid_files = glob.glob(path.as_posix() + "/**/valid/*.jsonl", recursive=True)
        test_files = glob.glob(path.as_posix() + "/**/test/*.jsonl", recursive=True)
        dataset = load_dataset(
            "json",
            data_files={
                "train": train_files,
                "validation": valid_files,
                "test": test_files,
            },
        )

    return dataset

dataset = load_local_dataset(LANG)

Found cached dataset json (/home/nando/.cache/huggingface/datasets/json/default-dea67a2e36b13223/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

### LOAD THE MODEL

In [4]:
from transformers import RobertaTokenizer, T5ForConditionalGeneration
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base', truncation_side='right')
#tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
model = T5ForConditionalGeneration.from_pretrained(f'Salesforce/codet5-base-codexglue-sum-{LANG}').to("cuda")

### GENERATE THE SUMMARIES AND ANOTATE THE DATASET

In [None]:
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import json

BATCH_SIZE = 32
PARTITION = 'train'
#all_summaries = []

def tokenization_collator(batch_sample):
  code = list(map(lambda x: x['original_string'], batch_sample))
  return tokenizer(code, 
                   return_tensors='pt', 
                   padding='longest', 
                   truncation=True).input_ids.to('cuda')


for PARTITION in ['train', 'test', 'validation']:
  dataloader = DataLoader(dataset[PARTITION], 
                          batch_size=BATCH_SIZE, 
                          collate_fn=tokenization_collator)

  with open(PARTITION + f'-{LANG}.jsonl', 'w') as f:
    for batch_num, batch_data in tqdm(enumerate(iter(dataloader)), 
                                      total=(len(dataset[PARTITION]) // BATCH_SIZE) + 1):
      
      generated_ids = model.generate(batch_data, max_length=1024)
      summaries = tokenizer.batch_decode(generated_ids.squeeze(), 
                                         skip_special_tokens=True)
      
      #all_summaries.extend(summaries)
      for summary in summaries:
        f.write(json.dumps({'summary': summary}))
        f.write('\n')

  0%|          | 0/12881 [00:00<?, ?it/s]

In [4]:
import gzip, json

with gzip.open(f'dataset_{LANG}.jsonl.gz', 'w') as w:
  for PARTITION in ['train', 'test', 'validation']:
    with open(PARTITION + f'-{LANG}.jsonl') as f:
      for line_pos, line in enumerate(f.readlines()):
        summary = json.loads(line)
        if len(summary['summary']) > 350:
            continue
            
        d = dataset[PARTITION][line_pos].copy()
        d.update(summary)
        b = json.dumps(d)
        w.write(b.encode('UTF-8'))
        w.write('\n'.encode('UTF-8'))


In [5]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
processed_dataset = load_dataset('json', data_files=f'./dataset_{LANG}.jsonl.gz')
processed_dataset.push_to_hub(f'Nan-Do/code-search-net-{LANG}')

Downloading and preparing dataset json/default to /home/nando/.cache/huggingface/datasets/json/default-8d86563eb7816a96/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/nando/.cache/huggingface/datasets/json/default-8d86563eb7816a96/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Pushing split train to the Hub.


Pushing dataset shards to the dataset hub:   0%|          | 0/4 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/114 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/114 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/114 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/114 [00:00<?, ?ba/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]