In [None]:
from transformers import AutoProcessor, AutoConfig, AutoModelForCausalLM
import torch
from nnsight import LanguageModel
import nnsight

In [None]:
STAGE1_MODEL = "m-a-p/YuE-s1-7B-anneal-en-cot"


processor = AutoProcessor.from_pretrained(STAGE1_MODEL)
cfg = AutoConfig.from_pretrained(STAGE1_MODEL)

In [None]:
class YuE(LanguageModel):
    def _load_meta(self, repo_id: str, tokenizer_kwargs={}, **kwargs):
        self.repo_id = repo_id

        self._load_config(repo_id, **kwargs)

        self._load_tokenizer(repo_id, **tokenizer_kwargs)
        return AutoModelForCausalLM.from_pretrained(
            repo_id,
            torch_dtype=torch.bfloat16,
            # attn_implementation="flash_attention_2",
        )

    def _load(
        self,
        repo_id: str,
        tokenizer_kwargs={},
        **kwargs,
    ):
        self.repo_id = repo_id

        self._load_config(repo_id, **kwargs)

        self._load_tokenizer(repo_id, **tokenizer_kwargs)
        return AutoModelForCausalLM.from_pretrained(
            repo_id,
            torch_dtype=torch.bfloat16,
            # attn_implementation="flash_attention_2",
        ).to(kwargs["device_map"])

In [None]:
nn_model = YuE(
    STAGE1_MODEL,
    config=cfg,
    # tokenizer=processor.tokenizer,
    device_map="cuda",
)

with nn_model.generate("Hello world", max_new_tokens=10):
    ...

In [None]:
nn_model

In [None]:
tokens = 20
prompt = "Recreate the essence of a classic video game theme with chiptune sounds and nostalgic melodies."

for n in [2]:
    ablate_layer = nn_model.model.layers[n]
    with nn_model.generate([prompt] * 1, max_new_tokens=tokens):
        outputs = nnsight.list().save()

        for _ in range(tokens):
            ablate_layer.output[0][:] = ablate_layer.input[:]
            outputs.append(nn_model.generator.output)
            nn_model.next()

        print("Step")

In [None]:
nn_model.device

In [None]:
nn_model.cpu()
del nn_model
torch.cuda.empty_cache()