Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,36 +81,40 @@ class CoolModel(pl.LightningModule):
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': self.my_loss(y_hat, y)}
return {'loss': F.cross_entropy(y_hat, y)}

def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}

def configure_optimizers(self):
# REQUIRED
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@pl.data_loader
def tng_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@pl.data_loader
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@pl.data_loader
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
```

Expand Down
12 changes: 5 additions & 7 deletions docs/LightningModule/RequiredTrainerInterface.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ Otherwise, to Define a Lightning Module, implement the following methods:
**Required**:

- [training_step](RequiredTrainerInterface.md#training_step)
- [validation_step](RequiredTrainerInterface.md#validation_step)
- [validation_end](RequiredTrainerInterface.md#validation_end)

- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)

- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)

**Optional**:
- [validation_step](RequiredTrainerInterface.md#validation_step)
- [validation_end](RequiredTrainerInterface.md#validation_end)
- [val_dataloader](RequiredTrainerInterface.md#val_dataloader)
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)

- [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint)
- [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint)
Expand Down
30 changes: 23 additions & 7 deletions pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.multiprocessing as mp
import torch.distributed as dist

from pytorch_lightning.root_module.root_module import LightningModule
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import (
Expand Down Expand Up @@ -312,6 +313,14 @@ def __is_function_implemented(self, f_name):
f_op = getattr(model, f_name, None)
return callable(f_op)

def __is_overriden(self, f_name):
model = self.__get_model()
super_object = super(model.__class__, model)

# when code pointers are different, it was overriden
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
return is_overriden

@property
def __tng_tqdm_dic(self):
tqdm_dic = {
Expand Down Expand Up @@ -345,13 +354,13 @@ def __layout_bookeeping(self):
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)

# determine number of validation batches
self.nb_val_batches = len(self.val_dataloader)
self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches

# determine number of test batches
self.nb_test_batches = len(self.test_dataloader)
self.nb_test_batches = len(self.test_dataloader) if self.test_dataloader is not None else 0
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)

# determine when to check validation
Expand All @@ -372,6 +381,10 @@ def validate(self, model, dataloader, max_batches):
:param max_batches: Scalar
:return:
"""
# skip validation if model has no validation_step defined
if not self.__is_overriden('validation_step'):
return {}

# enable eval mode
model.zero_grad()
model.eval()
Expand Down Expand Up @@ -418,11 +431,13 @@ def validate(self, model, dataloader, max_batches):
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)

# give model a chance to do something with the outputs
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
# give model a chance to do something with the outputs (and method defined)
val_results = {}
if self.__is_overriden('validation_end'):
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)

# enable train mode again
model.train()
Expand All @@ -439,6 +454,7 @@ def get_dataloaders(self, model):
:return:
"""
self.tng_dataloader = model.tng_dataloader

self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader

Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ def forward(self, *args, **kwargs):
def validation_step(self, data_batch, batch_nb):
"""
return whatever outputs will need to be aggregated in validation_end
OPTIONAL
:param data_batch:
:return:
"""
raise NotImplementedError
pass

def validation_end(self, outputs):
"""
Outputs has the appended output after each validation step
OPTIONAL
:param outputs:
:return: dic_with_metrics for tqdm
"""
raise NotImplementedError
pass

def training_step(self, data_batch, batch_nb):
"""
Expand All @@ -67,26 +69,26 @@ def configure_optimizers(self):
@data_loader
def tng_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError

@data_loader
def test_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError
return None

@data_loader
def val_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError
return None

@classmethod
def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .lm_test_module import LightningTestModel
from .no_val_end_module import NoValEndTestModel
from .no_val_module import NoValModel
Loading