# Train code search sentence embeddings

In [1]:
!pip install --upgrade -q jax jaxlib

[K     |████████████████████████████████| 686kB 5.4MB/s 
[K     |████████████████████████████████| 46.2MB 125kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone


In [2]:
!test -f train_code_search_net.py || wget -q https://raw.githubusercontent.com/nreimers/flax-sentence-embeddings/main/code-search-net/train_code_search_net.py .
!test -f requirements.txt || wget -q https://raw.githubusercontent.com/nreimers/flax-sentence-embeddings/main/code-search-net/requirements.txt .
!test -d trainer || mkdir trainer/
!test -d trainer/loss || mkdir trainer/loss/
!test -f trainer/loss/custom.py || wget -q https://raw.githubusercontent.com/nreimers/flax-sentence-embeddings/049d30a44f83a266fbc6f71e22e285e4d0a8d30b/trainer/loss/custom.py .
!test -f custom.py && mv custom.py ./trainer/loss/custom.py
!test -f trainer/loss/basic.py || wget -q https://raw.githubusercontent.com/nreimers/flax-sentence-embeddings/049d30a44f83a266fbc6f71e22e285e4d0a8d30b/trainer/loss/basic.py .
!test -f basic.py && mv basic.py ./trainer/loss/basic.py
!test -d trainer/utils/ || mkdir trainer/utils/
!test -f trainer/utils/ops.py || wget -q https://raw.githubusercontent.com/nreimers/flax-sentence-embeddings/049d30a44f83a266fbc6f71e22e285e4d0a8d30b/trainer/utils/ops.py .
!test -f ops.py && mv ops.py ./trainer/utils/ops.py
!pip install -qr requirements.txt

[K     |████████████████████████████████| 122kB 5.3MB/s 
[K     |████████████████████████████████| 184kB 27.6MB/s 
[K     |████████████████████████████████| 2.5MB 32.1MB/s 
[K     |████████████████████████████████| 245kB 41.3MB/s 
[K     |████████████████████████████████| 1.8MB 38.9MB/s 
[K     |████████████████████████████████| 61kB 6.4MB/s 
[K     |████████████████████████████████| 3.3MB 23.7MB/s 
[K     |████████████████████████████████| 901kB 60.2MB/s 
[K     |████████████████████████████████| 122kB 54.7MB/s 
[K     |████████████████████████████████| 245kB 46.3MB/s 
[K     |████████████████████████████████| 133kB 50.9MB/s 
[K     |████████████████████████████████| 102kB 8.1MB/s 
[K     |████████████████████████████████| 174kB 46.7MB/s 
[K     |████████████████████████████████| 71kB 9.4MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone


In [13]:
# requirements
from sklearn.model_selection import train_test_split
import gzip
from tqdm import tqdm
import numpy as np
import csv

# for training script
from dataclasses import dataclass, field, asdict, replace
from functools import partial
from typing import Callable, List, Union

import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, struct, traverse_util
from flax.training import train_state
from flax.serialization import to_bytes, from_bytes
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from trainer.loss.custom import multiple_negatives_ranking_loss

import wandb
import json
import os

from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, FlaxAutoModel

In [4]:
#@title TPU config
USE_TPU = True  #@param {type:"boolean"}
# DEBUG = True  #@param {type:"boolean"}

# if DEBUG:
#   os.environ['XLA_HLO_DEBUG']="1"

if USE_TPU:
  # Google Colab "TPU" runtimes are configured in "2VM mode", meaning that JAX
  # cannot see the TPUs because they're not directly attached. Instead we need to
  # setup JAX to communicate with a second machine that has the TPUs attached.
  if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print('Connected to TPU.')
    print('TPU address is: {}'.format(os.environ['COLAB_TPU_ADDR']))
    print('Using {} devices:'.format(jax.device_count()))
    for d in jax.devices():
      print('\t{}'.format(d))
    print('XRT TPU config: {}'.format(os.environ['XRT_TPU_CONFIG']))

    # get the latest JAX and jaxlib
    # !pip install --upgrade -q jax jaxlib

    # Colab runtime set to TPU acces
    # import requests
    # import os
    # if 'TPU_DRIVER_MODE' not in globals():
    #   print('setting TPU driver mode.')
    #   url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
    #   resp = requests.post(url)
    #   TPU_DRIVER_MODE = 1

    # TPU driver as backend for JAX
    # from jax.config import config
    # config.FLAGS.jax_xla_backend = "tpu_driver"
    # config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
    # print(config.FLAGS.jax_backend_target)
    # Prevent GPU/TPU warning.
    # import jax; jax.config.update('jax_platform_name', 'cpu')
  else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


Connected to TPU.
TPU address is: 10.70.228.82:8470
Using 8 devices:
	TPU_0(host=0,(0,0,0,0))
	TPU_1(host=0,(0,0,0,1))
	TPU_2(host=0,(1,0,0,0))
	TPU_3(host=0,(1,0,0,1))
	TPU_4(host=0,(0,1,0,0))
	TPU_5(host=0,(0,1,0,1))
	TPU_6(host=0,(1,1,0,0))
	TPU_7(host=0,(1,1,0,1))
XRT TPU config: tpu_worker;0;10.70.228.82:8470


In [5]:
!test -f codesearchnet.jsonl.gz || wget -q https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/codesearchnet.jsonl.gz .
# !gzip -dk codesearchnet.jsonl.gz

In [6]:
def load_jsonl_dataset(file_name):
  with gzip.open(file_name, "rb") as f:
    dataset = [json.loads(jline) for jline in f.read().splitlines()]
    return dataset

X = load_jsonl_dataset('./codesearchnet.jsonl.gz')

X_train, X_val = train_test_split(X, test_size=0.1, random_state=42)
print("Splitted data:\n\t{:,} samples for train\n\t{:,}; samples for validation.".format(len(X_train), len(X_val)))

Splitted data:
	1,237,560 samples for train
	137,507; samples for validation.


In [7]:
idx = np.random.randint(len(X_train))
print('docstring:')
print('==========\n')
print(X_train[idx][0], '\n')

print('function:')
print('=========\n')
print(X_train[idx][1])

docstring:

// SetInstanceId sets the InstanceId field's value. 

function:

func (s *UpdateInstanceCustomHealthStatusInput) SetInstanceId(v string) *UpdateInstanceCustomHealthStatusInput {
	s.InstanceId = &v
	return s
}


In [8]:
# with open('tr.jsonl', 'w') as outfile:
#   for j in tqdm(X_train):
#       json.dump({'doc': j[0], 'function': j[1]}, outfile)
#       outfile.write('\n')

# with open('val.jsonl', 'w') as outfile:
#   for j in tqdm(X_val):
#       json.dump({'doc': j[0], 'function': j[1]}, outfile)
#       outfile.write('\n')

In [9]:
# write csv dataset

header = ['docstring', 'code']

with open('val.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)  # write the header
    writer.writerow(header)
    for j in tqdm(X_val):
      writer.writerow(j)  # write the data

with open('tr.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)  # write the header
    writer.writerow(header)
    for j in tqdm(X_train):
      writer.writerow(j)  # write the data

HBox(children=(FloatProgress(value=0.0, max=137507.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1237560.0), HTML(value='')))




In [10]:
# from datasets import load_dataset
# dataset_test = load_dataset("csv", data_files='tr.csv', split="train")
# dataset_test[0]

In [11]:
@dataclass
class TrainingArgs:
    model_id: str = "microsoft/codebert-base"
    max_epochs: int = 2
    batch_size_per_device: int = 32
    seed: int = 42
    lr: float = 2e-5
    init_lr: float = 1e-5
    warmup_steps: int = 2000
    weight_decay: float = 1e-3

    input1_maxlen: int = 128
    input2_maxlen: int = 128
    
    logging_steps: int = 20
    save_dir: str = "checkpoints"

    tr_data_files: List[str] = field(
        default_factory=lambda: [
            "tr.csv",
        ]
    )
        
    val_data_files: List[str] = field(
        default_factory=lambda: [
            "val.csv",
        ]
    )

    def __post_init__(self):
        self.batch_size = self.batch_size_per_device * jax.device_count()


def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
    decay_steps = num_train_steps - warmup_steps
    warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
    decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
    lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
    return lr


def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
    def weight_decay_mask(params):
        params = traverse_util.flatten_dict(params)
        mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
        return traverse_util.unflatten_dict(mask)
    lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
    tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
    return tx, lr


class TrainState(train_state.TrainState):
    loss_fn: Callable = struct.field(pytree_node=False)
    scheduler_fn: Callable = struct.field(pytree_node=False)


@partial(jax.pmap, axis_name="batch")
def train_step(state, model_input1, model_input2, drp_rng):
    train = True
    new_drp_rng, drp_rng = jax.random.split(drp_rng, 2)

    def loss_fn(params, model_input1, model_input2, drp_rng):
        def _forward(model_input):
            attention_mask = model_input["attention_mask"][..., None]
            embedding = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng)[0]
            attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

            embedding = embedding * attention_mask
            embedding = jnp.mean(embedding, axis=1)

            modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
            embedding = embedding / jnp.maximum(modulus, 1e-12)

            # gather all the embeddings on same device for calculation loss over global batch
            embedding = jax.lax.all_gather(embedding, axis_name="batch")
            embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

            return embedding

        embedding1, embedding2 = _forward(model_input1), _forward(model_input2)
        return state.loss_fn(embedding1, embedding2)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, model_input1, model_input2, drp_rng)
    state = state.apply_gradients(grads=grads)

    metrics = {"tr_loss": loss, "lr": state.scheduler_fn(state.step)}
    return state, metrics, new_drp_rng


@partial(jax.pmap, axis_name="batch")
def val_step(state, model_inputs1, model_inputs2):
    train = False

    def _forward(model_input):
        attention_mask = model_input["attention_mask"][..., None]
        embedding = state.apply_fn(**model_input, params=state.params, train=train)[0]
        attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

        embedding = embedding * attention_mask
        embedding = jnp.mean(embedding, axis=1)

        modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
        embedding = embedding / jnp.maximum(modulus, 1e-12)

        # gather all the embeddings on same device for calculation loss over global batch
        embedding = jax.lax.all_gather(embedding, axis_name="batch")
        embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

        return embedding

    embedding1, embedding2 = _forward(model_inputs1), _forward(model_inputs2)
    loss = state.loss_fn(embedding1, embedding2)
    return jnp.mean(loss)


def get_batched_dataset(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield dict(batch)


@dataclass
class DataCollator:
    tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer]
    input1_maxlen: int = 128
    input2_maxlen: int = 128

    def __call__(self, batch):
        # Currently only static padding; TODO: change below for adding dynamic padding support
        model_input1 = self.tokenizer(batch["docstring"], return_tensors="jax", max_length=self.input1_maxlen, truncation=True, padding="max_length")
        model_input2 = self.tokenizer(batch["code"], return_tensors="jax", max_length=self.input2_maxlen, truncation=True, padding="max_length")
        model_input1, model_input2 = dict(model_input1), dict(model_input2)
        return shard(model_input1), shard(model_input2)


def save_checkpoint(save_dir, state, save_fn=None, training_args=None):
    print(f"saving checkpoint in {save_dir}", end=" ... ")

    os.makedirs(save_dir, exist_ok=True)
    state = jax_utils.unreplicate(state)

    if save_fn is not None:
        # saving model in HF fashion
        save_fn(save_dir, params=state.params)
    else:
        path = os.path.join(save_dir, "flax_model.msgpack")
        with open(path, "wb") as f:
            f.write(to_bytes(state.params))

    # this will save optimizer states
    path = os.path.join(save_dir, "opt_state.msgpack")
    with open(path, "wb") as f:
        f.write(to_bytes(state.opt_state))

    if training_args is not None:
        path = os.path.join(save_dir, "training_args.json")
        with open(path, "w") as f:
            json.dump(asdict(training_args), f)

    print("done!!")


def prepare_dataset(args):
    tr_dataset = load_dataset("csv", data_files=args.tr_data_files, split="train")
    val_dataset = load_dataset("csv", data_files=args.val_data_files, split="train")

    # ensures similar processing to all splits at once
    dataset = DatasetDict(train=tr_dataset, validation=val_dataset)

    # columns_to_remove = ['repo', 'path', 'func_name', 'original_string', 'sha', 'url', 'partition']
    # dataset = dataset.remove_columns(columns_to_remove)

    # drop extra batch from the end
    for split in dataset:
        num_samples = len(dataset[split]) - len(dataset[split]) % args.batch_size
        dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(num_samples))

    print(dataset)
    tr_dataset, val_dataset = dataset["train"], dataset["validation"]
    return tr_dataset, val_dataset

    
def main(args, logger):
    os.makedirs(args.save_dir, exist_ok=True)
    
    model = FlaxAutoModel.from_pretrained(args.model_id)
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)

    data_collator = DataCollator(
        tokenizer=tokenizer,
        input1_maxlen=args.input1_maxlen,
        input2_maxlen=args.input2_maxlen,
    )

    tr_dataset, val_dataset = prepare_dataset(args)

    tx_args = {
        "lr": args.lr,
        "init_lr": args.init_lr,
        "warmup_steps": args.warmup_steps,
        "num_train_steps": (len(tr_dataset) // args.batch_size) * args.max_epochs,
        "weight_decay": args.weight_decay,
    }
    tx, lr = build_tx(**tx_args)

    state = TrainState.create(
        apply_fn=model.__call__,
        params=model.params,
        tx=tx,
        loss_fn=multiple_negatives_ranking_loss,
        scheduler_fn=lr,
    )
    state = jax_utils.replicate(state)

    rng = jax.random.PRNGKey(args.seed)
    drp_rng = jax.random.split(rng, jax.device_count())
    for epoch in range(args.max_epochs):
        # training step
        total = len(tr_dataset) // args.batch_size
        batch_iterator = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
        for i, batch in tqdm(enumerate(batch_iterator), desc=f"Running epoch-{epoch}", total=total):
            model_input1, model_input2 = data_collator(batch)
            state, metrics, drp_rng = train_step(state, model_input1, model_input2, drp_rng)

            if (i + 1) % args.logging_steps == 0:
                tr_loss = jax_utils.unreplicate(metrics["tr_loss"]).item()
                tqdm.write(str(dict(tr_loss=tr_loss, step=i+1)))
                logger.log({
                    "tr_loss": tr_loss,
                    "step": i + 1,
                }, commit=True)

        # evaluation
        val_loss  = jnp.array(0.)
        total = len(val_dataset) // args.batch_size
        val_batch_iterator = get_batched_dataset(val_dataset, args.batch_size, seed=None)
        for j, batch in tqdm(enumerate(val_batch_iterator), desc=f"evaluating after epoch-{epoch}", total=total):
            model_input1, model_input2 = data_collator(batch)
            val_step_loss = val_step(state, model_input1, model_input2)
            val_loss += jax_utils.unreplicate(val_step_loss)

        val_loss = val_loss.item() / (j + 1)
        print(f"val_loss: {val_loss}")
        logger.log({"val_loss": val_loss}, commit=True)
        
        save_dir = args.save_dir + f"-epoch-{epoch}"
        save_checkpoint(save_dir, state, save_fn=model.save_pretrained, training_args=args)

In [None]:
args = TrainingArgs()
logger = wandb.init(project="code-search-net", config=asdict(args))
logging_dict = dict(logger.config); logging_dict["save_dir"] += f"-{logger.id}"
args = replace(args, **logging_dict)

print(args)
main(args, logger)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

TrainingArgs(model_id='microsoft/codebert-base', max_epochs=2, batch_size_per_device=32, seed=42, lr=2e-05, init_lr=1e-05, warmup_steps=2000, weight_decay=0.001, input1_maxlen=128, input2_maxlen=128, logging_steps=20, save_dir='checkpoints-2a5xfp0y', tr_data_files=['tr.csv'], val_data_files=['val.csv'])




Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-f843dba51fffa426/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-f843dba51fffa426/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0. Subsequent calls will reuse this data.




Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-e124bbb4b2c79ce6/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-e124bbb4b2c79ce6/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0. Subsequent calls will reuse this data.
DatasetDict({
    train: Dataset({
        features: ['docstring', 'code'],
        num_rows: 1237504
    })
    validation: Dataset({
        features: ['docstring', 'code'],
        num_rows: 137472
    })
})


  "jax.host_count has been renamed to jax.process_count. This alias "
  "jax.host_id has been renamed to jax.process_index. This alias "


HBox(children=(FloatProgress(value=0.0, description='Running epoch-0', max=4834.0, style=ProgressStyle(descrip…

  lax._check_user_dtype_supported(dtype, "arange")
