In [5]:
!export HF_HOME=/vol/bitbucket/rm1623/.cache/

Traceback (most recent call last):
  File "/vol/bitbucket/rm1623/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1390, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
  File "/vol/bitbucket/rm1623/miniconda3/envs/llm/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/vol/bitbucket/rm1623/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/testing_utils.py", line 131, in <module>
    fro

In [2]:
import os

os.environ["HF_HOME"] = "/vol/bitbucket/rm1623/.cache/"

In [4]:
import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    OPTPreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

from typing import Optional, Tuple, Union, List
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import OPTPreTrainedModel, OPTModel

import torch
import torch.nn as nn

class OPTForCausalLM(OPTPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = OPTModel(config)

        # the lm_head weight is automatically tied to the embed tokens weight
        self.bit_size = torch.log2(torch.tensor(config.vocab_size)).ceil().int().item()
        self.lm_head = nn.Sequential(
            nn.Linear(config.word_embed_proj_dim, self.bit_size, bias=False),
            # nn.Sigmoid(),  not to be used with BCEWithLogitsLoss
        )

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model.decoder = decoder

    def get_decoder(self):
        return self.model.decoder

    def int_to_bin_tensor(self, val):
        length = self.bit_size
        bin_str = format(val, "0" + str(length) + "b")
        bin_tensor = torch.tensor([int(bit) for bit in bin_str])
        return bin_tensor

    def bin_tensor_to_int(self, bin_tensor):
        bin_str = "".join(str(int(bit.item())) for bit in bin_tensor)
        return int(bin_str, 2)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r""" """

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = self.lm_head(outputs[0]).contiguous()  # (bs, seq_length, bit_size)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # convert the labels to binary - currently they are indexes of the tokenizer
            binary_tensors = [
                self.int_to_bin_tensor(val.item()) for val in shift_labels.flatten()
            ]
            # get the binary tokens in the same shape as the original tensor
            binary_tensors = torch.stack(binary_tensors).view(*shift_labels.shape, -1)
            binary_tensors = binary_tensors.to(logits.device)
            # add L1 loss
            # loss_fct = nn.L1Loss()
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(shift_logits.float(), binary_tensors.float())

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs
    ):
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx.to(past_state.device))
                    for past_state in layer_past
                ),
            )
        return reordered_past

In [10]:
!pip install --force-reinstall peft==0.5.0

Collecting peft==0.5.0
  Downloading peft-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting numpy>=1.17 (from peft==0.5.0)
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting packaging>=20.0 (from peft==0.5.0)
  Downloading packaging-24.0-py3-none-any.whl.metadata (3.2 kB)
Collecting psutil (from peft==0.5.0)
  Downloading psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Collecting pyyaml (from peft==0.5.0)
  Downloading PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)
Collecting torch>=1.13.0 (from peft==0.5.0)
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting transformers (from peft==0.5.0)
  Downloading transformers-4.40.2-py3-none-any.whl.

In [19]:
from peft import PeftModel, PeftConfig
original_model = "facebook/opt-350m"
finetuned_model = "/vol/bitbucket/rm1623/llms/opt-350m-alpaca-bce"


# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(original_model, cache_dir="/vol/bitbucket/rm1623/.cache/")

# Load the model configuration

config = PeftConfig.from_pretrained(finetuned_model)

# Load the trained model
model = OPTForCausalLM.from_pretrained(
    original_model, 
    # config=config, 
    return_dict=True, 
    device_map="auto"
)

loading configuration file config.json from cache at /vol/bitbucket/rm1623/.cache/models--facebook--opt-350m/snapshots/08ab08cc4b72ff5593870b5d527cf4230323703c/config.json
Model config OPTConfig {
  "_name_or_path": "facebook/opt-350m",
  "_remove_final_layer_norm": false,
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": false,
  "dropout": 0.1,
  "enable_bias": true,
  "eos_token_id": 2,
  "ffn_dim": 4096,
  "hidden_size": 1024,
  "init_std": 0.02,
  "layer_norm_elementwise_affine": true,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "float16",
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 50272,
  "word_embed_proj_dim": 512
}

loading file vocab.json from cache at /vol/bitbucket/rm

In [20]:
tokenizer.pad_token = tokenizer.eos_token

In [22]:
fine_tuned_model = PeftModel.from_pretrained(model, finetuned_model)

In [24]:
fine_tuned_model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): OPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 512, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
          (project_out): Linear(in_features=1024, out_features=512, bias=False)
          (project_in): Linear(in_features=512, out_features=1024, bias=False)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(
                  in_features=1024, out_features=1024, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=64, bias=False)
                  )
      

In [31]:
# Prepare the input text
input_text = "Once upon a time"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
print(input_ids)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
input_ids = input_ids.to(device)

# Perform inference
with torch.no_grad():
    outputs = model.generate(input_ids, max_length=50, num_return_sequences=1)

print(outputs)
# Decode binary outputs back to token IDs
logits = outputs[0]  # assuming single sequence generation
predicted_ids = []
for i in range(logits.size(0)):
    # Extract the bit tensor
    bit_tensor = logits[i]
    print(bit_tensor)
    # Convert bit tensor to integer token ID
    # token_id = model.bin_tensor_to_int(bit_tensor)
    # predicted_ids.append(token_id)

# Convert token IDs to tokens
generated_text = tokenizer.decode(logits, skip_special_tokens=True)

print(generated_text)

tensor([[    2, 11475,  2115,    10,    86]])
tensor([[    2, 11475,  2115,    10,    86,     7,     7,     0,     7,     7,
             0,     7,     0,     7,     0,     7,     0,     7,     0,     7,
             0,     7,     0,     7,     0,     7,     0,     7,     0,     7,
             0,     7,     0,     7,     0,     7,     0,     7,     0,     7,
             0,     7,     0,     7,     0,     7,     0,     7,     0,     7]],
       device='cuda:0')
tensor(2, device='cuda:0')
tensor(11475, device='cuda:0')
tensor(2115, device='cuda:0')
tensor(10, device='cuda:0')
tensor(86, device='cuda:0')
tensor(7, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(7, d