Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 25, 2019
1 parent 600c755 commit d272f29
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,34 @@ To use lightning do 2 things:
```python
import pytorch_lightning as ptl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

class CoolModel(ptl.LightningModule):

def __init(self):
# not the best model...
self.l1 = torch.nn.Linear(28*28, 10)

def forward(self, x):
return self.l1(x)
return torch.relu(self.l1(x))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': some_loss(y_hat, y)}
return {'tng_loss': self.my_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': some_loss(y_hat, y)}
return {'val_loss': self.my_loss(y_hat, y)}

def configure_optimizers(self):
return [optim.Adam(self.parameters(), lr=0.02)]
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
Expand All @@ -74,8 +81,7 @@ class CoolModel(ptl.LightningModule):

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
```

2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
Expand Down
47 changes: 47 additions & 0 deletions docs/LightningModule/RequiredTrainerInterface.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,53 @@ Otherwise, to Define a Lightning Module, implement the following methods:
- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics)
- [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args)

---
**Minimal example**
```python
import pytorch_lightning as ptl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

class CoolModel(ptl.LightningModule):

def __init(self):
# not the best model...
self.l1 = torch.nn.Linear(28*28, 10)

def forward(self, x):
return torch.relu(self.l1(x))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': self.my_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
```

---

### training_step
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ def configure_optimizers(self):
"""
raise NotImplementedError

def loss(self, *args, **kwargs):
"""
Expand model_out into your components
:param model_out:
:return:
"""
raise NotImplementedError

def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)
Expand Down

0 comments on commit d272f29

Please sign in to comment.