In [1]:
!git clone https://github.com/SupotcoA/sw_gen_pred.git

Cloning into 'sw_gen_pred'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 30 (delta 10), reused 25 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (30/30), 520.32 KiB | 10.41 MiB/s, done.
Resolving deltas: 100% (10/10), done.


In [2]:
import sys
sys.path.append('/kaggle/working/sw_gen_pred')

In [3]:
import os
import torch

In [4]:
from train import train
from build_model import build_model
from data import build_dataset
from utils import Logger

import traceback
import shutil

In [None]:
global_config = dict(
    ver="gen_pred_test3",
    description="scale up",
    outcome_root="/kaggle/working",
)
global_config["outcome_dir_root"] = os.path.join(global_config["outcome_root"],
                                                 global_config["ver"])
seg_size=4
n_features=9
transformer_config=dict(
    inp_dim = n_features*seg_size*2,
    dim = 128,
    out_dim = 128,
    num_layers = 8,
    num_heads = 8,
    ff_hidden_dim = 320,
    max_seq_len = 64,
    dropout = 0.0
)


mlp_config=dict(
    in_channels=n_features*seg_size,
    model_channels=128,
    out_channels=n_features*seg_size,
    z_channels=transformer_config['out_dim'],
    num_res_blocks=4,
    grad_checkpointing=False
)


diffusion_config=dict(
    num_steps = 64
)

train_config=dict(
    seg_size=seg_size,
    num_fm_per_gd=8,
    max_seq_len=transformer_config['max_seq_len'],
    train_steps=40000,
    log_every_n_steps=4000,
    eval_every_n_steps=8000,
    pretrained=None,
    batch_size=256,
    base_learning_rate=3.0e-4,
    min_learning_rate=1.0e-4,
    use_lr_scheduler=True,
    warmup_steps=500,
    betas=[0.98, 0.999],
    need_check=False,
    use_ema=False,
    ema_decay=0.9999,
    ema_steps=20000
)
train_config['save']=train_config['train_steps']>0

img_dataset_paths={'afhq':'/kaggle/input/afhq-512',
               'ffhq':'/kaggle/input/flickrfaceshq-dataset-nvidia-resized-256px',
               'celebahq':'/kaggle/input/celebahq256-images-only',
               'fa':'/kaggle/input/face-attributes-grouped',
               'animestyle':'/kaggle/input/gananime-lite',
               'animefaces':'/kaggle/input/another-anime-face-dataset',
              }

data_config = dict(
    shape=(train_config['batch_size'],
           train_config['max_seq_len']*seg_size,
           n_features),
    image_size=256,
    batch_size=train_config['batch_size'],
    ae_batch_size=48,
    split=[0.8,0.1,0.1],
    space_weather_data_root="/kaggle/input/sw-2012-v0",
    data_paths=img_dataset_paths,
    enc_path=os.path.join(global_config["outcome_dir_root"], "enc"),
    enc_inp_path='/kaggle/input/sd-vae-ft-ema-f8-256-faces6-enc',
    dataset_names=['afhq', 'ffhq', 'celebahq', 'fa', 'animestyle', 'animefaces'],
    ignored_dataset=['fa'],
    ignored_dataset_ft=['ffhq', 'celebahq', 'animestyle', 'animefaces'],
    valid_dataset_idx=[]
)

In [6]:
logger = Logger(log_every_n_steps=train_config['log_every_n_steps'],
                log_root=global_config["outcome_dir_root"],
                model_name=global_config['ver']
               )

logger.log_text(str(global_config), "config")
logger.log_text(str(mlp_config), "config", newline=True)
logger.log_text(str(transformer_config), "config", newline=True)
logger.log_text(str(diffusion_config), "config", newline=True)
logger.log_text(str(train_config), "config", newline=True)

In [7]:
torch.manual_seed(42+hash(global_config['ver'])%10000)

train_dataset, val_dataset, test_dataset = build_dataset(data_config)

logger.log_text(str(data_config), "config", newline=True)

Using data: ACE_Bz_2012 ACE_Psw_2012 ACE_Vsw_2012 OMNI_AE_2012 OMNI_ASYMH_2012 OMNI_PC_2012 OMNI_SYMH_2012
All data cat shape: torch.Size([527040, 8])
train data len 421632
val data len 52704
test data len 52704


In [8]:
model, optim, lr_scheduler = build_model(logger,
                                         transformer_config,
                                         mlp_config,
                                         diffusion_config,
                                         train_config)

T params: 1,528,064, MLP params: 432,392, TTrainable: 1,528,064
running on cuda


In [9]:
try:
    train(model, optim, lr_scheduler, train_config,
          train_dataset, val_dataset, test_dataset, logger)
except Exception as e:
    traceback.print_exc()
    info = traceback.format_exc()
    info = f"Exception: {str(info)} \n"+\
            f"Step: {logger.step}"
    print(info)
    logger.log_text(info, "error")
finally:
    if not any([fn.endswith('.pth') for fn in os.listdir(logger.log_root)]):
        if train_config['save']:
            logger.log_net(model.cpu(),f"mar_{logger.step}")
    shutil.make_archive(global_config["outcome_dir_root"],
                        'zip',
                        global_config["outcome_dir_root"])

Train step 4000
loss: 0.1561
time per kstep: 216
peak GPU mem: 5.3 GB

Train step 8000
loss: 0.0619
time per kstep: 215
peak GPU mem: 5.3 GB

Eval
loss:0.0513+-0.0195

Train step 12000
loss: 0.0589
time per kstep: 216
peak GPU mem: 5.3 GB

Train step 16000
loss: 0.0574
time per kstep: 215
peak GPU mem: 5.3 GB

Eval
loss:0.0506+-0.0202

Train step 20000
loss: 0.0553
time per kstep: 216
peak GPU mem: 5.3 GB

Train step 24000
loss: 0.0524
time per kstep: 216
peak GPU mem: 5.3 GB

Eval
loss:0.0582+-0.0248

Train step 28000
loss: 0.0487
time per kstep: 216
peak GPU mem: 5.3 GB

Train step 32000
loss: 0.0448
time per kstep: 216
peak GPU mem: 5.3 GB

Eval
loss:0.0710+-0.0332

Train step 36000
loss: 0.0412
time per kstep: 216
peak GPU mem: 5.3 GB

Train step 40000
loss: 0.0384
time per kstep: 216
peak GPU mem: 5.3 GB

Eval
loss:0.0869+-0.0452

generating


  plt.ylim([0,0.08])


Test
loss:0.0744+-0.0192

