Skip to content

Commit

Permalink
simclr (#32)
Browse files Browse the repository at this point in the history
* added mixed

* added mixed

* added mixed

* added mixed

* added mixed

* added mixed

* added mixed

* finished simclr

* added mixed

* added mixed

* added mixed

* added moco

* added moco

* added moco

* added moco

* added moco

* added moco

* added train step return

* added mixed

* added mixed

* added mixed

* added mixed

* added moco

* added moco

* added moco

* added moco

* added moco

* added moco

* added train step return

* times tests

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
williamFalcon and Borda committed Jun 4, 2020
1 parent 3879c3a commit b7ab0ef
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ references:
name: Testing
command: |
python --version ; pip --version ; pip list
coverage run --source pl_bolts -m py.test pl_bolts tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml
coverage run --source pl_bolts -m py.test pl_bolts tests -v --junitxml=test-reports/pytest_junit.xml
coverage report
codecov
no_output_timeout: 30m
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:
# TOXENV: py${{ matrix.python-version }}
run: |
# tox --sitepackages
coverage run --source pl_bolts -m py.test pl_bolts tests -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
coverage run --source pl_bolts -m py.test pl_bolts tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
- name: Upload pytest test results
uses: actions/upload-artifact@master
Expand Down
2 changes: 1 addition & 1 deletion .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ rm -rf ./tests/cometruns*
rm -rf ./tests/wandb*
rm -rf ./tests/tests/*
rm -rf ./lightning_logs
python -m coverage run --source pytorch_lightning_bolts -m py.test pytorch_lightning_bolts tests -v --doctest-modules --flake8
python -m coverage run --source pytorch_lightning_bolts -m py.test pytorch_lightning_bolts tests -v --flake8
python -m coverage report -m
22 changes: 22 additions & 0 deletions pl_bolts/datamodules/submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

This comment has been minimized.

Copy link
@Borda

Borda Jun 4, 2020

Author Member

can we use pl_bolts/submit.py instead? :]

This comment has been minimized.

Copy link
@williamFalcon

williamFalcon Jun 4, 2020

Author Contributor

this isn't supposed to be there..

again, this is in heavy dev. please stop reviewing PRs because i'm not making any effort with the PRs.
I just need the commit hashes so i can train

from argparse import ArgumentParser


def submit(master_address, master_port, world_size, node_rank, local_rank):
os.environ['MASTER_ADDR'] = str(master_address)
os.environ['MASTER_PORT'] = str(master_port)
os.environ['NODE_RANK'] = str(node_rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('master_address', default=str)
parser.add_argument('master_port', default=str)
parser.add_argument('node_rank', default=str)
parser.add_argument('world_size', default=str)
parser.add_argument('local_rank', default=str)


# grid train main.py --local --world_size 16 --local_gpus '0,1,2,3' --node_rank 0
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/amdim/amdim_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def training_step(self, batch, batch_nb):

return result

def training_end(self, outputs):
def training_epoch_end(self, outputs):
r1_x1 = outputs['r1_x1']
r5_x1 = outputs['r5_x1']
r7_x1 = outputs['r7_x1']
Expand Down
5 changes: 3 additions & 2 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def training_step(self, batch, batch_nb):
if self.online_evaluator:
if self.hparams.dataset == 'stl10':
img_1, y = labeled_batch
with torch.no_grad():
Z = self(img_1)

with torch.no_grad():
Z = self(img_1)

# just in case... no grads into unsupervised part!
z_in = Z.detach()
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,5 +352,5 @@ def concat_all_gather(tensor):

model = MocoV2(**args.__dict__)

trainer = pl.Trainer.from_argparse_args(args, fast_dev_run=True)
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model)
113 changes: 92 additions & 21 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pl_bolts.datamodules.ssl_imagenet_dataloaders import SSLImagenetDataLoaders
from pl_bolts.losses.self_supervised_learning import nt_xent_loss
from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLRDataTransform
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts import metrics
from pl_bolts.metrics import mean
from pl_bolts.optimizers.layer_adaptive_scaling import LARS

Expand All @@ -29,6 +31,8 @@ def forward(self, x):
class Projection(nn.Module):
def __init__(self, input_dim=1024, output_dim=128):
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.model = nn.Sequential(
nn.Linear(input_dim, 512, bias=False),
nn.BatchNorm1d(512),
Expand All @@ -42,9 +46,11 @@ def forward(self, x):

class SimCLR(pl.LightningModule):
def __init__(self, dataset, data_dir, lr, wd, input_height, batch_size,
num_workers=0, optimizer='adam', step=30, gamma=0.5, temperature=0.5, **kwargs):
online_ft=False, num_workers=0, optimizer='adam',
step=30, gamma=0.5, temperature=0.5, **kwargs):
super().__init__()

self.online_evaluator = online_ft
self.batch_size = batch_size
self.input_height = input_height
self.gamma = gamma
Expand All @@ -55,11 +61,22 @@ def __init__(self, dataset, data_dir, lr, wd, input_height, batch_size,
self.temp = temperature
self.data_dir = data_dir
self.num_workers = num_workers
self.dataset_name = dataset
self.dataset = self.get_dataset(dataset)
self.loss_func = self.init_loss()
self.encoder = self.init_encoder()
self.projection = self.init_projection()

if self.online_evaluator:
z_dim = self.projection.output_dim
num_classes = self.dataset.num_classes
self.non_linear_evaluator = SSLEvaluator(
n_input=z_dim,
n_classes=num_classes,
p=0.2,
n_hidden=1024
)

def init_loss(self):
return nt_xent_loss

Expand All @@ -86,14 +103,42 @@ def forward(self, x):
return h, z

def training_step(self, batch, batch_idx):
if self.dataset_name == 'stl10':
labeled_batch = batch[1]
unlabeled_batch = batch[0]
batch = unlabeled_batch

(img_1, img_2), y = batch
h1, z1 = self.forward(img_1)
h2, z2 = self.forward(img_2)

# return h1, z1, h2, z2
loss = self.loss_func(z1, z2, self.temp)
logs = {'loss': loss.item()}
return dict(loss=loss, log=logs)
log = {'train_ntx_loss': loss}

# don't use the training signal, just finetune the MLP to see how we're doing downstream
if self.online_evaluator:
if self.dataset_name == 'stl10':
(img_1, img_2), y = labeled_batch

with torch.no_grad():
h1, z1 = self.forward(img_1)

# just in case... no grads into unsupervised part!
z_in = z1.detach()

z_in = z_in.reshape(z_in.size(0), -1)
mlp_preds = self.non_linear_evaluator(z_in)
mlp_loss = F.cross_entropy(mlp_preds, y)
loss = loss + mlp_loss
log['train_mlp_loss'] = mlp_loss

result = {
'loss': loss,
'log': log
}

return result

# def training_step_end(self, output_parts):
# h1s, z1s, h2s, z2s = output_parts
Expand All @@ -104,31 +149,65 @@ def training_step(self, batch, batch_idx):
# print(f'Rank = {rank}', [z2.shape for z2 in z2s])

def validation_step(self, batch, batch_idx):
if self.dataset_name == 'stl10':
labeled_batch = batch[1]
unlabeled_batch = batch[0]
batch = unlabeled_batch

(img_1, img_2), y = batch
h1, z1 = self.forward(img_1)
h2, z2 = self.forward(img_2)
loss = self.loss_func(z1, z2, self.temp)
logs = {'val_loss': loss.item()}
return dict(val_loss=loss, log=logs)
result = {'val_loss': loss}

if self.online_evaluator:
if self.dataset_name == 'stl10':
(img_1, img_2), y = labeled_batch
h1, z1 = self.forward(img_1)

z_in = z1.reshape(z1.size(0), -1)
mlp_preds = self.non_linear_evaluator(z_in)
mlp_loss = F.cross_entropy(mlp_preds, y)
acc = metrics.accuracy(mlp_preds, y)
result['mlp_acc'] = acc
result['mlp_loss'] = mlp_loss

return result

def validation_epoch_end(self, outputs: list):
val_loss = mean(outputs, 'val_loss')
logs = dict(

log = dict(
val_loss=val_loss,
)
return dict(val_loss=val_loss, log=logs)

if self.online_evaluator:
mlp_acc = mean(outputs, 'mlp_acc')
mlp_loss = mean(outputs, 'mlp_loss')
log['val_mlp_acc'] = mlp_acc
log['val_mlp_loss'] = mlp_loss

return dict(val_loss=val_loss, log=log, progress_bar={'val_acc': log['val_mlp_acc']})

def prepare_data(self):
self.dataset.prepare_data()

def train_dataloader(self):
train_transform = SimCLRDataTransform(input_height=self.input_height)
loader = self.dataset.train_dataloader(self.batch_size, transforms=train_transform)

if self.dataset_name == 'stl10':
loader = self.dataset.train_dataloader_mixed(self.batch_size, transforms=train_transform)
else:
loader = self.dataset.train_dataloader(self.batch_size, transforms=train_transform)
return loader

def val_dataloader(self):
test_transform = SimCLRDataTransform(input_height=self.input_height, test=True)
loader = self.dataset.val_dataloader(self.batch_size, transforms=test_transform)

if self.dataset_name == 'stl10':
loader = self.dataset.val_dataloader_mixed(self.batch_size, transforms=test_transform)
else:
loader = self.dataset.val_dataloader(self.batch_size, transforms=test_transform)
return loader

def configure_optimizers(self):
Expand Down Expand Up @@ -160,9 +239,9 @@ def add_model_specific_args(parent_parser):

# Training
parser.add_argument('--expdir', type=str, default='simclrlogs')
parser.add_argument('--optim', choices=['adam', 'lars'], default='adam')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--optim', choices=['adam', 'lars'], default='lars')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.00006)
parser.add_argument('--mom', type=float, default=0.9)
parser.add_argument('--eta', type=float, default=0.001)
parser.add_argument('--step', type=float, default=30)
Expand All @@ -171,17 +250,9 @@ def add_model_specific_args(parent_parser):
# Model
parser.add_argument('--temp', type=float, default=0.5)
parser.add_argument('--trans', type=str, default='randcrop,flip')
parser.add_argument('--num_workers', default=8, type=int)
return parser

# model = SimCLR(
# hparams=args,
# encoder=EncoderModel(),
# projection=Projection(),
# loss_func=nt_xent_loss,
# temperature=args.temp,
# transform_list=list(args.trans.split(','))
# )


if __name__ == '__main__':
from argparse import ArgumentParser
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ norecursedirs =
.git
dist
build
addopts = --strict
addopts =
--strict
--doctest-modules
--durations=0

[coverage:report]
exclude_lines =
Expand Down

0 comments on commit b7ab0ef

Please sign in to comment.