Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add super res 512 and 1024 #747

Merged
merged 1 commit into from Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,56 @@
_base_: ./imagen_base.yaml

Global:
global_batch_size:
local_batch_size: 1
micro_batch_size: 1


Model:
name: imagen_SR1024
text_encoder_name: t5/t5-11b
text_embed_dim: 1024
timesteps: 1000
in_chans: 3
cond_drop_prob: 0.1
noise_schedules: cosine
pred_objectives: noise
lowres_noise_schedule: linear
lowres_sample_noise_level: 0.2
per_sample_random_aug_noise_level: False
condition_on_text: True
auto_normalize_img: True
p2_loss_weight_gamma: 0.5
dynamic_thresholding: True,
dynamic_thresholding_percentile: 0.95
only_train_unet_number: 1
use_recompute: False

Data:
Train:
dataset:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 1024
max_seq_len: 128
loader:
num_workers: 8
shuffle: True
batch_size: 1
drop_last: True
collate_fn: imagen_collate_fn


Loss:
name: mse_loss
p2_loss_weight_k: 1.0

Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
@@ -0,0 +1,56 @@
_base_: ./imagen_base.yaml

Global:
global_batch_size:
local_batch_size: 1
micro_batch_size: 1


Model:
name: imagen_SR256
text_encoder_name: t5/t5-11b
text_embed_dim: 1024
timesteps: 1000
in_chans: 3
cond_drop_prob: 0.1
noise_schedules: cosine
pred_objectives: noise
lowres_noise_schedule: linear
lowres_sample_noise_level: 0.2
per_sample_random_aug_noise_level: False
condition_on_text: True
auto_normalize_img: True
p2_loss_weight_gamma: 0.5
dynamic_thresholding: True,
dynamic_thresholding_percentile: 0.95
only_train_unet_number: 1
use_recompute: False

Data:
Train:
dataset:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 256
max_seq_len: 128
loader:
num_workers: 8
shuffle: True
batch_size: 1
drop_last: True
collate_fn: imagen_collate_fn


Loss:
name: mse_loss
p2_loss_weight_k: 1.0

Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
@@ -0,0 +1,56 @@
_base_: ./imagen_base.yaml

Global:
global_batch_size:
local_batch_size: 1
micro_batch_size: 1


Model:
name: imagen_SR512
text_encoder_name: t5/t5-11b
text_embed_dim: 1024
timesteps: 1000
in_chans: 3
cond_drop_prob: 0.1
noise_schedules: cosine
pred_objectives: noise
lowres_noise_schedule: linear
lowres_sample_noise_level: 0.2
per_sample_random_aug_noise_level: False
condition_on_text: True
auto_normalize_img: True
p2_loss_weight_gamma: 0.5
dynamic_thresholding: True,
dynamic_thresholding_percentile: 0.95
only_train_unet_number: 1
use_recompute: False

Data:
Train:
dataset:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 512
max_seq_len: 128
loader:
num_workers: 8
shuffle: True
batch_size: 1
drop_last: True
collate_fn: imagen_collate_fn


Loss:
name: mse_loss
p2_loss_weight_k: 1.0

Distributed:
dp_degree: 1
mp_degree: 1
pp_degree: 1
sharding:
sharding_degree: 1
sharding_stage: 1
sharding_offload: False
2 changes: 1 addition & 1 deletion ppfleetx/core/engine/eager_engine.py
Expand Up @@ -259,7 +259,7 @@ def _train_one_epoch(self,
# Note(GuoxiaWang): Do not use len(train_data_loader()),
# it will cause a memory leak.
total_train_batch = len(train_data_loader)
total_eval_batch = len(valid_data_loader)
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
for step, batch in enumerate(train_data_loader):

if epoch_index == self._load_recovery['epoch']:
Expand Down
4 changes: 2 additions & 2 deletions ppfleetx/models/multimodal_model/imagen/__init__.py
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling import (ImagenModel, imagen_397M_text2im_64,
from .modeling import (ImagenModel, imagen_397M_text2im_64,
imagen_2B_text2im_64, imagen_text2im_64_SR256,
imagen_SR256, imagen_SR1024, ImagenCriterion)
imagen_SR256, imagen_SR1024, imagen_SR512, ImagenCriterion)
8 changes: 8 additions & 0 deletions ppfleetx/models/multimodal_model/imagen/modeling.py
Expand Up @@ -814,6 +814,14 @@ def imagen_SR256(**kwargs):
return model


def imagen_SR512(**kwargs):
model = ImagenModel(unets=SRUnet1024(), image_sizes=(512, ), **kwargs)
return model

def imagen_SR1024(**kwargs):
model = ImagenModel(unets=SRUnet1024(), image_sizes=(1024, ), **kwargs)
return model

def imagen_SR64to1024(**kwargs):
model = ImagenModel(unets=SRUnet64to1024(), image_sizes=(1024, ), **kwargs)
return model
3 changes: 2 additions & 1 deletion ppfleetx/models/multimodal_model/multimodal_module.py
Expand Up @@ -27,9 +27,10 @@
class MultiModalModule(BasicModule):
def __init__(self, configs):
self.nranks = paddle.distributed.get_world_size()

super(MultiModalModule, self).__init__(configs)

self.loss_fn = self.get_loss_fn()

def process_configs(self, configs):
configs = process_configs(configs)
return configs
Expand Down
18 changes: 18 additions & 0 deletions projects/imagen/run_super_resolusion_1024_single.sh
@@ -0,0 +1,18 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export CUDA_VISIBLE_DEVICES=0
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml -o Data.Train.loader.num_workers=0
18 changes: 18 additions & 0 deletions projects/imagen/run_super_resolusion_512_single.sh
@@ -0,0 +1,18 @@
#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export CUDA_VISIBLE_DEVICES=0
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml -o Data.Train.loader.num_workers=8
2 changes: 1 addition & 1 deletion projects/imagen/run_text2im_397M_64x64_single.sh
Expand Up @@ -15,4 +15,4 @@
# limitations under the License.

export CUDA_VISIBLE_DEVICES=0
python tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_397M_text2im_64x64.yaml -o Data.Train.loader.num_workers=8
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_397M_text2im_64x64.yaml -o Data.Train.loader.num_workers=8