Skip to content

Commit

Permalink
cleaned readme
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 25, 2019
1 parent b989358 commit 74817c2
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,65 @@ gpu training, etc... every time you start a project. Let lightning handle all of
data and what happens in the training, testing and validation loop and lightning will do the rest.

To use lightning do 2 things:
1. [Define a Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/trainer_cpu_template.py).
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py).
1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
```python
from pytorch_lightning import LightningModule
import torch

class CoolModel(LightningModule):

def __init(self):
self.l1 = torch.nn.Linear(28*28, 10)

def forward(self, x):
return self.l1(x)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': some_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': some_loss(y_hat, y)}

def configure_optimizers(self):
return [optim.Adam(self.parameters(), lr=0.02)]

@property
def tng_dataloader(self):
mnist = MNIST('path/to/save', train=True)
return DataLoader(mnist, batch_size=32)

@property
def val_dataloader(self):
mnist = MNIST('path/to/save', train=False)
return DataLoader(mnist, batch_size=32)

@property
def test_dataloader(self):
mnist = MNIST('sam/as/val/for/simplicity', train=False)
return DataLoader(mnist, batch_size=32)
```

2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
```python
from pytorch_lightning import Trainer
from test_tube import Experiment

model = CoolModel()

# fit on 32 gpus across 4 nodes
exp = Experiment(save_dir='some/dir')
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])

trainer.fit(model)

# see all experiment metrics here
# tensorboard --log_dir some/dir
```


## What does lightning control for me?
Everything!
Expand Down

0 comments on commit 74817c2

Please sign in to comment.