Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jul 25, 2019
1 parent 4562580 commit b0d38d5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
27 changes: 13 additions & 14 deletions pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

class LightningModule(GradInformation, ModelIO, ModelHooks):

def __init__(self, hparams):
def __init__(self):
super(LightningModule, self).__init__()
self.hparams = hparams

self.dtype = torch.FloatTensor
self.exp_save_path = None
Expand Down Expand Up @@ -64,18 +63,6 @@ def configure_optimizers(self):
"""
raise NotImplementedError

def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

@data_loader
def tng_dataloader(self):
"""
Expand Down Expand Up @@ -128,5 +115,17 @@ def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
model.load_state_dict(checkpoint['state_dict'], strict=False)
return model

def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.parameters():
param.requires_grad = True



53 changes: 51 additions & 2 deletions tests/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,55 @@
import shutil
import pdb

import pytorch_lightning as ptl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST


class CoolModel(ptl.LightningModule):

def __init(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x))

def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': self.my_loss(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
return avg_loss

def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)


def get_model():
# set up model with these hyperparams
Expand Down Expand Up @@ -94,11 +143,9 @@ def run_prediction(dataloader, trained_model):
def main():

save_dir = init_save_dir()
model, hparams = get_model()

# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()

# exp file to get weights
Expand All @@ -113,6 +160,8 @@ def main():
distributed_backend='dp',
)

model = CoolModel()

result = trainer.fit(model)

# correct result and ok accuracy
Expand Down

0 comments on commit b0d38d5

Please sign in to comment.