Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS Inf/Nan Loss #13285

Closed
justusschock opened this issue Jun 14, 2022 · 8 comments · Fixed by #14368
Closed

MPS Inf/Nan Loss #13285

justusschock opened this issue Jun 14, 2022 · 8 comments · Fixed by #14368
Assignees
Labels
accelerator: mps Apple Silicon GPU bug Something isn't working
Milestone

Comments

@justusschock
Copy link
Member

justusschock commented Jun 14, 2022

I am encountering a bug, namely some of my neural network's targets are corrupted using the M1 GPU. This does not happen on CPU. Specifically, some targets are set to large values (~ -2+e25), resulting in inf/nan loss.

I have isolated the behavior by stepping through the code and verifying at which steps the targets are still in the range (-1,1) as intended. The corrupted values first occur in the batch argument of validation_step(). I implemented a custom collate_fn to verify that nothing is wrong with my datasets and dataloaders. The correct targets leave the collate_fn, but then some of them appear corrupted within validation_step().

As I wrote above, the corrupted targets appear only when using the M1 GPU, whereas on the CPU everything works correctly. I could investigate this issue further if someone could tell me what happens between collate_fn and validation_step, i.e. where else I should step through the code to identify the source of corruption.

Originally posted by @gloryVine in #13102 (comment)

cc @akihironitta @justusschock

@justusschock
Copy link
Member Author

Hi @gloryVine, I created a separate issue for this :)

Thanks for the report. To be honest, I likely suspect this to be an issue with core PyTorch and not our code. We are only doing the device mapping and as long as everything is running on MPS, we are doing our job correctly.

As this is still an experimental feature (from both, PyTorch and our side) it might very well be the case, that some operations do not yet work as expected. For this it would be very helpful to pin down which operation specifically does not behave as expected. Therefore it would be helpful to run the network with the trainerflag detect_anomaly=True to get more insights on the operation. Once you do that, I would ask you to provide an example tensor for the operation to reproduce this behaviour to upstream it to pytorch :)

@akihironitta akihironitta added the accelerator: mps Apple Silicon GPU label Jun 14, 2022
@carmocca
Copy link
Contributor

@gloryVine
Copy link

gloryVine commented Jun 18, 2022

@justusschock I ran the trainer with the flag, nothing changed, including the output.

@carmocca I did the following at the position in your first link: Immediately after next(data_fetcher), I did the check (abs(batch[1])>1000).any(), which for some batch becomes True (even though my targets are all supposed to be in the (-1,1) range). One example of a corrupted target is -170545583224674966934149660672.
To make sure that this is not an issue with the dataloader, with y = data_fetcher.dataloader.dataset.y I did the checks (abs(y)>1000).any() and also abs(y.to("mps")>1000).any(), both of which are False.
As written in my initial post, I also did the size check at the end of collate_fn and the targets are all ok at that point.

@j0rd1smit
Copy link
Contributor

j0rd1smit commented Aug 15, 2022

@carmocca The issue is caused by PyTorch lightning, not PyTorch. Here is a minimal MNIST example that works with native PyTorch. However, I refactor the exact same code into the lightning format it does not work due to overflow/underflows in the targets.

Native PyTorch version:

import torch
import torchvision.transforms as transforms
import tqdm
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

batch_size = 128
device = torch.device("mps")

network = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(28 * 28, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(64, 10),
)
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
loss_func = torch.nn.CrossEntropyLoss()

transform = transforms.ToTensor()
dataset_train = MNIST("/tmp/data", train=True, download=True, transform=transform)
dataset_test = MNIST("/tmp/data", train=False, download=True, transform=transform)
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

network.to(device)
for epoch_idx in range(10):
    train_losses = []
    train_accs = []
    network.train()
    for x, y in tqdm.tqdm(train_dataloader):
        assert (
            torch.sum(y >= 10) == 0
        ), f"y has more than 10 unique values but got {y}"
        x = x.to(device)
        y = y.to(device)
        logits = network(x)
        y_pred = torch.argmax(logits, dim=1)
        loss = loss_func(logits, y)
        loss.backward()
        accuracy = torch.mean((y_pred == y).float())

        optimizer.step()
        optimizer.zero_grad()

        train_losses.append(loss.detach())
        train_accs.append(accuracy)

    print(
        f"Epoch {epoch_idx} train loss: {torch.mean(torch.stack(train_losses))} train acc: {torch.mean(torch.stack(train_accs))}"
    )

    val_losses = []
    val_accs = []
    network.eval()
    for x, y in tqdm.tqdm(test_dataloader):
        x = x.to(device)
        y = y.to(device)
        logits = network(x)
        y_pred = torch.argmax(logits, dim=1)
        loss = loss_func(logits, y)
        accuracy = torch.mean((y_pred == y).float())

        val_losses.append(loss.detach())
        val_accs.append(accuracy)

    print(
        f"Epoch {epoch_idx} val loss: {torch.mean(torch.stack(val_losses))} val acc: {torch.mean(torch.stack(val_accs))}"
    )

PyTorch lighting version:

import pytorch_lightning as pl
import torch
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 3, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Dropout2d(0.1),
            torch.nn.Flatten(),
            torch.nn.Linear(9216, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout1d(0.1),
            torch.nn.Linear(128, 10),
        )

        self.loss_func = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_nb):
        x, y = batch
        # I did not use torch.unique() such that you do not need PYTORCH_ENABLE_MPS_FALLBACK=1
        assert torch.sum(y >= 10) == 0, f"y has more than 10 unique values but got {y}"

        logits = self.network(x)
        y_pred = torch.argmax(logits, dim=1)
        loss = self.loss_func(logits, y)

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", torch.mean((y_pred == y).float()), prog_bar=True)

        return loss

    def validation_step(self, batch, batch_nb):
        x, y = batch
        print(x.device)
        logits = self.network(x)
        y_pred = torch.argmax(logits, dim=1)
        loss = self.loss_func(logits, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("vall_acc", torch.mean((y_pred == y).float()), prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.network.parameters(), lr=0.001)

batch_size = 128

network = Model()

dataset_train = MNIST(
    "/tmp/data", train=True, download=True, transform=transforms.ToTensor()
)
dataset_test = MNIST(
    "/tmp/data", train=False, download=True, transform=transforms.ToTensor()
)
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

trainer = Trainer(
    accelerator="mps",
    logger=False,
    devices=1,
    enable_checkpointing=False,
)
trainer.fit(
    network, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader
)

Output

AssertionError: y has more than 10 unique values but got tensor([9024291576, 5019830272, 9024292816, 5019831392, 9024295448, 5019831152,
        9024292112, 5019831072,      -4096,          8, 9024291552, 5019831552,
             -4096,          9,      -4096,          3, 9024295376, 5019831232,
             -4096,          8, 9024291584, 5019830512,      -4096,          8,
        9024291616, 5019830432, 9024291448, 5019831312, 9024292104, 5019831472,
             -4096,          3,      -4096,          5,      -4096,          5,
             -4096,          4, 9024291600, 5019830672,      -4096,          8,
             -4096,          6,      -4096,          4,      -4096,          3,
             -4096,          4,      -4096,          1,      -4096,          8,
             -4096,          8,      -4096,          2,      -4096,          9,
        9024291624, 5019830992,      -4096,          4,      -4096,          9,
             -4096,          5, 9024295368, 5019830832,      -4096,          3,
             -4096,          9,      -4096,          8,      -4096,          7,
             -4096,          2,      -4096,          2, 9024295352, 5019830352,
             -4096,          5,      -4096,          5, 9024291608, 5019830752,
             -4096,          3,      -4096,          3, 9024292096, 5019831632,
             -4096,          4,      -4096,          8,      -4096,          6,
             -4096,          4,      -4096,          0,      -4096,          7,
        9024292920, 5019830912,      -4096,          7,      -4096,          1,
             -4096,          4,      -4096,          6,      -4096,          6,
             -4096,          1,      -4096,          3,      -4096,          0,
        9024291592, 5019830592], device='mps:0')

MacOs: Monterey 12.5 (21G72)
System: MacBook Pro (16-inch, 2021)
Chip: Apple M1 Pro
python version 3.9.12 and requirements.txt:

absl-py==1.2.0; python_version >= "3.7"
aiohttp==3.8.1; python_version >= "3.7"
aiosignal==1.2.0; python_version >= "3.7"
async-timeout==4.0.2; python_version >= "3.7"
attrs==22.1.0; python_version >= "3.7"
cachetools==5.2.0; python_version >= "3.7" and python_version < "4.0" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7")
certifi==2022.6.15; python_version >= "3.7" and python_version < "4"
charset-normalizer==2.1.0; python_version >= "3.7" and python_version < "4" and python_full_version >= "3.6.0"
colorama==0.4.5; python_version >= "3.7" and python_full_version < "3.0.0" and platform_system == "Windows" or python_full_version >= "3.5.0" and python_version >= "3.7" and platform_system == "Windows"
cycler==0.11.0; python_version >= "3.7"
fonttools==4.35.0; python_version >= "3.7"
frozenlist==1.3.1; python_version >= "3.7"
fsspec==2022.7.1; python_version >= "3.7"
google-auth-oauthlib==0.4.6; python_version >= "3.7"
google-auth==2.10.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7"
grpcio==1.48.0; python_version >= "3.7"
idna==3.3; python_version >= "3.7" and python_version < "4"
importlib-metadata==4.12.0; python_version < "3.10" and python_version >= "3.7"
kiwisolver==1.4.4; python_version >= "3.7"
markdown==3.4.1; python_version >= "3.7"
markupsafe==2.1.1; python_version >= "3.7"
matplotlib==3.5.3; python_version >= "3.7"
multidict==6.0.2; python_version >= "3.7"
numpy==1.23.2; python_version >= "3.8"
oauthlib==3.2.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7"
packaging==21.3; python_version >= "3.7"
pillow==9.2.0; python_version >= "3.7"
protobuf==3.19.4; python_version >= "3.7"
pyasn1-modules==0.2.8; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7"
pyasn1==0.4.8; python_version >= "3.7" and python_full_version < "3.0.0" and python_version < "4" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7") or python_full_version >= "3.6.0" and python_version >= "3.7" and python_version < "4" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7")
pydeprecate==0.3.2; python_version >= "3.7"
pyparsing==3.0.9; python_full_version >= "3.6.8" and python_version >= "3.7"
python-dateutil==2.8.2; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.7"
pytorch-lightning==1.7.1; python_version >= "3.7"
pyyaml==6.0; python_version >= "3.7"
requests-oauthlib==1.3.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7"
requests==2.28.1; python_version >= "3.7" and python_version < "4" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7")
rsa==4.9; python_version >= "3.6" and python_version < "4" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7")
setuptools-scm==6.4.2; python_version >= "3.7"
six==1.16.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7"
tensorboard-data-server==0.6.1; python_version >= "3.7"
tensorboard-plugin-wit==1.8.1; python_version >= "3.7"
tensorboard==2.10.0; python_version >= "3.7"
tomli==2.0.1; python_version >= "3.7"
torch==1.12.1; python_full_version >= "3.7.0"
torchmetrics==0.9.3; python_version >= "3.7"
torchvision==0.13.1; python_version >= "3.7"
tqdm==4.64.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7"
typing-extensions==4.3.0; python_version >= "3.7" and python_full_version >= "3.7.0"
urllib3==1.26.11; python_version >= "3.7" and python_full_version < "3.0.0" and python_version < "4" or python_full_version >= "3.6.0" and python_version < "4" and python_version >= "3.7"
werkzeug==2.2.2; python_version >= "3.7"
yarl==1.8.1; python_version >= "3.7"
zipp==3.8.1; python_version < "3.10" and python_version >= "3.7"

@justusschock
Copy link
Member Author

Hey @j0rd1smit thanks for the MWE and sorry for the inconvenience. I could indeed reproduce the issue on my side, but so far I have no clue what we could do causing overflows. I will investigate though and will get back to you once I found something!

cc @akihironitta @awaelchli who might have ideas where this could come from

@carmocca carmocca added this to the pl:1.7.x milestone Aug 17, 2022
@carmocca carmocca added the bug Something isn't working label Aug 17, 2022
@j0rd1smit
Copy link
Contributor

j0rd1smit commented Aug 19, 2022

I have been able to solve this issue for myself it by making the following change here:

_MPS_DEVICES = ("mps", torch.device("mps:0"))
...
if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES:
    kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs)

I happy to make a PR for it. However, I wanted to discuss the solution first because I'm not if this solution solves the real problem or just one of the symptoms. What do you think @justusschock

@akihironitta
Copy link
Contributor

@j0rd1smit Great finding! Quick search led me to pytorch/pytorch#83015, and I've just confirmed that, with .to(..., non_blocking=True), some elements of a tensor can somtimes get nan in pure PyTorch. I think this behaviour and the patch you suggest explain the stability difference between PL and PyTorch.

@justusschock
Copy link
Member Author

@j0rd1smit That's indeed a great finding! the solution does look reasonable. Please go ahead with a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: mps Apple Silicon GPU bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants