Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

PL computes wrong accuracy with drop_last=False in PyTorch Geometric #6889

Closed
rusty1s opened this issue Apr 8, 2021 · 3 comments
Closed
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@rusty1s
Copy link
Contributor

rusty1s commented Apr 8, 2021

馃悰 Bug

PyTorch Lightning computes wrong accuracy when using a DataLoader with drop_last=False in PyTorch Geometric.
There seems to be an issue in which PL cannot determine the correct batch_size of mini-batches.

from typing import Optional

import torch
import torch.nn.functional as F
from torch.nn import Linear
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer,
                               seed_everything)

from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GCNConv, global_mean_pool


class Dataset(LightningDataModule):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        TUDataset('./data', name='MUTAG')

    def setup(self, stage: Optional[str] = None):
        dataset = TUDataset('./data', name='MUTAG')
        self.train_dataset = dataset[:3]
        self.val_dataset = dataset[3:6]
        self.test_dataset = dataset[6:9]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=2)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=2)


class GNN(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = GCNConv(7, 64)
        self.lin = Linear(64, 2)
        self.acc = Accuracy()

    def forward(self, x, edge_index, batch):
        x = self.conv(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

    def training_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        train_loss = F.cross_entropy(y_hat, data.y)
        return train_loss

    def validation_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        acc = self.acc(y_hat.softmax(dim=-1), data.y)
        self.log('val_acc', acc, on_step=False, on_epoch=True)
        return acc

    def test_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        acc = self.acc(y_hat.softmax(dim=-1), data.y)
        print('batch_size', data.num_graphs, 'accuracy', acc, 'shape', y_hat.shape)
        self.log('test_acc', acc, on_step=False, on_epoch=True)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)


def main():
    seed_everything(42)
    datamodule = Dataset()
    model = GNN()
    trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
    trainer.fit(model, datamodule=datamodule)
    trainer.test()


if __name__ == "__main__":
    main()

Here, I am using a dataset with 3 examples and utilize a batch_size of 2. In test_step, the accuracy of each individual mini-batch is:

batch_size 2 accuracy 0.5 shape [2, 2]
batch_size 1 accuracy 0.0 shape [1, 2]

while PyTorch Lightning reports an overall accuracy of 0.25.

Expected behavior

Report accuracy of 0.33.

Environment

  • torch-geometric==master

Additional context

It seems like PL has problems determining the correct batch_size of batches when data doesn't follow the conventional [batch_size, ...] format. However, it shouldn't have a problem in doing so since the batch_size can be easily inferred from the self.acc(y_hat, y_pred) call.

@rusty1s rusty1s added bug Something isn't working help wanted Open to be worked on labels Apr 8, 2021
@tchaton tchaton added the priority: 0 High priority task label Apr 8, 2021
@Borda Borda added priority: 1 Medium priority task and removed priority: 0 High priority task labels Apr 12, 2021
@tchaton
Copy link
Contributor

tchaton commented Apr 30, 2021

Hey @rusty1s,

There is a definitely a bug in Lighting, but hard to resolve without deeper refactor.
However, you can easily make this work by providing the metric directly to self.log as it is supported and much safer.
Also, there should be 1 metric per stage :)

class GNN(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = GCNConv(7, 64)
        self.lin = Linear(64, 2)
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def forward(self, x, edge_index, batch):
        x = self.conv(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

    def training_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        train_loss = F.cross_entropy(y_hat, data.y)
        return train_loss

    def validation_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        self.val_acc(y_hat.softmax(dim=-1), data.y)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)
        return acc

    def test_step(self, data, batch_idx):
        data = data.to(self.device)
        y_hat = self(data.x, data.edge_index, data.batch)
        self.test_acc(y_hat.softmax(dim=-1), data.y)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)


def main():
    seed_everything(42)
    datamodule = Dataset()
    model = GNN()
    trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
    trainer.fit(model, datamodule=datamodule)
    trainer.test()


if __name__ == "__main__":
    main()

@weihua916
Copy link

Has this issue been resolved on master? I installed master, and still see the accuracy of 0.25 when running Matthias' code.

@weihua916
Copy link

Sorry, I just saw the comment to do the work-around. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

No branches or pull requests

5 participants