In [None]:
import wandb
wandb.login()

import jax
import jax.numpy as jnp              
import optax                           
import torch.utils.data as Data
from tqdm import tqdm

from sol import *

In [None]:
T = 3
N = 50

In [None]:
#------------------------------------------------------------------#
# CONFIG
#------------------------------------------------------------------#

config = {
    'data_dir': f'/data/akash/decay_turbulence_T{T}_N{N}/',
    'work_dir': './checkpoints/',
    'epochs': 400,
    'batch_size': 2,
    'learning_rate': 1e-4,
    'weight_decay': 0,
    'seed': 23,
    'dtype': jnp.float32,
    'solver_dtype': jnp.bfloat16,
    'timespan': N
}

try:
    assert(inner_steps == T)
    print("Solver inner steps checked")
except NameError:
    print("No solver")

In [None]:
rng = jax.random.PRNGKey(config['seed'])
rng, init_rng = jax.random.split(rng)

state = create_train_state(
    init_rng, 
    config,
)

def to_fp16(t):
    return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t)

state = restore_checkpoint(state, config['work_dir']+'1r2jaj79')

if config['dtype'] == jnp.float16:
    state.update_value('params', to_fp16(state.params))
    
state.update_value('examples_seen', 0)
state.update_value('tx', optax.adamw(config['learning_rate'], config['weight_decay']))
state.update_value('dynamic_scale', optim.DynamicScale())
state.update_value('best_val_loss', float("inf"))

# Warm up

In [None]:
# Pretrain for a bit

config['timespan'] = 2
train_dataset = CFDDataset(config=config, frac=0.2)
train_loader = Data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, collate_fn=numpy_collate, drop_last=True)

In [None]:
for i, (SP, DP) in enumerate(tqdm(train_loader)):
    state, metrics, pred = train_step(state, SP, DP, timespan=config['timespan'])

    if i>=0:
        break

In [None]:
print(metrics)

# Train

In [None]:
# Load correct dataset
config['timespan'] = N
train_dataset = CFDDataset(config=config, frac=1)
test_dataset = CFDDataset(config=config, train=False, frac=1)
train_loader = Data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, collate_fn=numpy_collate, drop_last=True)
test_loader = Data.DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0, collate_fn=numpy_collate, drop_last=True)

In [None]:
# RUN

RUN_ID = None
if RUN_ID:
    run = wandb.init(id=RUN_ID, project="akash-ddp", resume="must")
else:
    run = wandb.init(project="akash-ddp", config=config)
    
config['work_dir'] = f'./checkpoints/{run.id}'
# if RUN_ID:
#     state = restore_checkpoint(state, config['work_dir'])
    

for epoch in range(1, config['epochs'] + 1):
    state, train_metrics, pred = train_epoch(state, train_loader, epoch, config)
    test_metrics = eval_model(state, test_loader, config)

In [None]:
run.finish()

In [None]:
save_checkpoint(state, config['work_dir'])