Skip to content

Commit

Permalink
Updates for lightning 0.10.0 (#264)
Browse files Browse the repository at this point in the history
* 🎨 refactor to use self.log instead of results obj

* 📌 pin reqs

* 🐛 use functional accuracy

* 🐛 remove result

* 🐛 clean hparams up to fix cpc
  • Loading branch information
nateraw committed Oct 12, 2020
1 parent 7b1d895 commit a8e3fb5
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 69 deletions.
10 changes: 4 additions & 6 deletions pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Expand Up @@ -106,17 +106,15 @@ def step(self, batch, batch_idx):

def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.TrainResult(minimize=loss)
result.log_dict(
self.log_dict(
{f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False
)
return result
return loss

def validation_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({f"val_{k}": v for k, v in logs.items()})
return result
self.log_dict({f"val_{k}": v for k, v in logs.items()})
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
Expand Up @@ -139,17 +139,15 @@ def step(self, batch, batch_idx):

def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.TrainResult(minimize=loss)
result.log_dict(
self.log_dict(
{f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False
)
return result
return loss

def validation_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({f"val_{k}": v for k, v in logs.items()})
return result
self.log_dict({f"val_{k}": v for k, v in logs.items()})
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Expand Up @@ -136,18 +136,16 @@ def generator_step(self, x):

# log to prog bar on each step AND for the full epoch
# use the generator loss for checkpointing
result = pl.TrainResult(minimize=g_loss, checkpoint_on=g_loss)
result.log('g_loss', g_loss, on_epoch=True, prog_bar=True)
return result
self.log('g_loss', g_loss, on_epoch=True, prog_bar=True)
return g_loss

def discriminator_step(self, x):
# Measure discriminator's ability to classify real from generated samples
d_loss = self.discriminator_loss(x)

# log to prog bar on each step AND for the full epoch
result = pl.TrainResult(minimize=d_loss)
result.log('d_loss', d_loss, on_epoch=True, prog_bar=True)
return result
self.log('d_loss', d_loss, on_epoch=True, prog_bar=True)
return d_loss

def configure_optimizers(self):
lr = self.hparams.learning_rate
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/regression/logistic_regression.py
Expand Up @@ -2,7 +2,7 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.classification import accuracy
from pytorch_lightning.metrics.functional import accuracy
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Expand Up @@ -136,19 +136,17 @@ def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.TrainResult(minimize=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result
return total_loss

def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result
return total_loss

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
Expand Down
48 changes: 20 additions & 28 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Expand Up @@ -35,7 +35,7 @@ class CPCV2(pl.LightningModule):
def __init__(
self,
datamodule: pl.LightningDataModule = None,
encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder',
encoder_name: str = 'cpc_encoder',
patch_size: int = 8,
patch_overlap: int = 4,
online_ft: int = True,
Expand All @@ -50,7 +50,7 @@ def __init__(
"""
Args:
datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
encoder: A string for any of the resnets in torchvision, or the original CPC encoder,
encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder,
or a custon nn.Module encoder
patch_size: How big to make the image patches
patch_overlap: How much overlap should each patch have.
Expand All @@ -66,28 +66,20 @@ def __init__(
super().__init__()
self.save_hyperparameters()

# HACK - datamodule not pickleable so we remove it from hparams.
# TODO - remove datamodule from init. data should be decoupled from models.
del self.hparams['datamodule']

self.online_evaluator = self.hparams.online_ft

if pretrained:
self.hparams.dataset = pretrained
self.online_evaluator = True

# link data
# if datamodule is None:
# datamodule = CIFAR10DataModule(
# self.hparams.data_dir,
# num_workers=self.hparams.num_workers,
# batch_size=batch_size
# )
# datamodule.train_transforms = CPCTrainTransformsCIFAR10()
# datamodule.val_transforms = CPCEvalTransformsCIFAR10()
assert datamodule
self.datamodule = datamodule

# init encoder
self.encoder = encoder
if isinstance(encoder, str):
self.encoder = self.init_encoder()
self.encoder = self.init_encoder()

# info nce loss
c, h = self.__compute_final_nb_c(self.hparams.patch_size)
Expand All @@ -97,20 +89,22 @@ def __init__(
self.num_classes = self.datamodule.num_classes

if pretrained:
self.load_pretrained(encoder)
self.load_pretrained(self.hparams.encoder_name)

print(self.hparams)

def load_pretrained(self, encoder):
def load_pretrained(self, encoder_name):
available_weights = {'resnet18'}

if encoder in available_weights:
load_pretrained(self, f'CPCV2-{encoder}')
elif available_weights not in available_weights:
rank_zero_warn(f'{encoder} not yet available')
if encoder_name in available_weights:
load_pretrained(self, f'CPCV2-{encoder_name}')
elif encoder_name not in available_weights:
rank_zero_warn(f'{encoder_name} not yet available')

def init_encoder(self):
dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size))

encoder_name = self.hparams.encoder
encoder_name = self.hparams.encoder_name
if encoder_name == 'cpc_encoder':
return cpc_resnet101(dummy_batch)
else:
Expand Down Expand Up @@ -160,18 +154,16 @@ def training_step(self, batch, batch_nb):
nce_loss = self.shared_step(batch)

# result
result = pl.TrainResult(nce_loss)
result.log('train_nce_loss', nce_loss)
return result
self.log('train_nce_loss', nce_loss)
return nce_loss

def validation_step(self, batch, batch_nb):
# calculate loss
nce_loss = self.shared_step(batch)

# result
result = pl.EvalResult(checkpoint_on=nce_loss)
result.log('val_nce', nce_loss, prog_bar=True)
return result
self.log('val_nce', nce_loss, prog_bar=True)
return nce_loss

def shared_step(self, batch):
try:
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Expand Up @@ -157,16 +157,14 @@ def forward(self, x):
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)

result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, on_epoch=True)
return result
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)

result = pl.EvalResult(checkpoint_on=loss)
result.log('avg_val_loss', loss)
return result
self.log('avg_val_loss', loss)
return loss

def shared_step(self, batch, batch_idx):
(img1, img2), y = batch
Expand Down
15 changes: 6 additions & 9 deletions pl_bolts/models/self_supervised/ssl_finetuner.py
Expand Up @@ -59,21 +59,18 @@ def on_train_epoch_start(self) -> None:

def training_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.TrainResult(loss)
result.log('train_acc', acc, prog_bar=True)
return result
self.log('train_acc', acc, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
return result
self.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult()
result.log_dict({'test_acc': acc, 'test_loss': loss})
return result
self.log_dict({'test_acc': acc, 'test_loss': loss})
return loss

def shared_step(self, batch):
x, y = batch
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
@@ -1,2 +1,2 @@
pytorch-lightning>=0.9.1rc3
pytorch-lightning>=0.10.0
torch>=1.6

0 comments on commit a8e3fb5

Please sign in to comment.