Skip to content

Fabric leaks the default device on exception #18705

@carmocca

Description

@carmocca

Bug description

See reproduction below

What version are you seeing the problem on?

master

How to reproduce the bug

import pytest
import torch

import lightning as L
from lightning.fabric.plugins import HalfPrecision


class MyPlugin(HalfPrecision):
    def init_context(self):
        # arbitrary error
        raise NotImplementedError


fabric = L.Fabric(devices=1, accelerator="cuda", plugins=MyPlugin("16-true"))

with pytest.raises(NotImplementedError):
    with fabric.init_module():
        pass

default_device = torch.tensor(0).device
assert default_device.type == "cpu", default_device

Error messages and logs

The assertion fails

Environment

Current master

More info

This blocks #18704

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions