In [None]:
!pip install transformers==4.14.1
!pip install bitsandbytes

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.14.1
  Downloading transformers-4.14.1-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 11.3 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 69.3 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 64.9 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 63.2 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=dfda

In [None]:
#@title Create wrappers
import transformers

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm

class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )


class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

In [None]:
#@title Load model
# import torch, transformers

# tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
# model = transformers.AutoModelForCausalLM.from_pretrained("OpenDungeon/gpt-j-8bit-ffbgem",
#                                                           device_map="auto",
#                                                           load_in_8bit=True,
#                                                           low_cpu_mem_usage=True).cuda()
# model.eval()

config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

Downloading:   0%|          | 0.00/930 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.75G [00:00<?, ?B/s]

k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, 

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): FrozenBNBEmbedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): FrozenBNBLinear(4096, 4096)
          (v_proj): FrozenBNBLinear(4096, 4096)
          (q_proj): FrozenBNBLinear(4096, 4096)
          (out_proj): FrozenBNBLinear(4096, 4096)
        )
        (mlp): GPTJMLP(
          (fc_in): FrozenBNBLinear(4096, 16384)
          (fc_out): FrozenBNBLinear(16384, 4096)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0

In [None]:
#@title Helper functions

import time

def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
    global input_dict, outputs

    past_key_values = None  # used to keep track of conversation history
    input_dict = tokenizer([prompt], return_tensors='pt')
    for k, v in input_dict.items():
        v = v[0, -1500:] # limit attention to suffix
        v = torch.stack([v] * batch)
        v = v.to(local_model.device)
        input_dict[k] = v

    output = [""] * batch
    batch_time = 0
    
    with torch.inference_mode():
        for i in range(limit_tokens + 20):
            if i == 5:
                start_time = time.perf_counter()

            outputs = local_model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
            last_logits = outputs.logits[:, -1]

            for j in range(batch):
                last_logits[j, last_logits[j].topk(k=10).indices] += 10

            past_key_values = outputs.past_key_values
            token_ix = torch.multinomial(last_logits.softmax(-1), 1)
            output = [stream + tokenizer.decode(ix) for stream, ix in zip(output, token_ix)]

            if single_hook is not None:
                if not single_hook(tokenizer.decode(token_ix[0])):
                    batch_time = (time.perf_counter() - start_time) / (i - 4)
                    break
            if i == limit_tokens:
                batch_time = (time.perf_counter() - start_time) / (i - 4)
                break

            input_dict = dict(input_ids=token_ix)
    return output, batch_time


def CutSentence(text):
    end = 0
    for punct in ".!?\n":
        end = max(end, text.rfind(punct))
    if end == 0:
        end = len(text) - 1
    return text[:end + 1]

In [None]:
#@title GM interface
MASTER_PREFIX = "Dungeon Master: "
PLAYER_PREFIX = "John: "

story = f"""It is a fantasy role-play game.

{MASTER_PREFIX}You are John, a wizard living in the kingdom of Larion. You have a staff and a spellbook. You finish your long journey and finally arrive at the ruin you've been looking for. You have come here searching for a mystical spellbook of great power called the book of essence. You look around and see the ancient ruins of an elf tower. The ruins have not been touched for decades. You look at the tower, and you can see a set of stone stairs that seem to lead somewhere deep inside the tower.
{PLAYER_PREFIX}I walk upstairs
{MASTER_PREFIX}You climb up the stairs in the ruined tower. There is a door on the second floor of the tower, the door seems to be made of enchanted wood.
{PLAYER_PREFIX}I ask the door if I may to come in
{MASTER_PREFIX}The door sighs open and you walk into the room."""

def ReprStory(text):
    return text.replace('\n', '<br>')

top_label = widgets.HTML(value=ReprStory(story))
user_option = widgets.Textarea(
    placeholder='What do you do?',
    disabled=False
)

submiter = widgets.Button(
    description="Continue",
    button_style='success'
)

user_inputs = [user_option, submiter]

def DoCuntinue(_):
    global story
    story += "\n" + PLAYER_PREFIX + user_option.value + "\n" + MASTER_PREFIX
    top_label.value = ReprStory(story)

    def uprint(text):
        global story
        story += text
        top_label.value = ReprStory(story)
        return not (PLAYER_PREFIX[:-2] in text)

    submiter.layout.visibility = 'hidden'

    PrintContinuation(story, model, uprint, 1, 50)

    story = CutSentence(story)
    top_label.value = ReprStory(story)
    user_option.value = ""

    submiter.layout.visibility = 'visible'

submiter.on_click(DoCuntinue)

display(widgets.VBox(
    [top_label] + user_inputs
))


VBox(children=(HTML(value="It is a fantasy role-play game.<br><br>Dungeon Master: You are John, a wizard livin…

In [None]:
#@title Co-writer interface

import ipywidgets as widgets

PROMPT_CUSTOM = 'I have a better idea'
PROMPT_ROLLBACK = 'Rollback the last move'
prompts = ['Once upon a time',
'This is fantasy story about a wizard living in the kingdom of Larion. He has a staff and a spellbook. He finishes his long journey and finally arrive at the ruin he have been looking for.',
PROMPT_ROLLBACK,
PROMPT_CUSTOM]

MASTER_PREFIX = "Dungeon Master: "
PLAYER_PREDIX = "John: "

story_log = []


top_label = widgets.HTML(value="Choose adventure")
selector = widgets.RadioButtons(
    options=prompts,
    layout={'width': 'max-content'}
)
user_option = widgets.Textarea(
    placeholder='Your own idea',
    disabled=False
)
submiter = widgets.Button(
    description="Continue",
    button_style='success'
)

user_inputs = [selector, user_option, submiter]

def DoCuntinue(_):
    global selector, story_log
    if selector.value == PROMPT_ROLLBACK:
        story_log = story_log[:-1]
    elif selector.value == PROMPT_CUSTOM:
        story_log.append(user_option.value)
    else:
        story_log.append(selector.value)

    pre_story = '\n'.join(story_log).replace('\n', '<br>')
    top_label.value = pre_story + '<br>'

    def uprint(text):
        top_label.value += text
        return True

    for i in range(len(user_inputs)):
        user_inputs[i].layout.visibility = 'hidden'

    options, _ = PrintContinuation(pre_story, model, uprint, 5, 50)

    options = [CutSentence(o) for o in options]

    top_label.value=pre_story
    options.append(PROMPT_ROLLBACK)
    options.append(PROMPT_CUSTOM)
    selector.options = options
    user_option.value = ""

    for i in range(len(user_inputs)):
        user_inputs[i].layout.visibility = 'visible'

submiter.on_click(DoCuntinue)

display(widgets.VBox(
    [top_label] + user_inputs
))

VBox(children=(HTML(value='Choose adventure'), RadioButtons(layout=Layout(width='max-content'), options=('Once…