In [1]:
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel
from transformer_lens import HookedTransformer

In [2]:
from typing import List

from torch import Tensor


class Multiplexer(nn.Module):
    def __init__(self, n_streams, d_model):
        super().__init__()
        self.n_streams = n_streams
        self.phi = [nn.Linear(d_model, d_model) for _ in range(n_streams)]

    def forward(self, x: List[Tensor]):
        assert len(x) == self.n_streams
        x = [phi(xi) for xi, phi in zip(x, self.phi)]
        x = torch.stack(x).mean(dim=0)
        return x

class Demultiplexer(nn.Module):
    def __init__(self, n_streams, d_model):
        super().__init__()
        self.n_streams = n_streams
        self.phi = [nn.Linear(d_model, d_model) for _ in range(n_streams)]

    def forward(self, x):
        x = [phi(x) for phi in self.phi]
        return x

class MultiplexedModel(nn.Module):
    def __init__(self, n_streams, model: HookedTransformer):
        super().__init__()
        self.n_streams = n_streams
        self.model = model
        self.multiplexer = Multiplexer(n_streams, model.cfg.d_model)
        self.demultiplexer = Demultiplexer(n_streams, model.cfg.d_model)
    
    def forward(self, input_ids, attention_mask, start_layer=0, **kwargs):
        assert len(input_ids) == self.n_streams
        assert len(attention_mask) == self.n_streams
    
        resids = [self.model.forward(input_ids[i], attention_mask[i], stop_at_layer=start_layer, **kwargs) for i in range(self.n_streams)]

        resid_comb = self.multiplexer(resids)
        attention_mask_comb = torch.stack(attention_mask).prod(dim=0)

        output_comb = self.model.forward(resid_comb, attention_mask_comb, start_layer=start_layer, **kwargs)
        
        outputs = self.demultiplexer(output_comb)

        return outputs

In [3]:
# Check if a GPU is available and set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Specify the model ID
model_id = "meta-llama/Meta-Llama-3-8B"

# Load the tokenizer from the Hugging Face library
tokenizer = AutoTokenizer.from_pretrained(model_id)

transformer_model = HookedTransformer.from_pretrained_no_processing(model_id, torch_dtype=torch.bfloat16)
transformer_model.requires_grad_(False)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out)

In [4]:
model = MultiplexedModel(2, transformer_model).to(device)

In [5]:
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from random import shuffle

dataset = load_dataset('roneneldan/TinyStories', split='train[:10]')

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

def shuffle_(x):
    shuffle(x)
    return x

def reshape_list(l, n_streams=2):
    ll = []
    for _ in range(n_streams):
        ll += [shuffle_([l[j+i] for i in range(n_streams)]) for j in range(0, len(l), n_streams)]
    return ll

def tokenization(batch, n_streams=2):
    batch = tokenizer(batch['text'], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    for key in batch:
        batch[key] = reshape_list(batch[key], n_streams=n_streams)
    return batch

dataset = dataset.map(tokenization, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])


Repo card metadata block was not found. Setting CardData to empty.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [6]:
def collate_fn(examples):
    return {
        'input_ids': [torch.stack([example['input_ids'][i] for example in examples]) for i in range(2)],
        'attention_mask': [torch.stack([example['attention_mask'][i] for example in examples]) for i in range(2)]
    }

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir="output",
        overwrite_output_dir=True,
        per_device_train_batch_size=1,
        num_train_epochs=1,
        report_to="none"
    ),
    train_dataset=dataset,
    data_collator=collate_fn
)

trainer.train()

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_346268/540649779.py", line 42, in forward
    resid_comb = self.multiplexer(resids)
                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_346268/540649779.py", line 14, in forward
    x = [phi(xi) for xi, phi in zip(x, self.phi)]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_346268/540649779.py", line 14, in <listcomp>
    x = [phi(xi) for xi, phi in zip(x, self.phi)]
         ^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
