In [1]:
import sys

In [2]:
import os
import torch

In [3]:
from train import train
from build_model import build_model
from data import build_dataset
from utils import Logger

import traceback
import shutil

In [4]:
global_config = dict(
    ver="gen_pred_test3",
    description="scale up",
    outcome_root="/test",
)
global_config["outcome_dir_root"] = os.path.join(global_config["outcome_root"],
                                                 global_config["ver"])
seg_size=4
n_features=9
transformer_config=dict(
    inp_dim = n_features*seg_size*2,
    dim = 16,
    out_dim = 16,
    num_layers = 4,
    num_heads = 8,
    ff_hidden_dim = 16,
    max_seq_len = 64,
    dropout = 0.0
)


mlp_config=dict(
    in_channels=n_features*seg_size,
    model_channels=16,
    out_channels=n_features*seg_size,
    z_channels=transformer_config['out_dim'],
    num_res_blocks=2,
    grad_checkpointing=False
)


diffusion_config=dict(
    num_steps = 16
)

train_config=dict(
    seg_size=seg_size,
    num_fm_per_gd=8,
    max_seq_len=transformer_config['max_seq_len'],
    train_steps=0,
    log_every_n_steps=4,
    eval_every_n_steps=4,
    pretrained=None,
    batch_size=8,
    base_learning_rate=3.0e-4,
    min_learning_rate=1.0e-4,
    use_lr_scheduler=True,
    warmup_steps=500,
    betas=[0.98, 0.999],
    need_check=False,
    use_ema=False,
    ema_decay=0.9999,
    ema_steps=20000
)
train_config['save']=train_config['train_steps']>0

img_dataset_paths={'afhq':'/kaggle/input/afhq-512',
               'ffhq':'/kaggle/input/flickrfaceshq-dataset-nvidia-resized-256px',
               'celebahq':'/kaggle/input/celebahq256-images-only',
               'fa':'/kaggle/input/face-attributes-grouped',
               'animestyle':'/kaggle/input/gananime-lite',
               'animefaces':'/kaggle/input/another-anime-face-dataset',
              }

data_config = dict(
    shape=(train_config['batch_size'],
           train_config['max_seq_len']*seg_size,
           n_features),
    image_size=256,
    seg_size=seg_size,
    batch_size=train_config['batch_size'],
    ae_batch_size=48,
    split=[0.8,0.1,0.1],
    space_weather_data_root="data/data",
    data_paths=img_dataset_paths,
    enc_path=os.path.join(global_config["outcome_dir_root"], "enc"),
    enc_inp_path='/kaggle/input/sd-vae-ft-ema-f8-256-faces6-enc',
    dataset_names=['afhq', 'ffhq', 'celebahq', 'fa', 'animestyle', 'animefaces'],
    ignored_dataset=['fa'],
    ignored_dataset_ft=['ffhq', 'celebahq', 'animestyle', 'animefaces'],
    valid_dataset_idx=[]
)

In [5]:
logger = Logger(log_every_n_steps=train_config['log_every_n_steps'],
                log_root=global_config["outcome_dir_root"],
                model_name=global_config['ver']
               )

logger.log_text(str(global_config), "config")
logger.log_text(str(mlp_config), "config", newline=True)
logger.log_text(str(transformer_config), "config", newline=True)
logger.log_text(str(diffusion_config), "config", newline=True)
logger.log_text(str(train_config), "config", newline=True)

In [6]:
torch.manual_seed(42+hash(global_config['ver'])%10000)

train_dataset, val_dataset, test_dataset = build_dataset(data_config)

logger.log_text(str(data_config), "config", newline=True)

Using data: ACE_IMF_Bx ACE_IMF_By ACE_IMF_Bz ACE_Psw ACE_Vsw OMNI_AE OMNI_ASYMH OMNI_PC OMNI_SYMH
All data cat shape: torch.Size([10519200, 9])
All mask cat shape: torch.Size([10519200, 9])
train data len 8415360
val data len 1051920
test data len 1051920


In [7]:
model, optim, lr_scheduler = build_model(logger,
                                         transformer_config,
                                         mlp_config,
                                         diffusion_config,
                                         train_config)

T params: 8,800, MLP params: 9,188, TTrainable: 8,800
running on cuda


In [8]:
try:
    train(model, optim, lr_scheduler, train_config,
          train_dataset, val_dataset, test_dataset, logger)
except Exception as e:
    traceback.print_exc()
    info = traceback.format_exc()
    info = f"Exception: {str(info)} \n"+\
            f"Step: {logger.step}"
    print(info)
    logger.log_text(info, "error")
finally:
    if not any([fn.endswith('.pth') for fn in os.listdir(logger.log_root)]):
        if train_config['save']:
            logger.log_net(model.cpu(),f"mar_{logger.step}")
    shutil.make_archive(global_config["outcome_dir_root"],
                        'zip',
                        global_config["outcome_dir_root"])

tensor(0.8512, device='cuda:0')
Exception: Traceback (most recent call last):
  File "C:\Users\Andy\AppData\Local\Temp\ipykernel_26240\3757478769.py", line 2, in <cell line: 1>
    train(model, optim, lr_scheduler, train_config,
  File "d:\hw2025\space weather\project\src\train.py", line 17, in train
    pipeline(model, logger, train_dataset)
  File "c:\Users\Andy\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "d:\hw2025\space weather\project\src\probe.py", line 10, in pipeline
    diff_loss(model, dataset, logger, num_test_steps=200)
  File "d:\hw2025\space weather\project\src\probe.py", line 60, in diff_loss
    loss = model.gen(mask, x0, scope=diff_step) #[num_diff_steps,s]
  File "c:\Users\Andy\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "d:\hw2025\space weather\pr

Traceback (most recent call last):
  File "C:\Users\Andy\AppData\Local\Temp\ipykernel_26240\3757478769.py", line 2, in <cell line: 1>
    train(model, optim, lr_scheduler, train_config,
  File "d:\hw2025\space weather\project\src\train.py", line 17, in train
    pipeline(model, logger, train_dataset)
  File "c:\Users\Andy\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "d:\hw2025\space weather\project\src\probe.py", line 10, in pipeline
    diff_loss(model, dataset, logger, num_test_steps=200)
  File "d:\hw2025\space weather\project\src\probe.py", line 60, in diff_loss
    loss = model.gen(mask, x0, scope=diff_step) #[num_diff_steps,s]
  File "c:\Users\Andy\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "d:\hw2025\space weather\project\src\model.py", line 99, in gen
    ra