Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 25, 2019
1 parent 9b99a02 commit e182559
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,44 @@ from torchvision.datasets import MNIST
class CoolModel(ptl.LightningModule):

def __init(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28*28, 10)
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': self.my_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
return avg_loss

def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
```

2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
Expand Down
21 changes: 11 additions & 10 deletions docs/LightningModule/RequiredTrainerInterface.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,43 +38,44 @@ from torchvision.datasets import MNIST
class CoolModel(ptl.LightningModule):

def __init(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28*28, 10)
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': self.my_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
return avg_loss

def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
```

---
Expand Down

0 comments on commit e182559

Please sign in to comment.