Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Mixtral generates nonsense #570

Closed
1 task done
joelburget opened this issue May 4, 2024 · 42 comments
Closed
1 task done

[Bug Report] Mixtral generates nonsense #570

joelburget opened this issue May 4, 2024 · 42 comments
Assignees

Comments

@joelburget
Copy link
Contributor

Describe the bug
Screenshot 2024-05-04 at 4 41 10 AM

I followed the instructions in docs/source/content/special_cases.md as well as I could tell (ran the model in both full precision and with HookedTransformer.from_pretrained_no_processing), yet my model generations were nonsensical.

Code example

from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = HookedTransformer.from_pretrained_no_processing(
    "mistralai/Mixtral-8x7B-v0.1",
    n_devices=4,
)

# test model actually works
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            verbose=False,
            max_new_tokens=50,
        )
    )

System Info
Describe the characteristic of your environment:

  • Describe how transformer_lens was installed: pip install sae-lens
  • What OS are you using? Ubuntu 22.04 (runpod/pytorch:2.0.1-py3.10-cuda11.8.0-devel-ubuntu22.04)
  • Python version: 3.10

Additional context
Running on a 4x A100 SXM system on Runpod.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@joelburget joelburget changed the title Mixtral generates nonsense [Bug Report] Mixtral generates nonsense May 4, 2024
@jbloomAus
Copy link
Collaborator

Generate is not a good way to check a model is running properly. Can you run the following and share the results?

a_large_chunk_of_text = "Generate is not a good way to check a model is running properly. Can you run the following and share the results?"
loss = model(a_large_chunk_of_text, return_type="loss")
print(loss.item())

@joelburget
Copy link
Contributor Author

Looks like the loss is ~6.82

Screenshot 2024-05-15 at 7 19 37 AM

@jbloomAus
Copy link
Collaborator

This could be related to this issue: #591 . If I were you, I'd pip install opt_einsum, then see if you get the same issue. opt_einsum should be automatically used by pytorch if it's installed and the other issue was solved by using it. Let me know if you still get garbage afterward.

@joelburget
Copy link
Contributor Author

Still seems quite bad even when verifying that I have opt_einsum installed. For good measure I checked a few more times and got 5.5, 6.7, and 8.7. So note quite as bad 12.5 but still problematic.
Screenshot 2024-05-21 at 12 55 11 PM

@jbloomAus
Copy link
Collaborator

jbloomAus commented May 21, 2024 via email

@joelburget
Copy link
Contributor Author

I'm not using any special tokens at all. The a_large_chunk_of_text code block is what you pasted verbatim.

@jbloomAus
Copy link
Collaborator

jbloomAus commented May 21, 2024 via email

@joelburget
Copy link
Contributor Author

Unfortunately I don't think that's the issue. Here are the instructions from the Mixtral HF repo:

Screenshot 2024-06-02 at 5 55 36 PM

Running those in a Colab we can see that the only special token they prepend is BOS (1):

Screenshot 2024-06-02 at 5 56 21 PM

TransformerLens of course does this by default but I double-checked: adding prepend_bos doesn't change the behavior and (2) the tokenizer output is the same as with the official instructions.

Screenshot 2024-06-02 at 5 54 39 PM

@joelburget
Copy link
Contributor Author

I also checked following the instructions on the HF repo without using TransformerLens. The generation looks reasonable. I'm not sure exactly how to interpret the loss or if it matters.

Screenshot 2024-06-03 at 7 20 43 AM

@bryce13950 bryce13950 self-assigned this Jun 4, 2024
@bryce13950
Copy link
Collaborator

I wonder if this is a memory issue. The only difference between the config for TransformerLens and HuggingFace is n_ctx being capped at 2048 for memory concerns, where it is 32768 in HuggingFace. The generation isn't entirely nonsense, it's just mostly generating French, and every once in a while Spanish or English. What it is actually generating does for the most part make sense as something that would follow the idiom being passed through in the various languages. I could imagine that a lack of memory could cause something like languages to get mixed up, and I think it may be a good place to start investigating.

@neelnanda-io
Copy link
Collaborator

neelnanda-io commented Jun 4, 2024 via email

@bryce13950
Copy link
Collaborator

Well I think I have a pretty decent idea on how to start with this. I will work with @joelburget to try and isolate the error, and hopefully we can find it relatively quickly.

@joelburget
Copy link
Contributor Author

joelburget commented Jun 9, 2024

I ran some experiments in this notebook: https://gist.github.com/joelburget/bae5ea4d997e804b2a65d02d5b61f5bc

  1. all weights seem to match
  2. when running a random tensor through block 1:
    a. 89 / 4096 mlp outputs match. Those that don't seem to be small differences like 0.38261502981185913 vs 0.38261523842811584
    b. 254 / 4096 attention outputs match.
    c. 3218 / 4096 block outputs match. Again we see small differences like -0.20334099233150482 vs -0.20334100723266602.

I'm surprised that there are more matching block outputs than mlp or attention, but whatever, I guess this is possible.

Ideas for where to go from here?

@bryce13950
Copy link
Collaborator

I am going to play around with this further tomorrow to see if I can figure out anything. I wonder if the problem lies in the MLP component itself. Maybe there is a discrepancy there, that only reveals itself in larger operations? We played around with the MOE component a bit last week, and that seems to be working as expected. We started working on looking at the weights, so having that ruled out is definitely a god step. I should have quite a bit of time to look at this, providing that nothing else comes up. I am going to start with playing around with the MLP outputs to see if I can figure anything out.

@bryce13950
Copy link
Collaborator

@joelburget I have been messing around with this for the last 3-4 hours. Unfortunately, I was not able to load the model. I am working on the branch mixtral-playing, and there are changes that I am trying to see the results for. If you have time can you try running this on my branch to see if you have any different results? I was not able to load the model on a system with 4 24 gb GPUs, and I realized while looking at this that the way a device is being selected is pretty suboptimal, and in the case of loading this model, it is simply not possible with those specs due to it trying to load most of the model onto the first device. There is a much better way to do this, and I think I will end up getting side tracked implementing that before coming back to this.
For the time being, if you can run my branch, and let me know if there are any differences, then that would be very helpful. Maybe it is working on that branch? It would be super cool if it was, but it's going to be a couple more days before I can go back to debugging this issue.

@Butanium
Copy link
Contributor

Butanium commented Jun 13, 2024

@haeggee suggested this:
The softmax is usually done with torch.float dtype. see e.g. https://github.com/mistralai/mistral-inference/blob/c24ac864ab623ca39bda4f48c334eed6e55f13a2/src/mistral_inference/moe.py#L29.
This is missing in the TL code and might explain some difference, but that would really need to be checked

@bryce13950
Copy link
Collaborator

There's quite a bit out of sync from the transformers implementation https://github.com/huggingface/transformers/blob/b7672826cad31e30319487af876e608d8af7d37b/src/transformers/models/mixtral/modeling_mixtral.py#L843C1-L843C69. We looked at this quite a bit in our little coding session, but we were not able to pinpoint the actual cause. I am going to break down the forward on this, and add a full suite of unit tests with input/output grabbed from transformers to be able to test it in a more isolated manner.

@bryce13950
Copy link
Collaborator

I think this has something to do with the W_Gate variable. That is the biggest difference between our implementation and the hugging face implementation. I may be completely off on this, but maybe the next thing to look at. I made a few changes to make it more inline with hugging face on my previously mentioned branch, and if anyone wants to mess around with that Gate until I have a chance to mess with it again, it may be a good place to look yet. There are still other differences between the two implementations, but that variable appears to be the most substantial at a glance.

@bryce13950
Copy link
Collaborator

Holy shit
image

I think I got it

@bryce13950
Copy link
Collaborator

OK, so not 100% there, but it seems like it is closer. I got this result by changing the dtype on the W_Gate to torch.float. The second inference was mixed between English and French again, and the third was completely French. I think the issue lies somehow in the einops operation right at the top. Don't have more time to look at it today, but I think this is real progress.

@joelburget
Copy link
Contributor Author

I tried @bryce13950's change, which unfortunately didn't seem to help. I tried both his branch (see mixtral-playing.ipynb) and a modification with the line weights /= weights.sum(dim=-1, keepdim=True) commented out (see mixtral-playing-2.ipynb). Both had significant differences from the Huggingface version: https://gist.github.com/joelburget/ffd4705167ee410b04e63758d2689e45

@joelburget
Copy link
Contributor Author

joelburget commented Jun 14, 2024

This time I generated histograms showing that MLPs and Attention outputs are both off by similar amounts.

ETA: I hadn't noticed that attention outputs are ~1e7 so they're not "off by similar amounts" at all.

@Butanium
Copy link
Contributor

Butanium commented Jun 14, 2024

Could it be a sliding attention issue ?
Mixtral shouldn't use sliding attention
See: https://www.reddit.com/r/LocalLLaMA/comments/18k0fek/psa_you_can_and_may_want_to_disable_mixtrals/

@bryce13950
Copy link
Collaborator

@Butanium I will definitely play with that this afternoon

@joelburget
Copy link
Contributor Author

@bryce13950's latest change seems promising. Still seems not exactly right, but closer?

Screenshot 2024-06-14 at 10 34 02 AM Screenshot 2024-06-14 at 10 34 26 AM Screenshot 2024-06-14 at 10 34 22 AM Screenshot 2024-06-14 at 10 34 15 AM

@bryce13950
Copy link
Collaborator

@joelburget Yeah I think the issue is a composite of a few different problems. The big breakthrough was last night with setting the W_Gate to use torch.float. That seems to have almost entirely solved the problem between layers. The attention now definitely seems like the most reasonable place to look next, so I will checkout the notes @Butanium shared to see if that reveals more clues. I was looking at that segment of config much earlier, but partially ruled it out to look at the expert routing. Looking at it now, there were probably problems in both.

@bryce13950
Copy link
Collaborator

More progress. I changed the root dtype of the config to bfloat16, and the accuracy between MLP outputs more than doubled. These are the histograms I haven now...

image image image

It looks like this is very close. The attention sliding window was already disabled, so that's not the issue there. I will try to poke around a bit more to see if I can get that closer.

@bryce13950
Copy link
Collaborator

OK, so another note. I tired changing n_ctx from 2048 to hf_config.max_position_embeddings, and that seemed to make things visible worse, so that can be ruled out!

@bryce13950
Copy link
Collaborator

OK, so here is the last output for the day.

image

We are now consistently generating English the first pass, and then it starts diverging, which leads me to believe that the problem now lies in a mutable variable somewhere through the generation. I think we are probably very close though.

@joelburget
Copy link
Contributor Author

joelburget commented Jun 16, 2024

We are now consistently generating English the first pass, and then it starts diverging, which leads me to believe that the problem now lies in a mutable variable somewhere through the generation.

Interesting, my first guess would have been that there's some small error which accumulates over many tokens. Or, what kind of mutable variable are you thinking of?

Out of curiosity, I adapted my script to GPT2 and compared HF vs TL. The differences are ~1e-5, bigger than the differences we're currently getting for Mixtral (at least for the first layer alone). I'll admit I'm a bit confused why there are any differences at all, or when we can be confident that the differences are immaterial.

Screenshot 2024-06-16 at 7 44 48 AM Screenshot 2024-06-16 at 7 44 43 AM Screenshot 2024-06-16 at 7 44 37 AM

@bryce13950
Copy link
Collaborator

When I am referring to mutable variables I am basically talking about any variable changing in a pass, so the example you gave is exactly what I am referring to among potentially other things. If we can figure out what exact variable mutation is causing the error, we should then be able to figure out where that error is coming from. We should be able to do that by looking at individual variables, preventing them from changing in any way between passes, and then seeing if passes continue to generate English between multiple passes. If we can do that, we should be able to work backwards and figure out where the variable is changing.

As for the other errors, it honestly doesn’t really surprise me. I haven’t run smaller mixtral models since the PR was open to add them, but the smaller ones were working at that point. I think there is a distinct possibility that these sorts errors are negligible in smaller models, but then manifest in what we are currently seeing in these large models. Hopefully this whole effort will shed some light on how to make sure things are better kept in check with Hugging Face so that we can reliably support even larger models in the future.

@Butanium
Copy link
Contributor

Are mixtral logits similar on a single prompt ? Is this just a .generate issue ?

@bryce13950
Copy link
Collaborator

@Butanium The histograms shown here are from a single pass through generate, I would think it would be the same as a single call to prompt. Honestly, it takes so long to run the full script, that I haven't ran it with prompt to verify. I can definitely play around to see if there is anything to that if I have a pause on testing various states.

We have a new baseline. This was my last run.

image image image

And the first generation...

image

I also added a loop to the bottom of the script to run 10 more passes with the two different prompts after the histograms are generated. This is the result from the 2nd to 5th passes.

image

Big step forward. They start degrading a very small amount in the second one, and then greatly in the third one. Only 1 of the 5 are generating French now, and only partially. Getting there.

@bryce13950
Copy link
Collaborator

I completely replaced the MoE and MLP implementation with something that matches transformers exactly, outside of reshaping tensors to make them compatible with the rest of TransformerLens. That was a huge step forward, and it was working perfectly for the first 3 inferences. The forth one (a Once upon a time generation) spit out German, which seems better, since German is in the same language family as English. After that it went back to English for the next "Hello my name is" generation. With that, I think it is now probably pretty safe to rule out the actual MoE logic as the problem, since it is identical to the one we are trying to simulate. I am going to try to connect it back to the existing GatedMLP implementation, and see if it retains the same performance. I then need to go through the MoE, connect hooks, and make a couple more changes.
@joelburget If you can, can you set this up on my branch again, and just have it loop a lot over a variety of prompts to get an idea of how often it is not generating English at this point?

@bryce13950
Copy link
Collaborator

Here's a quick grab of the 2nd-6th generations I saw

image

@joelburget
Copy link
Contributor Author

Interesting @bryce13950, I started doing something similar, see #641. I'll take a look through your changes a bit later!

@bryce13950
Copy link
Collaborator

@joelburget Very good! Your implementation is much improved. I am going to be pretty much rewriting the MLP structure during a shared coding session on Thursday. If you can address the couple changes I requested on the PR before then, I will get your PR wrapped up beforehand, and start the session with it included in my branch. If not, then I will start the session with wrapping up your PR.
If you have time to run the generate function through a lot of generations, then that would be incredibly helpful for data gathering at the moment. I am wondering if the German generation yesterday was a fluke. Unfortunately, I had to wrap it up when I was looking at it, and I was not able to get anything past the 6th iteration. I was very surprised that it went back to English though, and I would have liked to have seen further iterations. If you run this on 100 loops, with the prompt "Once upon a time", I would be very curious to see how often it spits out something that is not English. Maybe it will only happen a handful of times, at which point, I think we should wrap this up, and look into moving onto larger initiatives to improve accuracy across the whole project.
I spent a lot of time debugging this yesterday, and I probably can't spend a lot of time on this until Thursday. If you have the ability to run that loop either today, or tomorrow, it would be very helpful.

@bryce13950
Copy link
Collaborator

Alright, so this is probably going to be the last time I really look into this for a while. I discovered that if you generate only "Once upon a time" the generation still degrades rather rapidly. I am not sure if focusing on generation for that specific phrase is a really great use of time. That phrase is an idiom, and not just an idiom, but one that doesn't really make a whole lot of sense. We are all very use to it, but if you break it down word by word, it really doesn't make a whole lot of sense from a logical standpoint. When you translate it into different languages, the words themselves do not directly translate, and the phrase is often a series of words that make as little or even less sense than they do in English. All that to say that it is specifically a really weird and difficult piece of language.
When you generate "Hello My name is" the generation seems relatively stable, and when I mixed the two together, everything being generated for hello my name is was in English, with some occasional degradation. The generations for Once Upon a time slowly started generating French and German as the loop went on, but even with the degradation there, "Hello my name is" continued to generate English. I am going to clean up my branch, add hooks to the reworked components, and put the changes into a release. We should keep this open to keep track of it, but I don't think it makes much sense to directly focus on this. This has revealed that there are potentially quite a few inaccuracies in some of TransformerLens implementations, and fixing those inaccuracies may accumulate into improving the compatibility for this model as well.

Here is my last generation of something like 20-30 passes...

Hello my name is KaitlynTeachin13. My mom Catherine, is obviously the teacher that I have mentioned. She was also absent at my recitals. My mom never came to one of my recitals. I was very hurt that she couldn’

Once upon a time frame, there was an ancestry who lived in a indispensable globe of snobs. Except u be naive that, it was a state where motion seemed to ruin into product every time u be naive that again a hunch,
Hello my name is Cortez Thompson and I am currently with the Fresno County Office of Education. Working in a class that teaches all education students how to work with exceptional children. Many of these children come from low income homes, most of these children have I
Once upon a time, netEnt das Netz von Überwindung ungerechten Gesetzgebung versammeln Gesetz Streitigkeiten im Umgang mit Netlen. Kurz gesagt, wenn Sie nach enthält alle geręd
Hello my name is _______d&SO

I'm Backing-up this Pledge from AllenClassic4
and on Behalf of Bo, I’d

Like to Thank the Entire Bartow County Jury
Once upon a time Orks, in a time not so far ago, when humanity just begun to wander its first steps into outer space then came an lord Ork, Gorzag KRUM, a warboss who wanted rule all Um dia não t
Hello my name is Rebecca I am 29 years old and this is my story. CENSORED, because my god these models can go dark places
Once upon a time, un dénommé Jean-Roger Philippe lança au monde un joli bout de musique : Nicolas Godin de son sombre viseur, Uwe Schmidt de son bout du bâton. Un alignement
Hello my name is José Hernandez. I am a junior attending St. Thomas Aquinas College and studying Business administration. I started off did not really wanted to do business because I felt that it was very hard since I was not good with doing calculations. Growing
Once upon a time when children were inmates, they were given different punishments depending on the severity of their offenses. Three main forms of physical punishment were the common practice: corporal, solitary confinement (aka, the “separate cell”
Hello my name is actress singer songwriter. I recently have been accepted into a course to further improve my knowledge and skills as a musician to help me professionally perform as a musical actress. The course I am aiming to undertake is a musical theatre course. Not
Once upon a time going on still now…

Our Story Started 80 Years Ago

When members of the Publishers’ Publicity Association of New York were organizing their own association separate from members of Chicago’s, they decided to
Hello my name is brad. Im a professional ferrari ochro scot full full quiper yon. Ive also built me. Yes i used to build watches for a living took the Yoni and i was schisma and they were like all
Once upon a time vor einiger Zeit war lächeln deutsch-rusche, aber dann gabs was mit Hitler und Stalin, und vor Beginn des Krieges hat Stalin leute zur Forsetzung der Prämie nicht geben kann, die auch ple
Hello my name is james i think i can know ive learnt alot on this site converting my arduino uno to a small wifi device. Basically this makes you upgrade to arduino “YUN” which none of shop have here
Once upon a time dans un cottage traditionnel au milieu de la Côte D'Ur, au pays des ballons rouges, du blanc et bleu. Dans cette maison à côté d'une forêt ainsi qu'un lac,
Hello my name is Eli but for the it people we aren name so for the and name are so for the name not only is so for
Once upon a time encontraba en una librería y me percaté de hojeando un cuento de hadas y súbitamente me sobresaltó la seguridad expresada a algunos padres sobre ciertas caracter
Hello my name is Houng Taing and I am a sophomore at City College of New York. I am 23 and I study political science. I am Cambodian and grew up in western greene. I really love history and have been taking
Once upon a time Réunion, dortoir du stagiaire…

Une légende que le stagiaire récurrent et passionné vous répète,
oui, tous les stagiaires parisiens squattaient le stage de Réunion

@bryce13950
Copy link
Collaborator

In this weeks shared code session, we spent sometime adding Baichuan to TransformerLens. I was curious to add this model, since it generates both English and Chinese. On top of that, a lot of modern usage of Chinese mixes some English into the language for some modern things.
e.g. the acronym AI is probably just that in Chinese, so you might have something like “我爱AI” (I love AI). (That’s just an example that I pulled out of the air which may or may not be true, but the point is that these sorts of things exist in Chinese today.)

I was curious given this relatively uncommon quirk to the language how it would work within TransformerLens, and it did generate a lot of nonsense, with the languages mixed. I then merged the last bit of work @joelburget did on the attention side of things, and that did seem to improve it marginally. I am writing tests for a big rework of all MLP like components in the library at the moment, and I am very curious to see what changes both here, and with the Baichuan models that were added the other day. I should have that done the first half of the week, and we should have a release ready at that point with this greatly improved, and a lot more understood about TransformerLens accuracy!

@bryce13950 bryce13950 mentioned this issue Jul 4, 2024
10 tasks
@bryce13950
Copy link
Collaborator

👍 We are just about perfect.

image

Still a little bit off here, but very close.

image

100% perfect

image

This is now the furthest off, but I don't think it matters. I have a couple things left to wrap up as far as final touches on my branch, and I will then close it out tomorrow. I will watch a few more generations for a bit to see if they decay at all.

@bryce13950
Copy link
Collaborator

Alright! I did 5 generations, and no sign of degradation. 2 for "Hello my name is", and 3 for "Once upon a time,"

image

I think we are good!

@bryce13950
Copy link
Collaborator

Merged into dev. Release coming up shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants