Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for lightning 0.10.0 #264

Merged
merged 5 commits into from Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 4 additions & 6 deletions pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Expand Up @@ -106,17 +106,15 @@ def step(self, batch, batch_idx):

def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.TrainResult(minimize=loss)
result.log_dict(
self.log_dict(
{f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False
)
return result
return loss

def validation_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({f"val_{k}": v for k, v in logs.items()})
return result
self.log_dict({f"val_{k}": v for k, v in logs.items()})
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
Expand Up @@ -139,17 +139,15 @@ def step(self, batch, batch_idx):

def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.TrainResult(minimize=loss)
result.log_dict(
self.log_dict(
{f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False
)
return result
return loss

def validation_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({f"val_{k}": v for k, v in logs.items()})
return result
self.log_dict({f"val_{k}": v for k, v in logs.items()})
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Expand Up @@ -136,18 +136,16 @@ def generator_step(self, x):

# log to prog bar on each step AND for the full epoch
# use the generator loss for checkpointing
result = pl.TrainResult(minimize=g_loss, checkpoint_on=g_loss)
result.log('g_loss', g_loss, on_epoch=True, prog_bar=True)
return result
self.log('g_loss', g_loss, on_epoch=True, prog_bar=True)
return g_loss

def discriminator_step(self, x):
# Measure discriminator's ability to classify real from generated samples
d_loss = self.discriminator_loss(x)

# log to prog bar on each step AND for the full epoch
result = pl.TrainResult(minimize=d_loss)
result.log('d_loss', d_loss, on_epoch=True, prog_bar=True)
return result
self.log('d_loss', d_loss, on_epoch=True, prog_bar=True)
return d_loss

def configure_optimizers(self):
lr = self.hparams.learning_rate
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/regression/logistic_regression.py
Expand Up @@ -2,7 +2,7 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.classification import accuracy
from pytorch_lightning.metrics.functional import accuracy
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Expand Up @@ -136,19 +136,17 @@ def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.TrainResult(minimize=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result
return total_loss

def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)

# log results
result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss)
result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})
self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss})

return result
return total_loss

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
Expand Down
48 changes: 20 additions & 28 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Expand Up @@ -35,7 +35,7 @@ class CPCV2(pl.LightningModule):
def __init__(
self,
datamodule: pl.LightningDataModule = None,
encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder',
encoder_name: str = 'cpc_encoder',
patch_size: int = 8,
patch_overlap: int = 4,
online_ft: int = True,
Expand All @@ -50,7 +50,7 @@ def __init__(
"""
Args:
datamodule: A Datamodule (optional). Otherwise set the dataloaders directly
encoder: A string for any of the resnets in torchvision, or the original CPC encoder,
encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder,
or a custon nn.Module encoder
patch_size: How big to make the image patches
patch_overlap: How much overlap should each patch have.
Expand All @@ -66,28 +66,20 @@ def __init__(
super().__init__()
self.save_hyperparameters()

# HACK - datamodule not pickleable so we remove it from hparams.
# TODO - remove datamodule from init. data should be decoupled from models.
del self.hparams['datamodule']

self.online_evaluator = self.hparams.online_ft

if pretrained:
self.hparams.dataset = pretrained
self.online_evaluator = True

# link data
# if datamodule is None:
# datamodule = CIFAR10DataModule(
# self.hparams.data_dir,
# num_workers=self.hparams.num_workers,
# batch_size=batch_size
# )
# datamodule.train_transforms = CPCTrainTransformsCIFAR10()
# datamodule.val_transforms = CPCEvalTransformsCIFAR10()
assert datamodule
self.datamodule = datamodule

# init encoder
self.encoder = encoder
if isinstance(encoder, str):
self.encoder = self.init_encoder()
self.encoder = self.init_encoder()

# info nce loss
c, h = self.__compute_final_nb_c(self.hparams.patch_size)
Expand All @@ -97,20 +89,22 @@ def __init__(
self.num_classes = self.datamodule.num_classes

if pretrained:
self.load_pretrained(encoder)
self.load_pretrained(self.hparams.encoder_name)

print(self.hparams)

def load_pretrained(self, encoder):
def load_pretrained(self, encoder_name):
available_weights = {'resnet18'}

if encoder in available_weights:
load_pretrained(self, f'CPCV2-{encoder}')
elif available_weights not in available_weights:
rank_zero_warn(f'{encoder} not yet available')
if encoder_name in available_weights:
load_pretrained(self, f'CPCV2-{encoder_name}')
elif encoder_name not in available_weights:
rank_zero_warn(f'{encoder_name} not yet available')

def init_encoder(self):
dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size))

encoder_name = self.hparams.encoder
encoder_name = self.hparams.encoder_name
if encoder_name == 'cpc_encoder':
return cpc_resnet101(dummy_batch)
else:
Expand Down Expand Up @@ -160,18 +154,16 @@ def training_step(self, batch, batch_nb):
nce_loss = self.shared_step(batch)

# result
result = pl.TrainResult(nce_loss)
result.log('train_nce_loss', nce_loss)
return result
self.log('train_nce_loss', nce_loss)
return nce_loss

def validation_step(self, batch, batch_nb):
# calculate loss
nce_loss = self.shared_step(batch)

# result
result = pl.EvalResult(checkpoint_on=nce_loss)
result.log('val_nce', nce_loss, prog_bar=True)
return result
self.log('val_nce', nce_loss, prog_bar=True)
return nce_loss

def shared_step(self, batch):
try:
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Expand Up @@ -157,16 +157,14 @@ def forward(self, x):
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)

result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, on_epoch=True)
return result
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)

result = pl.EvalResult(checkpoint_on=loss)
result.log('avg_val_loss', loss)
return result
self.log('avg_val_loss', loss)
return loss

def shared_step(self, batch, batch_idx):
(img1, img2), y = batch
Expand Down
15 changes: 6 additions & 9 deletions pl_bolts/models/self_supervised/ssl_finetuner.py
Expand Up @@ -59,21 +59,18 @@ def on_train_epoch_start(self) -> None:

def training_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.TrainResult(loss)
result.log('train_acc', acc, prog_bar=True)
return result
self.log('train_acc', acc, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
return result
self.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
result = pl.EvalResult()
result.log_dict({'test_acc': acc, 'test_loss': loss})
return result
self.log_dict({'test_acc': acc, 'test_loss': loss})
return loss

def shared_step(self, batch):
x, y = batch
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
@@ -1,2 +1,2 @@
pytorch-lightning>=0.9.1rc3
pytorch-lightning>=0.10.0
torch>=1.6