-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
The metric val_loss
was not found for ReduceLROnPlateau
and progress bar display. But using print(val_loss
on validation_step
, and validation_epoch_end
works ok (display Tensor(value)
.
Code sample
class MyModel(pl.LightningModule):
def __init__(self, train_df, val_df, test_df, hparams = Namespace(lr = 0.02)):
# Initialization
super(MyModel, self).__init__()
self.train_df = train_df
self.val_df = val_df
self.test_df = test_df
self.hparams = hparams
# Model Structure
backbone = models.resnet18(pretrained=False)
self.features_extractor = torch.nn.Sequential(*list(backbone.children())[:-1])
self.fc = torch.nn.Sequential(*[
torch.nn.Linear(backbone.fc.in_features, 256, bias=True),
torch.nn.Linear(256, 32, bias=True),
torch.nn.Linear(32, 4, bias=True)
])
# Loss
self._loss = torch.nn.CrossEntropyLoss(weight=weight.float())
def forward(self, x):
x = self.features_extractor(x)
x = x.squeeze(-1).squeeze(-1)
x = self.fc(x)
return x
def loss(self, logits, y):
return self._loss(logits, y)
def training_step(self, batch, batch_idx):
# 1. Inference
x, y = batch
y_hat = self.forward(x)
# 2. Loss
loss = self.loss(y_hat, y)
# 3. Output
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = self.loss(y_hat, y)
return {'val_loss': loss}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [scheduler]
def prepare_data(self):
self.train_ds = ClassificationDataset(self.train_df, 'data/images')
self.val_ds = ClassificationDataset(self.val_df, 'data/images')
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_ds, batch_size=256, num_workers=4, sampler=train_sampler)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_ds, batch_size=64, num_workers=4)
Error
model = MyModel(train_df, val_df, test_df, hparams=Namespace(lr=0.001))
trainer = pl.Trainer(gpus=1, max_epochs=2, train_percent_check=0.01, weights_summary='top')
trainer.fit(model)
---------------------------------------------------------------------------
MisconfigurationException Traceback (most recent call last)
<ipython-input-412-55f3b29fc11e> in <module>
4 # Trainer
5 trainer = pl.Trainer(gpus=1, max_epochs=2, train_percent_check=0.01, weights_summary='top')
----> 6 trainer.fit(model)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, test_dataloaders)
702
703 elif self.single_gpu:
--> 704 self.single_gpu_train(model)
705
706 elif self.use_tpu: # pragma: no-cover
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py in single_gpu_train(self, model)
475 self.optimizers = optimizers
476
--> 477 self.run_pretrain_routine(model)
478
479 def tpu_train(self, tpu_core_idx, model):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_pretrain_routine(self, model)
862
863 # CORE TRAINING LOOP
--> 864 self.train()
865
866 def test(self, model: Optional[LightningModule] = None):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in train(self)
364
365 # update LR schedulers
--> 366 self.update_learning_rates(interval='epoch')
367
368 if self.max_steps and self.max_steps == self.global_step:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in update_learning_rates(self, interval)
779 avail_metrics = ','.join(list(self.callback_metrics.keys()))
780 raise MisconfigurationException(
--> 781 f'ReduceLROnPlateau conditioned on metric {monitor_key}'
782 f' which is not available. Available metrics are: {avail_metrics}.'
783 ' Condition can be set using `monitor` key in lr scheduler dict'
MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: . Condition can be set using `monitor` key in lr scheduler dict
Environment
- CUDA:
- GPU:
- Tesla P100-PCIE-16GB
- available: True
- version: 10.1
- GPU:
- Packages:
- numpy: 1.18.1
- pyTorch_debug: False
- pyTorch_version: 1.4.0
- pytorch-lightning: 0.7.3
- tensorboard: 2.2.1
- tqdm: 4.43.0
- System:
- OS: Linux
- architecture:
- 64bit
- processor:
- python: 3.7.6
- version: Proposal for help #1 SMP Debian 4.9.210-1 (2020-01-20)
Additional context
Dataset
class ClassificationDataset(torch.utils.data.Dataset):
def __init__(self, df: pd.DataFrame, root_dir: pathlib.Path, test=False):
self.df = df
self.test = test
self.root_dir = root_dir
self.transforms = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.number_of_categories = len(self.df.time_cat.cat.categories)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
sample = datasets.folder.default_loader(pathlib.Path(self.root_dir) / pathlib.Path(self.df.iloc[index]['filename']))
sample = self.transforms(sample)
y = int(self.df.time_cat.cat.codes.iloc[index])
return (sample, y)
def __len__(self):
return self.df.shape[0]
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on