# [mtj-softtuner](https://github.com/VE-FORBRYDERNE/mtj-softtuner) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/VE-FORBRYDERNE/mtj-softtuner/blob/main/mtj-softtuner.ipynb) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
### (Unofficial Mesh Transformer JAX soft-tuning notebook)

Create, in Colab, soft prompts compatible with KoboldAI (United) and [mkultra](https://github.com/corolla-johnson/mkultra) for your favourite GPT-J-6B-based or GPT-Neo-2.7B-based model!

See this paper https://arxiv.org/pdf/2104.08691.pdf for more information about what a soft prompt is.

---

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0). Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

<br/><br/><br/><br/>

---

<br/><br/><br/><br/>

# 1. Install and set up dependencies

## If you, at any point, restart your Colab instance by using Runtime > Restart Runtime, you will have to run this cell again.

Run the cell below to do some basic setup. Sometimes this cell may throw an error about "deadline exceeded", in which case you should restart your Colab instance (Runtime > Restart Runtime) and then run this cell again.

In [None]:
# @markdown There is an extremely large amount of code in this cell, so the code is hidden by default. Feel free to press "Show code" and look through the code, though.

import os
import sys
import termcolor


class NotebookException(Exception):
    """This kind of exception will not show a traceback when raised"""

    pass


ipython = get_ipython()
if not hasattr(ipython._showtraceback, "NOTEBOOK_EXCEPTION_FLAG"):

    def __exception_handler(exception_class, message, traceback):
        if issubclass(exception_class, NotebookException):
            print(termcolor.colored(f"ERROR:  {message}", "red"), file=sys.stderr)
        else:
            __exception_handler.old_showtraceback(exception_class, message, traceback)

    __exception_handler.old_showtraceback = ipython._showtraceback
    __exception_handler.NOTEBOOK_EXCEPTION_FLAG = True
    ipython._showtraceback = __exception_handler

if "COLAB_TPU_ADDR" not in os.environ:
    raise NotebookException(
        "\nThis notebook only works in a TPU instance.\nGo to Runtime > Change runtime type and change Hardware accelerator to TPU.\n"
    )


print(termcolor.colored("Installing mesh-transformer-jax...", "magenta"))
%cd /content
!rm -r mesh-transformer-jax/
# You can use the official mesh-transformer-jax if you're going to tune a
# GPT-J-6B model, but you must use this fork if you want to tune any other type
# of model.
!git clone https://github.com/VE-FORBRYDERNE/mesh-transformer-jax
!pip install -r mesh-transformer-jax/requirements.txt
!pip install mesh-transformer-jax/ jax==0.2.12 torch progressbar2 ftfy
!pip uninstall -y optax
!pip install -U git+git://github.com/deepmind/optax.git@00c3256466196d0f6c523a8854323b08ef534600


print(termcolor.colored("\nInitializing JAX...\n", "magenta"))
import requests  # Only for connecting to Colab TPU and for nothing else
import progressbar
from tqdm.auto import tqdm
import IPython
import multiprocessing
import time
import datetime
import pickle
import json
import uuid
import base64
import ftfy
import zipfile
import google.colab
from typing import Iterable, List, Optional
import io
import pathlib
import math
import functools
import jax
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import torch
import mesh_transformer
import mesh_transformer.util
import mesh_transformer.layers
import mesh_transformer.transformer_shard


seq = 2048

starting_step = -1


if not os.path.isdir(".notebook_pickle"):
    os.mkdir(".notebook_pickle")

# Required for certain optax optimizers to work properly with haiku modules
# as per https://github.com/deepmind/dm-haiku/issues/191
os.environ["HAIKU_FLATMAPPING"] = "0"

# In JAX 0.2.13, jax.tree_multimap was renamed to jax.tree_map and
# jax.tree_multimap became an alias of this new jax.tree_map,
# and optax depends on this change, so we're shimming jax.tree_map calls to
# go to the current jax.tree_multimap (which is the same as the JAX 0.2.13
# versions of both jax.tree_map and jax.tree_multimap) for compatibility
# with optax.
jax.tree_map = jax.tree_multimap


# Colab doesn't seem to like ReplicatedLayerNorm, so we're just going to use
# haiku's standard LayerNorm modules, which we can do because we aren't going
# to train any layernorm parameters.
old_getnorm = mesh_transformer.layers.getnorm


def getnorm(type: str):
    if type == "layernorm":
        return hk.LayerNorm(-1, True, True, name="replicated_layer_norm")
    elif type == "layernorm-nobias":
        return hk.LayerNorm(-1, True, False, name="replicated_layer_norm")
    else:
        return old_getnorm(type)


mesh_transformer.layers.getnorm = getnorm


def shatter(in_axes: str, out_axes: str):
    """Helper function for setting up JAX xmaps.

    This is a decorator that creates an xmapped version of a function.
    Your function's arguments should be NumPy arrays or JAX pytrees with
    NumPy arrays at the leaves.  Your actual function will be run 8 times in
    parallel (once for each TPU core).  You can specify for some of your
    function's arguments to be sharded (using a different value for that
    argument on each TPU).  Sharded arguments will be split along the leading
    dimension into 8 subarrays, for example if your sharded argument is an
    array with shape (8, 10, 2), each of the 8 versions of your function will
    receive one subarray with shape (10, 2).  If the leading dimension of your
    sharded arguments aren't equal to 8, you will receive an error.
    The return value(s) of your function can also be sharded, which will
    result in each of the 8 values from your 8 versions of your function to be
    concatenated together along a new leading axis.  Non-sharded arguments
    should have a leading axis of size 1.

    Note: Your function must have at least one sharded argument AND one
    non-sharded argument, otherwise an error will be thrown.  Also your
    function shouldn't have any default arguments, *args or **kwargs.

    Parameters
    ----------
    in_axes : str
        A string with the same length as the number of parameters your function
        has, where each character of the string is 's' or 'b'.  's' means the
        corresponding parameter of your function should be sharded; 'b' means
        the corresponding parameter of your function should not be sharded.
    out_axes : str
        A string with the same length as the number of returns your function
        has, where each character of the string is 's' or 'b'.

    Returns
    -------
    Callable[Callable[..., Any], Callable[..., Any]]
        A function that takes one argument (the function that you want to be
        xmapped) and returns the xmapped version of your function.
    """
    in_axes = tuple(map(lambda c: ["batch" if c == "b" else "shard", ...], in_axes))
    out_axes = tuple(map(lambda c: ["batch" if c == "b" else "shard", ...], out_axes))
    if len(in_axes) == 1:
        in_axes = in_axes[0]
    if len(out_axes) == 1:
        out_axes = out_axes[0]
    return lambda fun: jax.experimental.maps.xmap(
        fun=fun,
        in_axes=in_axes,
        out_axes=out_axes,
        donate_argnums=(0,),
        axis_resources={"shard": "mp", "batch": "dp"},
    )


class EmbeddingShard(mesh_transformer.transformer_shard.EmbeddingShard):
    """
    A version of Mesh Transformer JAX's EmbeddingShard with a trainable
    soft prompt module
    """

    def __init__(self, config: dict, **kwargs):
        super().__init__(config, **kwargs)
        self.softtune_in_dim = soft_in_dim
        self.softtune_in_dim_per_shard = math.ceil(
            self.softtune_in_dim / config["cores_per_replica"]
        )
        self.softtune_proj = hk.Linear(
            self.out_dim,
            w_init=hk.initializers.TruncatedNormal(
                stddev=1 / np.sqrt(self.softtune_in_dim)
            ),
            with_bias=False,
            name="softtune_linear",
        )

    def __call__(self, x: jnp.array, **kwargs) -> jnp.array:
        pe_length = kwargs.get("pe_length", 0)
        shard_start_index = jax.lax.axis_index("shard") * self.in_dim_per_shard
        proj_out = self.proj(
            jax.nn.one_hot(x - shard_start_index, self.in_dim_per_shard)
        )
        mask = jnp.broadcast_to((x < self.in_dim)[:, jnp.newaxis], proj_out.shape)
        proj_out = jnp.where(mask, proj_out, 0)
        if (
            not kwargs.get("mtj_softtuner_disable_pe", False)
            and self.positional_embeddings is not None
        ):
            pe_length = jnp.int32(pe_length)
            shard_roll_index = jnp.int32(
                jax.lax.axis_index("shard") * self.out_dim_per_shard
            )
            pos_embed = jnp.pad(
                self.positional_embeddings,
                ((0, 0), (0, self.out_dim - self.out_dim_per_shard)),
            )
            pos_embed = jnp.roll(pos_embed, shard_roll_index, axis=1)
            pos_embed = jnp.roll(pos_embed, -pe_length, axis=0)[-proj_out.shape[0] :]
            proj_out += pos_embed
        soft_shard_start_index = (
            jax.lax.axis_index("shard") * self.softtune_in_dim_per_shard
        )
        proj_out += self.softtune_proj(
            jax.nn.one_hot(
                x - soft_shard_start_index - self.in_dim, self.softtune_in_dim_per_shard
            )
        )
        return mesh_transformer.util.g_psum(proj_out)


mesh_transformer.transformer_shard.EmbeddingShard = EmbeddingShard


class EmbeddingCausalTransformer(mesh_transformer.transformer_shard.CausalTransformer):
    """
    A version of Mesh Transformer JAX's CausalTransformer with a function for
    embedding a 1D array of token IDs and returning the embedding matrix
    """

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        @shatter("sb", "b")
        def _get_embedding_matrix(params: dict, tokens: jnp.array) -> jnp.array:
            @hk.without_apply_rng
            @hk.transform
            def inner(tokens: jnp.array):
                transformer = mesh_transformer.transformer_shard.CausalTransformerShard(
                    self.config
                )
                return transformer.embed(tokens, mtj_softtuner_disable_pe=True)

            return inner.apply(params, tokens)

        self._get_embedding_matrix = _get_embedding_matrix

    def get_embedding_matrix(self, tokens: np.array) -> jnp.array:
        """Embeds the given array of tokens.

        Parameters
        ----------
        tokens : numpy.array
            A 1-dimensional NumPy/jax.numpy array with dtype `numpy.uint32` or
            `jax.numpy.uint32` containing the IDs of the tokens you want to
            embed.

        Returns
        -------
        jax.numpy.array
            Embedding matrix for your tokens as a 2-dimensional jax.numpy array
            with dtype `jax.numpy.float32` and shape `(len(tokens), d_model)`,
            where `d_model` is the embedding dimension (or "model dimension")
            of your model.
        """
        return self._get_embedding_matrix(
            self.state["params"],
            tokens[np.newaxis, :],
        )[0]


def read_ckpt_custom(pytree, dir: str, shards_in: int, load_opt: bool = True):
    """Loads the model's state from a checkpoint.

    Parameters
    ----------
    pytree
        State of the network (network.state).
    dir : str
        Path to the model checkpoint.  Must contain a trailing slash.
    shards_in : int
        Number of shards the model is broken up into.  Should be 8.
    load_opt : bool, default=True
        Whether or not to load the optimizer state from the model checkpoint
        if it has one.

    Returns
    -------
    Any
        Updated version of network.state.
    spmodule : str
        The trainable soft prompt module's module name, so that you can change
        the embedding matrix for the trainable soft prompt (assuming your
        embedding matrix is called `soft_embeddings`) by doing
        `network.state["params"][spmodule]["w"] = soft_embeddings`.
    """
    pieces = 16

    old_flattened, structure = jax.tree_flatten(pytree)

    soft_embeddings_mask, _ = jax.tree_flatten(
        hk.data_structures.map(
            lambda module_name, name, value: module_name.split("/~/", 3)[-1]
            == "softtune_linear",
            pytree["params"],
        )
    )
    assert sum(soft_embeddings_mask) == 1

    desync_mask, _ = jax.tree_flatten(
        hk.data_structures.map(
            lambda module_name, name, value: module_name.split("/~/", 3)[-1].startswith(
                "replicated_layer_norm"
            ),
            pytree["params"],
        )
    )

    original_opt_state = pytree["opt_state"]

    n_tensors = 0
    for file_index in range(pieces):
        n_tensors += len(np.load(f"{dir}shard_0/{file_index}.npz").keys())

    def _unshard(bar):
        unsharded = []
        tensor_index = progress_index = 0

        for file_index in range(pieces):
            array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()]
            for array_index in range(len(array_keys)):
                unstacked = []
                for shard_index in range(shards_in):
                    if (
                        tensor_index < len(desync_mask)
                        and desync_mask[tensor_index]
                        and shard_index > 0
                    ):
                        continue
                    if (
                        tensor_index < len(soft_embeddings_mask)
                        and soft_embeddings_mask[tensor_index]
                    ):
                        unsharded.append(
                            jnp.empty(
                                (
                                    shards_in,
                                    math.ceil(soft_in_dim / shards_in),
                                    params["d_model"],
                                ),
                                dtype=jnp.float32,
                            )
                        )
                        tensor_index += 1

                    npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz")
                    array = npz[array_keys[array_index]]
                    if array.dtype == "V2":
                        array.dtype = jnp.bfloat16
                    unstacked.append(array)

                if tensor_index < len(desync_mask) and desync_mask[tensor_index]:
                    x = network.move_xmap(
                        jnp.tile(unstacked[0], (shards_in, 1)),
                        np.zeros(params["cores_per_replica"]),
                    )
                else:
                    x = network.move_xmap(
                        jnp.stack(unstacked),
                        np.zeros(params["cores_per_replica"]),
                    )
                unsharded.append(x)

                bar.update(progress_index)

                assert (
                    x.shape == old_flattened[tensor_index].shape
                ), f"Incompatible checkpoints {x.shape} vs {old_flattened[tensor_index].shape}"
                progress_index += 1
                tensor_index += 1

        return unsharded

    print(
        "\nPlease wait while we load the model's tensors into the TPU memory.",
        flush=True,
    )
    with progressbar.ProgressBar(
        max_value=n_tensors,
        widgets=[
            progressbar.AnimatedMarker(
                "⡀⡁⡂⡃⡄⡅⡆⡇⡈⡉⡊⡋⡌⡍⡎⡏⡐⡑⡒⡓⡔⡕⡖⡗⡘⡙⡚⡛⡜⡝⡞⡟⡠⡡⡢⡣⡤⡥⡦⡧⡨⡩⡪⡫⡬⡭⡮⡯⡰⡱⡲⡳⡴⡵⡶⡷⡸⡹⡺⡻⡼⡽⡾⡿⢀⢁⢂⢃⢄⢅⢆⢇⢈⢉⢊⢋⢌⢍⢎⢏⢐⢑⢒⢓⢔⢕⢖⢗⢘⢙⢚⢛⢜⢝⢞⢟⢠⢡⢢⢣⢤⢥⢦⢧⢨⢩⢪⢫⢬⢭⢮⢯⢰⢱⢲⢳⢴⢵⢶⢷⢸⢹⢺⢻⢼⢽⢾⢿⣀⣁⣂⣃⣄⣅⣆⣇⣈⣉⣊⣋⣌⣍⣎⣏⣐⣑⣒⣓⣔⣕⣖⣗⣘⣙⣚⣛⣜⣝⣞⣟⣠⣡⣢⣣⣤⣥⣦⣧⣨⣩⣪⣫⣬⣭⣮⣯⣰⣱⣲⣳⣴⣵⣶⣷⣸⣹⣺⣻⣼⣽⣾⣿"
            ),
            "  ",
            progressbar.ETA(),
            "   ",
            progressbar.Counter(),
            f"/{n_tensors}  ",
            progressbar.Percentage(),
            "  ",
            progressbar.Bar(left="[", right="]", marker="█"),
        ],
    ) as bar:
        try:
            unsharded = _unshard(bar)
        except AssertionError:
            load_opt = False  # no opt to load in ckpt
            del pytree["opt_state"]
            old_flattened, structure = jax.tree_flatten(pytree)
            unsharded = _unshard(bar)

    loaded_pytree = jax.tree_unflatten(structure, unsharded)

    print("\nFinished loading the model!\n\n\n")

    if not load_opt:
        loaded_pytree["opt_state"] = original_opt_state
    return loaded_pytree, next(
        state
        for state in loaded_pytree["params"]
        if state.split("/~/", 3)[-1] == "softtune_linear"
    )


# Training functions


@shatter("sb", "s")
def _init_opt_state(params: dict, aux: jnp.array):
    return mesh_transformer.util.to_f32(network.config["optimizer"].init(params))


def init_opt_state(params: dict):
    """Initializes optax state for the given haiku module parameters"""
    return _init_opt_state(params, np.empty(1))


def train_grad(state, ctx, tgt):
    """
    This function converts the model into a mathematical function with 6
    billion arguments (one for each parameter the model has) that returns a
    single number, the "loss" (lower loss is better), and then calculates the
    gradient of that function.  Of course, the gradient of a function with 6
    billion arguments is a vector of length 6 billion, and that's just a waste
    of memory given that we're only training the soft prompt and not the entire
    model, so we only keep the parts of the gradient that are from the
    soft prompt module.
    """

    @hk.without_apply_rng
    @hk.transform
    def inner(ctx, tgt):
        """
        This is the function that we're going to take the gradient of.
        """
        transformer = mesh_transformer.transformer_shard.CausalTransformerShard(
            network.config
        )
        out = transformer.loss(ctx, tgt, z_loss=True)
        return out["loss"], out["last_loss"]

    # Compute gradient and also the actual loss value
    val_grad_fn = jax.value_and_grad(inner.apply, has_aux=True)
    (loss, last_loss), grad = val_grad_fn(
        mesh_transformer.util.to_bf16(state["params"]), ctx, tgt
    )
    # Remove everything from the gradient that isn't related to the
    # soft prompt
    grad = grad[spmodule]
    # Calculate the Euclidean norm of the modified gradient
    gnorm = mesh_transformer.util.global_norm(grad)
    # Return the modified gradient, the loss and last loss, and the
    # norm of the modified gradient
    return grad, loss, last_loss, gnorm


@shatter("sbb", "bbbb")
def train_initial(state, ctx, tgt):
    """
    This function just runs train_grad.
    """
    grad, loss, last_loss, gnorm = train_grad(state, ctx, tgt)
    return grad, loss, last_loss, gnorm


@shatter("sbbb", "bbbb")
def train_intermediate(state, old_grad, ctx, tgt):
    """
    This function runs train_grad and then adds old_grad to the gradient vector
    it returns.
    """
    grad, loss, last_loss, gnorm = train_grad(state, ctx, tgt)
    grad = jax.tree_multimap(lambda a, b: a + b, old_grad, grad)
    return grad, loss, last_loss, gnorm


@shatter("sbb", "bbs")
def train_final(state, grad, gnorm):
    """
    This function takes the gradient accumulated by train_intermediate,
    applies the Adam algorithm to it and then adds (element-wise) the
    result to the (flattened) soft prompt tensor.
    """
    grad_norm_micro = jax.lax.pmean(gnorm, "batch")
    grad = jax.lax.pmean(grad, "batch")
    grad_norm = mesh_transformer.util.global_norm(grad)
    # Apply Adam algorithm to grad to get updates (output of the Adam
    # algorithm) and new_opt_state (new state of the Adam optimizer)
    updates, new_opt_state = network.config["optimizer"].update(
        grad, state["opt_state"], params=state["params"][spmodule]
    )
    # optax.apply_updates here just returns the element-wise sum
    # of state["params"][spmodule] and updates, cast to bfloat16.
    state["params"][spmodule] = optax.apply_updates(
        state["params"][spmodule], mesh_transformer.util.to_f32(updates)
    )
    return (
        grad_norm / gradient_accumulation_steps,
        grad_norm_micro,
        {
            "params": state["params"],
            "step": state["step"] + 1,
            "opt_state": new_opt_state,
        },
    )


def show_spinner() -> multiprocessing.Process:
    """
    Shows a bouncing progress bar.  To stop it, save the return value of this
    function as (for example) `spinner`, and then run spinner.terminate().
    """

    def _show_spinner():
        bar = progressbar.ProgressBar(
            max_value=progressbar.UnknownLength,
            widgets=[
                progressbar.Timer(),
                "  ",
                progressbar.BouncingBar(left="[", right="]", marker="█"),
            ],
        )
        i = 0
        while True:
            bar.update(i)
            time.sleep(0.1)
            i += 1

    spinner = multiprocessing.Process(target=_show_spinner, args=())
    spinner.start()
    return spinner


def save_variable(name: str, val) -> None:
    """Save a variable so it can be restored later with `restore_variable`."""
    with open(".notebook_pickle/" + name, "wb") as f:
        pickle.dump(val, f)


def restore_variable(name: str) -> None:
    """Restore a variable saved with `save_variable` if it exists."""
    if not os.path.exists(".notebook_pickle/" + name):
        return
    with open(".notebook_pickle/" + name, "rb") as f:
        globals()[name] = pickle.load(f)


def spform_callback(form_input: str):
    """
    This function gets called when we click the Submit button in the cell
    that asks you for your initial soft prompt
    """
    max_tokenized_len = seq - 1
    global initial_softprompt, step, starting_step, soft_in_dim
    tokenized_input: List[int] = tokenizer.encode(
        form_input, max_length=int(2e9), truncation=True
    )
    if len(tokenized_input) == 0:
        initial_softprompt = soft_in_dim = None
        del initial_softprompt
        del soft_in_dim
        starting_step = step = -1
        return "ERROR:  Your initial soft prompt cannot be empty!"
    if len(tokenized_input) >= max_tokenized_len:
        initial_softprompt = soft_in_dim = None
        del initial_softprompt
        del soft_in_dim
        starting_step = step = -1
        return f"ERROR:  Your initial soft prompt is too long!<br/>It is {len(tokenized_input)} tokens long,<br/>more than the maximum of {max_tokenized_len}."
    initial_softprompt = tokenized_input
    starting_step = step = 0
    soft_in_dim = len(tokenized_input)
    save_variable("initial_softprompt", initial_softprompt)
    save_variable("starting_step", starting_step)
    save_variable("soft_in_dim", soft_in_dim)
    return f"Initial soft prompt set successfully!<br/>({len(tokenized_input)} token{'' if len(tokenized_input) == 1 else 's'} long)"


google.colab.output.register_callback("spform_callback", spform_callback)


"""HTML for the plots that get shown during training"""
plot_html = """
    <style>
        .row { display: flex; }
        .col { flex: 1; }
    </style>
    <div class="row">
        <div class="col"><canvas id="plotl"></canvas></div>
        <div class="col"><canvas id="plotg"></canvas></div>
        <div class="col"><canvas id="plotr"></canvas></div>
    </div>
    <script src="https://cdn.jsdelivr.net/npm/chart.js@3.5.1/dist/chart.min.js" integrity="sha256-bC3LCZCwKeehY6T4fFi9VfOU0gztUa+S4cnkIhVPZ5E=" crossorigin="anonymous"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/numeral.js/2.0.6/numeral.min.js" integrity="sha512-USPCA7jmJHlCNRSFwUFq3lAm9SaOjwG8TaB8riqx3i/dAJqhaYilVnaf2eVUH5zjq89BU6YguUuAno+jpRvUqA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
    <script>
        var labels = [];
        var plotl = new Chart(document.getElementById("plotl").getContext("2d"), {
            type: 'line',
            data: {
                labels: labels,
                datasets: [{
                    label: 'Training Loss',
                    borderColor: 'rgb(239, 41, 41)',
                    cubicInterpolationMode: 'monotone',
                    tension: 0.4,
                    data: []
                }]
            },
            options: {
                animation: { duration: 0 },
                elements: { point: { radius: 0 } },
                scales: { x: { display: true, }, y: { display: true } },
                interaction: { intersect: false, mode: 'nearest', axis: 'x' }
            }
        });
        var plotg = new Chart(document.getElementById("plotg").getContext("2d"), {
            type: 'line',
            data: {
                labels: labels,
                datasets: [{
                    label: 'Gradient L2 Norm',
                    borderColor: 'rgb(114, 159, 207)',
                    cubicInterpolationMode: 'monotone',
                    tension: 0.4,
                    data: []
                }]
            },
            options: {
                animation: { duration: 0 },
                elements: { point: { radius: 0 } },
                scales: { x: { display: true, }, y: { display: true, type: 'logarithmic' } },
                interaction: { intersect: false, mode: 'nearest', axis: 'x' }
            }
        });
        var plotr = new Chart(document.getElementById("plotr").getContext("2d"), {
            type: 'line',
            data: {
                labels: labels,
                datasets: [{
                    label: 'Learning Rate',
                    borderColor: 'rgb(173, 127, 168)',
                    cubicInterpolationMode: 'monotone',
                    tension: 0.4,
                    data: []
                }]
            },
            options: {
                animation: { duration: 0 },
                elements: { point: { radius: 0 } },
                scales: { x: { display: true, }, y: { display: true } },
                interaction: { intersect: false, mode: 'nearest', axis: 'x' },
                plugins: {
                    tooltip: {
                        callbacks: {
                            label: function(context) {
                                return numeral(context.parsed.y).format('0.000e+0');
                            }
                        }
                    }
                }
            }
        });

        function push(label, l, g, r) {
            labels.push(label);
            plotl.data.datasets[0].data.push(l);
            plotg.data.datasets[0].data.push(g);
            plotr.data.datasets[0].data.push(r);
        }

        function update() {
            plotl.update();
            plotg.update();
            plotr.update();
        }
    </script>"""


# Restore any variables that may have been lost by restarting the instance
for var in (
    "starting_step",
    "params",
    "ckpt_path",
    "save_file",
    "stparams",
    "initial_softprompt",
    "soft_in_dim",
    "gradient_accumulation_steps",
    "dataset_file",
    "num_sequences",
):
    restore_variable(var)


print(
    termcolor.colored("\n\nConnecting to your Colab instance's TPU...", "magenta"),
    flush=True,
)
spinner = show_spinner()
colab_tpu_addr = os.environ["COLAB_TPU_ADDR"].split(":")[0]
requests.post(f"http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607")
jax.config.FLAGS.jax_xla_backend = "tpu_driver"
jax.config.FLAGS.jax_backend_target = "grpc://" + os.environ["COLAB_TPU_ADDR"]
spinner.terminate()
if jax.device_count() < 8:
    raise NotebookException(
        "We couldn't detect your Colab instance's TPU.\nTry restarting the runtime (Runtime > Restart Runtime) and trying again."
    )


if "params" in globals():
    mesh_shape = (1, params["cores_per_replica"])
    devices = np.array(jax.devices()[: params["cores_per_replica"]]).reshape(mesh_shape)
    thread_resources_env = jax.experimental.maps.ResourceEnv(
        jax.experimental.maps.Mesh(devices, ("dp", "mp"))
    )
    jax.experimental.maps.thread_resources.env = thread_resources_env


print(termcolor.colored("\n\nInitializing transformers...", "magenta"))
import transformers

tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")


print(termcolor.colored("\n\nDone.\n\n", "green"), flush=True)

<br/><br/><br/><br/>

---

<br/><br/><br/><br/>

# 2. Download the model or log in to Google Drive

First we have to download and extract the model into your Colab instance, if you don't already have the model *unextracted* in your Google Drive. If you do have it unextracted in your Google Drive, you can skip this cell.

You might want to look at the rest of the notebook while the model is downloading/extracting.

In [None]:
# Feel free to modify this cell to download a finetuned GPT-J-6B model instead.
# You can also use GPT-Neo-2.7B models after first converting them with this
# notebook:
# https://colab.research.google.com/github/VE-FORBRYDERNE/mesh-transformer-jax/blob/modelcompat/convert_neo_pytorch_model_to_jax.ipynb


print(termcolor.colored("Installing pv and zstd...", "magenta"))
# The official version of pv doesn't work in Colab anymore for some reason.
# This fork contains a small patch to address the issue.
!git clone https://github.com/VE-FORBRYDERNE/pv
%cd pv
!./configure
!make
!make install
%cd ..
!apt install zstd

print(
    termcolor.colored(
        "\nDownloading GPT-J-6B model into your Colab instance...", "magenta"
    )
)
!wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
print(termcolor.colored("\nExtracting the model...", "magenta"))
!pv step_383500_slim.tar.zstd | tar -I zstd -x
!rm step_383500_slim.tar.zstd

print(termcolor.colored("\nDone.\n\n", "green"))

Are you using a GPT-J-6B or GPT-Neo-2.7B pretrained model? Use the dropdown below to select your model type and then run the cell below.

In [None]:
model_type = "GPT-J-6B"  # @param ["GPT-J-6B", "GPT-Neo-2.7B"]

if model_type == "GPT-Neo-2.7B":
    params = {
        "compat": "neo",
        "layers": 32,
        "d_model": 2560,
        "n_heads": 20,
        "n_vocab": 50257,
        "n_vocab_padding": 143,
        "norm": "layernorm",
        "pe": "fixed",
        "seq": seq,
        "cores_per_replica": 4,
    }
else:
    params = {
        "layers": 28,
        "d_model": 4096,
        "n_heads": 16,
        "n_vocab": 50400,
        "norm": "layernorm",
        "pe": "rotary",
        "pe_rotary_dims": 64,
        "seq": seq,
        "cores_per_replica": 8,
    }
assert (
    params["cores_per_replica"] > 1 and params["cores_per_replica"] % 2 == 0
) or params["cores_per_replica"] == 1
save_variable("params", params)

mesh_shape = (1, params["cores_per_replica"])
devices = np.array(jax.devices()[: params["cores_per_replica"]]).reshape(mesh_shape)
thread_resources_env = jax.experimental.maps.ResourceEnv(
    jax.experimental.maps.Mesh(devices, ("dp", "mp"))
)
jax.experimental.maps.thread_resources.env = thread_resources_env

print("OK.")

If your model is stored unextracted in your Google Drive, you must run this cell below to allow us access to your Google Drive.

In [None]:
google.colab.drive.mount("/content/drive/")

Type the path to the extracted model below and then run the cell below.

If you just downloaded the normal GPT-J-6B model, then the default path that's already shown, `/content/step_383500`, is correct, so you just have to run the cell without changing the path.

If you downloaded a finetuned model, you probably know where it is stored.

If your model is in Google Drive, prefix your path with `/content/drive/MyDrive`. For example, if your model were stored in a directory in the root directory of your Google Drive called "MLP", the path would be `/content/drive/MyDrive/MLP`.

In [None]:
ckpt_path = "/content/step_383500"  # @param {type:"string"}

ckpt_path = ckpt_path.replace("\\", "/")
if not ckpt_path.endswith("/"):
    ckpt_path += "/"

if not os.path.isdir(ckpt_path):
    del ckpt_path
    raise NotebookException("That is not a path to a valid directory.")
if not os.path.exists(ckpt_path + "shard_0/0.npz"):
    del ckpt_path
    raise NotebookException("There doesn't seem to be a model in that directory.")

save_variable("ckpt_path", ckpt_path)
if "network" in globals():
    print(
        "WARNING:  Due to memory constraints, you must restart your instance (Runtime > Restart runtime) before continuing further."
    )
    print("          (You also have to run Step 1 again)")
    del network
    if "spmodule" in globals():
        del spmodule
print("OK.")

<br/><br/><br/><br/>

---

<br/><br/><br/><br/>

# 3. Set up soft-tuning hyperparameters and training data

If you want to save your soft prompt into your Google Drive, run this cell to login to Google Drive.

In [None]:
google.colab.drive.mount("/content/drive/")

If you want to begin a new soft-tuning run, choose the path where we will save to. You will see a file there before the soft-tuning process is fully complete, it will be there so you can resume the soft-tuning process later if your Colab instance crashes. If you want to save into your Google Drive, prefix your path with `/content/drive/MyDrive`.

If you want to resume a soft-tuning run for the aforementioned reason, choose the path to an existing MTJSP file.

In [None]:
save_file = "/content/drive/MyDrive/my_softprompt.mtjsp"  # @param {type:"string"}

save_file = save_file.replace("\\", "/")
if save_file.endswith("/"):
    soft_in_dim = None
    del soft_in_dim
    del save_file
    starting_step = step = -1
    raise NotebookException("save_file should be a file, not a directory.")

os.makedirs(save_file.rsplit("/", 1)[0].strip(), exist_ok=True)

if os.path.exists(save_file):
    try:
        npz = np.load(save_file, allow_pickle=True)
        assert npz["step"] > 0
        assert npz["tensor"].ndim == 2 and "opt_state" in npz
        assert npz["tensor"].shape[0] < seq
        assert npz["tensor"].shape[1] == params["d_model"]
        assert all(
            p in npz for p in ("loss", "last_loss", "grad_norm", "grad_norm_micro")
        )
        soft_in_dim = npz["tensor"].shape[0]
        starting_step = step = np.uint32(npz["step"]).item()
    except:
        soft_in_dim = None
        del soft_in_dim
        del save_file
        starting_step = step = -1
        raise NotebookException("MTJSP file exists and is not a valid save file.")
    print("OK.")
    print(f"We will resume soft-tuning at step {starting_step + 1}.")
    save_variable("starting_step", starting_step)
    save_variable("soft_in_dim", soft_in_dim)
else:
    starting_step = step = -1
    print("OK.")
    print("We will begin a new soft-tuning run.")

save_variable("save_file", save_file)

If you are beginning a new soft-tuning run, choose a string to initialize your soft prompt with. It should be roughly 20-200 tokens long. The maximum allowed length is 2047 tokens. It's recommended that your string should end with two newline characters and have no other leading or trailing whitespace on any line.

If you're resuming a soft-tuning run with an existing MTJSP file, you can skip this cell.

In [None]:
#@markdown ### Run this cell once to make a text box appear for you to type in your initial soft prompt.<br/>After that, press the "Submit" button underneath the text box; do not run this cell a second time.
%%html
<form>
    <textarea id="softprompt" rows="10" cols="80">Le Jeu du Prochain Train itself is simplicity in motion. The object: Be the last of your round's six to jump from one side of the tracks to the other - that is, across the tracks - before the train passes.

</textarea>
    <br/>
    <p><input id="submit-softprompt" type="button" value="Submit" /></p>
</form>
<br/>
<p id="softprompt-message"></p>
<script type="text/javascript">
    (function() {
        var submit = document.getElementById("submit-softprompt");
        var softprompt = document.getElementById("softprompt");
        var message = document.getElementById("softprompt-message");
        submit.addEventListener("click", async function() {
            var msg = await google.colab.kernel.invokeFunction("spform_callback", [softprompt.value], {});
            message.innerHTML = msg.data['text/plain'].replace(/(?:^ *'*)|(?:'* *$)/g, "");
        });
        softprompt.addEventListener("input", function() {
            message.innerHTML = "";
        });
    })();
</script>

If your dataset is stored in Google Drive, you have to log in to Google Drive so we can access it.

In [None]:
google.colab.drive.mount("/content/drive/")

If your dataset is a single txt file or collection of txt files, we have to convert it to npy format first. If you have already used this notebook to convert your txt dataset to npy, you can skip this cell.

`dataset_path` should be the path to either a single txt file or a folder with one or more txt files in it. Then run the cell, and we will make a npy file using your dataset at the given path (we will create the required directory tree for the output file if the output file's directory doesn't already exist). If your txt files are in Google Drive, prefix your path with `/content/drive/MyDrive`.

`batch_size` is explained in this article: https://medium.com/mini-distill/effect-of-batch-size-on-training-dynamics-21c14f7a716e. The maximum possible batch size is 2048 minus the number of tokens in your initial soft prompt, so if your initial soft prompt is 49 tokens long then the maximum allowed batch size is 1999. If your batch_size is too high, we will automatically lower it to the highest possible value, so just leave it at 2048 if you want us to do that. Epochs is the amount of times to repeat your dataset (it will be shuffled every time).

In [None]:
dataset_path = "/content/drive/MyDrive/dataset.txt"  # @param {type:"string"}
output_file = "/content/drive/MyDrive/output.npy"  # @param {type:"string"}
batch_size = 2048  # @param {type:"integer"}
epochs = 1  # @param {type:"integer"}

dataset_path = dataset_path.replace("\\", "/")
output_file = output_file.replace("\\", "/")
if "starting_step" not in globals() or starting_step == -1:
    del dataset_path
    del output_file
    del batch_size
    del epochs
    raise NotebookException(
        "You did not load from a MTJSP file or define an initial soft prompt."
    )
if not isinstance(batch_size, int) or batch_size < 1:
    del dataset_path
    del output_file
    del batch_size
    del epochs
    raise NotebookException("batch_size must be an integer greater than zero.")
if not isinstance(epochs, int) or epochs < 1:
    del dataset_path
    del output_file
    del batch_size
    del epochs
    raise NotebookException("epochs must be an integer greater than zero.")
if output_file.endswith("/"):
    del dataset_path
    del output_file
    del batch_size
    raise NotebookException("output_file should be a file, not a directory.")
if not os.path.exists(dataset_path):
    del dataset_path
    del output_file
    del batch_size
    del epochs
    raise NotebookException("dataset_path is not set to a valid file or directory.")


batch_size = min(batch_size, seq - soft_in_dim)
assert batch_size >= 0
print(
    termcolor.colored(
        "\nIf you see a warning above about token indices, ignore it.  That warning is normal.\n",
        "magenta",
    )
)
print("Batch size:", batch_size)
print(termcolor.colored("Tokenizing your dataset...\n", "magenta"))

if os.path.isfile(dataset_path):
    files = [dataset_path]
else:
    files = (
        os.path.join(dataset_path, filename) for filename in os.listdir(dataset_path)
    )
tokens = []
for path in files:
    with open(path) as f:
        tokens.extend(tokenizer.encode(ftfy.fix_text(f.read())) + [50256])

print("Dataset size (in tokens):", len(tokens))
if len(tokens) < batch_size + 1:
    raise NotebookException(
        "Your dataset is too small!  The number of tokens has to be greater than the batch size."
    )
tail = len(tokens) % (batch_size + 1)
if tail:
    print(
        f"We're removing the last {tail} tokens from your dataset to make the length a multiple of {batch_size+1}."
    )
    tokens = tokens[:-tail]

tokens = np.array(tokens, dtype=np.uint16).reshape((-1, batch_size + 1))
if epochs > 1:
    rng = np.random.Generator(np.random.PCG64(1729))
    tokens = np.concatenate(
        (
            tokens,
            *(rng.permutation(tokens, axis=0) for i in range(epochs - 1)),
        ),
        axis=0,
    )
print(f"Total sequences in your dataset: {tokens.shape[0]}")

with open(output_file, "w") as f:
    np.save(output_file, tokens)

print("OK.")

Here, we set the npy that we will use for soft-tuning.

`gradient_accumulation_steps` is described here: https://towardsdatascience.com/what-is-gradient-accumulation-in-deep-learning-ec034122cfa. It's preferable to have gradient accumulation steps in the 16-32 range.

In [None]:
dataset_file = "/content/drive/MyDrive/output.npy"  # @param {type:"string"}
gradient_accumulation_steps = 16  # @param {type:"integer"}

save_variable("gradient_accumulation_steps", gradient_accumulation_steps)

if not isinstance(gradient_accumulation_steps, int) or gradient_accumulation_steps < 1:
    del dataset_file
    del gradient_accumulation_steps
    raise NotebookException(
        "gradient_accumulation_steps must be an integer greater than zero."
    )
if not os.path.exists(dataset_file):
    del dataset_file
    del gradient_accumulation_steps
    raise NotebookException("Could not find any file at that path.")

dataset = np.load(dataset_file, mmap_mode="r")
assert dataset.ndim >= 2
assert dataset.shape[0] >= 2
num_sequences = dataset.shape[0]
print("Batch size of your dataset:", dataset.shape[1] - 1)
print("Total sequences in your dataset:", num_sequences)
print()

if num_sequences < gradient_accumulation_steps:
    del dataset_file
    del gradient_accumulation_steps
    del num_sequences
    raise NotebookException(
        "Your dataset is too small!  gradient_accumulation_steps must be less than or equal to the number of sequences."
    )

if dataset.shape[1] - 1 > seq - soft_in_dim:
    print(
        f"WARNING:  The batch size of your dataset is {dataset.shape[1] - 1}\n"
        f"which is larger than the allowed maximum of {seq - soft_in_dim}.\n"
        "Your dataset will be truncated!",
        file=sys.stderr,
    )

save_variable("num_sequences", num_sequences)
save_variable("dataset_file", dataset_file)
print("OK.")

Now it is time to set the other soft-tuning hyperparameters. Edit the numbers below (or don't edit any) and then run the cell.

By default we use the same modified version of the Adam optimization algorithm that Mesh Transformer JAX uses by default for training.

The main thing you have to pay attention to is `lr` and `max_grad_norm`; everything else is basically universally OK for training. It is recommended to set `lr` to somewhere in the `1e-5` to `5e-5` range. Higher `lr` results in the trainer having a stronger effect. Values higher than around `7e-5` tend to result in exploding gradients (numerical instability) and should be avoided! You can tell when this happens because the Gradient L2 Norm will start increasing abnormally and eventually "explode" to values in the thousands. If the trainer still isn't strong enough, you should train with more epochs instead of risking the numerical instability.

`max_grad_norm` controls the maximum allowed rate at which the soft prompt can be trained, so if the trainer tries to change the soft prompt by too much in one step, the changes will be scaled down uniformly. Lower values are more restrictive.

* `save_every` (`int` > 0): We'll save an MTJSP file every this many steps so that if you are disconnected from Colab, you can continue from an earlier point in the training.
* `warmup` (`float` between 0.0 and 1.0 inclusive): What portion of the beginning of the total training steps should be warmup steps. The learning rate for warmup steps starts at 0.0 and increases linearly to the maximum learning rate.
* `lr` (`float` > 0): Aforementioned maximum learning rate.
* `end_lr_multiplier` (`float` > 0): After the warmup steps, the remaining training steps have a learning rate controlled by a cosine function that goes from the maximum learning rate to this proportion of the maximum learning rate (i.e. the default is one-tenth of the maximum learning rate).
* `weight_decay` (`float` between 0.0 and 1.0 inclusive): The soft prompt you're going to train is actually a two-dimensional array of floating-point numbers. The absolute values of the numbers tend to be pretty small (less than 1). If the absolute value of one of those numbers grows too large, the trainer tends to try to reduce the absolute values of the entire array, resulting in the entire array being filled with zeros and ruining your soft prompt. This setting effectively restricts the absolute value of the numbers in the array (higher weight decay value means the maximum absolute value is lower) to help prevent this scenario. Of course, setting it too high would lower the absolute value of your array, too, just like what would happen if you'd set it too low, so there's usually an optimal value. People have suggested 0.1 as a good all-around weight decay factor.
* `max_grad_norm` (`float` > 0): Controls the maximum allowed rate at which the soft prompt can be trained, so if the trainer tries to change the soft prompt by too much in one step, the changes will be scaled down uniformly. Lower values are more restrictive. If the "Gradient L2 Norm" is much higher (e.g. at least twice as much) on average than this value, you should probably raise this value.

In [None]:
lr = 3e-5  # @param {type:"number"}
max_grad_norm = 10.0  # @param {type:"number"}
weight_decay = 0.1  # @param {type:"number"}
warmup = 0.1  # @param {type:"number"}
end_lr_multiplier = 0.1  # @param {type:"number"}
save_every = 50  # @param {type:"integer"}
stparams = {
    param: globals()[param]
    for param in (
        "lr",
        "max_grad_norm",
        "weight_decay",
        "warmup",
        "end_lr_multiplier",
        "save_every",
    )
}

save_variable("stparams", stparams)
print("OK.")

<br/><br/><br/><br/>

---

<br/><br/><br/><br/>

# 4. Soft-tune the model

## If you reached here and at any point after that restarted your Colab instance by using Runtime > Restart Runtime, you will have to run the step 1 cell again and then run the cells below again. You don't have to re-run any cells in between.

This can take quite a while depending on how fast your Colab instance is. Note that currently (2021-10-12), all Colab TPU instances will train at the same speed, some just take longer to initialize the model than others.

In [None]:
if "ckpt_path" not in globals():
    raise NotebookException("You didn't specify the path to your model.")
elif "params" not in globals():
    raise NotebookException(
        "You have not specified whether you're using a GPT-J-6B or GPT-Neo-2.7B model."
    )
elif starting_step == -1:
    raise NotebookException(
        "You did not set an initial soft prompt string to begin soft-tuning with or existing save file to resume soft-tuning with."
    )
if "dataset_file" not in globals():
    raise NotebookException("You have not specified the path to your npy dataset file.")

step = starting_step

# Set up the scheduler which determines the learning rate for each step
steps = num_sequences // gradient_accumulation_steps
warmup_steps = max(1, round(steps * stparams["warmup"]))
scheduler = mesh_transformer.util.gpt3_schedule(
    warmup_steps,
    max(1, steps - warmup_steps),
    stparams["lr"],
    stparams["end_lr_multiplier"] * stparams["lr"],
)

# Tell Mesh Transformer to create the network as bfloat16
params["early_cast"] = True

if step == 0:
    print("We are starting a brand new soft-tuning session.\n")
else:
    # If we're resuming a soft-tuning session, the soft prompt tensor is
    # already in the save file and we just have to decode it.
    try:
        npz = np.load(save_file, allow_pickle=True)
        assert npz["step"] > 0
        assert npz["tensor"].ndim == 2 and "opt_state" in npz
        assert npz["tensor"].shape[0] < seq
        assert npz["tensor"].shape[1] == params["d_model"]
        assert all(
            p in npz
            for p in (
                "loss",
                "last_loss",
                "grad_norm",
                "grad_norm_micro",
            )
        )
        assert soft_in_dim == npz["tensor"].shape[0]
        step = np.uint32(npz["step"]).item()
    except:
        raise NotebookException("MTJSP file is corrupted.")
    print(f"We're resuming a previous soft-tuning session at step {step+1}.\n")
    soft_embeddings = npz["tensor"]
    if soft_embeddings.dtype == "V2":
        soft_embeddings.dtype = jnp.bfloat16
    soft_embeddings = jnp.float32(soft_embeddings)

# Load the model
if "spmodule" not in globals():
    print(termcolor.colored("Initializing network...", "magenta"), flush=True)
    params["optimizer"] = optax.scale(0)
    network = EmbeddingCausalTransformer(params)
    print(termcolor.colored("\n\nLoading pretrained model...", "magenta"), flush=True)
    network.state, spmodule = read_ckpt_custom(
        network.state, ckpt_path, params["cores_per_replica"]
    )
    network.state = network.move_xmap(
        network.state, np.zeros(params["cores_per_replica"])
    )
    network.state["params"][spmodule]["w"] = jnp.float32(
        network.state["params"][spmodule]["w"]
    )

# Set up the optimizer, which is the algorithm we use to train the soft prompt
params["optimizer"] = network.config["optimizer"] = optax.chain(
    optax.scale(1 / gradient_accumulation_steps),
    mesh_transformer.util.clip_by_global_norm(float(stparams["max_grad_norm"])),
    optax.scale_by_adam(mu_dtype=jnp.float32),
    mesh_transformer.util.additive_weight_decay(stparams["weight_decay"]),
    optax.scale(-1),
    optax.scale_by_schedule(scheduler),
)

if step == 0:
    # If we're starting a soft-tuning session from scratch, we initialize the
    # soft prompt tensor by using the model to "embed" the tokens from the
    # initial soft prompt string, producing a matrix (2D array) that we use as
    # the soft prompt tensor.
    soft_embeddings = network.get_embedding_matrix(
        np.array(initial_softprompt, dtype=np.uint32)
    )
    soft_embeddings = jnp.float32(soft_embeddings)
    # We also have to initialize the optimizer state in that case
    network.state["opt_state"] = init_opt_state(network.state["params"][spmodule])
else:
    # Optimizer state is already saved otherwise
    network.state["opt_state"] = mesh_transformer.util.to_f32(tuple(npz["opt_state"]))

# Pad the embedding matrix with zeros at the bottom so that its number of
# rows is a multiple of 8 (or 4 for GPT-Neo-2.7B)
rows = soft_embeddings.shape[0]
padding_amount = -(rows % -params["cores_per_replica"])
soft_embeddings = jnp.pad(soft_embeddings, ((0, padding_amount), (0, 0)))
# Split the matrix row-wise into 8 (or 4) submatrices (so that if the original
# matrix had R rows and C columns, then each submatrix has R/8 rows and C
# columns) and then concatenate the 8 submatrices together along a new
# leading axis into a 3-dimensional array so that it can be sharded by
# xmapped functions
soft_embeddings = soft_embeddings.reshape(
    (params["cores_per_replica"], -1, params["d_model"])
)
# Put this 3D array into the network so we can train it
network.state["params"][spmodule]["w"] = soft_embeddings


###############################################################################


def save_mtjsp(
    loss,
    last_loss,
    grad_norm,
    grad_norm_micro,
):
    global starting_step
    tensor = network.state["params"][spmodule]["w"]
    tensor = tensor.reshape((-1, tensor.shape[2]))
    tensor = tensor[:soft_in_dim]
    with open(save_file, "wb") as f:
        np.savez_compressed(
            f,
            tensor=tensor,
            opt_state=np.array(network.state["opt_state"], dtype=np.object),
            step=np.uint32(step),
            loss=np.float32(loss),
            last_loss=np.float32(last_loss),
            grad_norm=np.float32(grad_norm),
            grad_norm_micro=np.float32(grad_norm_micro),
        )
    starting_step = step
    save_variable("starting_step", starting_step)


def push_data(
    loss,
    last_loss,
    grad_norm,
    grad_norm_micro,
):
    """
    Updates the training plots with the given data.
    """
    IPython.display.display(
        IPython.display.Javascript(
            f"push({step}, {loss}, {grad_norm}, {scheduler(step)}); update();"
        )
    )


def train_step(use_tqdm=True):
    # Get the next batch from the dataset
    data = dataset[
        (step - 1) * gradient_accumulation_steps : step * gradient_accumulation_steps
    ]
    # Concatenate the soft prompt at the beginning
    vocab_size = params["n_vocab"] + params.get("n_vocab_padding", 0)
    header = np.tile(
        np.arange(vocab_size, vocab_size + soft_in_dim, dtype=np.uint32),
        (data.shape[0], 1),
    )
    data = np.concatenate((header, data), axis=1)[:, : seq + 1]

    ctx = data[:, :-1]
    tgt = data[:, 1:]

    grad, loss, last_loss, gnorm = train_initial(
        network.state,
        ctx[np.newaxis, 0],
        tgt[np.newaxis, 0],
    )
    r = range(1, ctx.shape[0])
    if use_tqdm:
        r = tqdm(
            r,
            initial=1,
            total=ctx.shape[0],
            desc="GRADIENT ACCUMULATION",
            leave=False,
        )
    for i in r:
        grad, _loss, _last_loss, _gnorm = train_intermediate(
            network.state,
            grad,
            ctx[np.newaxis, i],
            tgt[np.newaxis, i],
        )
        loss += _loss
        last_loss += _last_loss
        gnorm += _gnorm
    loss /= ctx.shape[0]
    last_loss /= ctx.shape[0]
    gnorm /= ctx.shape[0]
    grad_norm, grad_norm_micro, network.state = train_final(
        network.state,
        grad,
        gnorm,
    )
    del grad

    return (
        np.array(loss).mean(),
        np.array(last_loss).mean(),
        np.array(grad_norm).mean(),
        np.array(grad_norm_micro).mean(),
    )


step += 1


# Train
first_step = step
if step <= steps:
    # Load the dataset
    dataset = np.load(dataset_file, mmap_mode="r")
    assert dataset.ndim >= 2
    assert dataset.shape[1] >= 2
    assert dataset.shape[0] == num_sequences
    print(
        termcolor.colored(
            "Compiling trainer, this may take several minutes\n", "magenta"
        ),
        flush=True,
    )
    # Simultaneously compile the trainer and train for one step
    loss, last_loss, grad_norm, grad_norm_micro = train_step(use_tqdm=False)
    # Show the plots for learning rate, etc.
    IPython.display.clear_output()
    IPython.display.display(IPython.core.display.HTML(plot_html))
    # Update plot
    push_data(
        loss,
        last_loss,
        grad_norm,
        grad_norm_micro,
    )
# Create a save file for step 1
if step == 1 or step % stparams["save_every"] == 0:
    save_mtjsp(
        loss,
        last_loss,
        grad_norm,
        grad_norm_micro,
    )
for i in tqdm(
    range(first_step, steps),
    initial=first_step,
    total=steps,
    desc="SOFT-TUNING PROGRESS",
):
    step += 1
    # Train for one step and update the plot
    loss, last_loss, grad_norm, grad_norm_micro = train_step(use_tqdm=True)
    push_data(
        loss,
        last_loss,
        grad_norm,
        grad_norm_micro,
    )
    # Save whenever step is divisible by save_every
    if step % stparams["save_every"] == 0:
        save_mtjsp(
            loss,
            last_loss,
            grad_norm,
            grad_norm_micro,
        )
step += 1
save_mtjsp(
    loss,
    last_loss,
    grad_norm,
    grad_norm_micro,
)

Use this cell to login to Google Drive if required.

In [None]:
google.colab.drive.mount("/content/drive/")

Once you finish soft-tuning, you can use the following cell to convert your MTJSP file to a KoboldAI United-compatible ZIP file:

(`supported` should be the name of a model or a comma-separated list of such names that the soft prompt is intended to be used with)

In [None]:
output_file = "/content/drive/MyDrive/my_softprompt.zip"  # @param {type:"string"}
name = "Untitled"  # @param {type:"string"}
author = ""  # @param {type:"string"}
supported = "Generic 6B"  # @param {type:"string"}
description = "Baby shoes"  # @param {type:"string"}

try:
    npz = np.load(save_file, allow_pickle=True)
    assert npz["step"] > 0
    assert npz["tensor"].ndim == 2 and "opt_state" in npz
    assert npz["tensor"].shape[0] < seq
    assert npz["tensor"].shape[1] == params["d_model"]
    assert all(
        p in npz
        for p in (
            "loss",
            "last_loss",
            "grad_norm",
            "grad_norm_micro",
        )
    )
    _step = np.uint32(npz["step"]).item()
except:
    raise NotebookException("MTJSP file is corrupted.")

tensor = npz["tensor"]
if tensor.dtype == "V2":
    tensor.dtype = jnp.bfloat16

meta = {
    k: globals()[k]
    for k in (
        "name",
        "author",
        "supported",
        "description",
    )
}
if len(meta["author"].strip()) == 0:
    meta.pop("author")
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))

with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
    with z.open("tensor.npy", "w") as f:
        np.save(f, tensor, allow_pickle=False)
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
    with z.open("meta.json", "w") as f:
        f.write(json.dumps(meta, indent=2).encode("utf-8"))

Or this one, to convert your MTJSP file to an mkultra-compatible JSON file.

In [None]:
output_file = "/content/drive/MyDrive/my_softprompt.json"  # @param {type:"string"}
soft_prompt_name = "Untitled"  # @param {type:"string"}
soft_prompt_description = "Baby shoes"  # @param {type:"string"}

try:
    npz = np.load(save_file, allow_pickle=True)
    assert npz["step"] > 0
    assert npz["tensor"].ndim == 2 and "opt_state" in npz
    assert npz["tensor"].shape[0] < seq
    assert npz["tensor"].shape[1] == params["d_model"]
    assert all(
        p in npz
        for p in (
            "loss",
            "last_loss",
            "grad_norm",
            "grad_norm_micro",
        )
    )
    _step = np.uint32(npz["step"]).item()
except:
    raise NotebookException("MTJSP file is corrupted.")

tensor = npz["tensor"]
if tensor.dtype == "V2":
    tensor.dtype = jnp.bfloat16
    tensor = torch.tensor(tensor).to(torch.float32)
else:
    tensor = torch.tensor(tensor)

with open(output_file, "w") as f:
    json.dump(
        {
            "metadata": {
                "step": _step,
                "loss": float(npz["loss"].item()),
                "uuid": str(uuid.uuid4()),
                "name": soft_prompt_name,
                "description": soft_prompt_description,
                "epoch": datetime.datetime.now().timestamp(),
            },
            "tensor": base64.b64encode(
                pickle.dumps(
                    tensor,
                    protocol=4,
                ),
            ).decode("ascii"),
        },
        f,
    )

print("OK.")