In [None]:
import os
import random
import numpy as np
import torch
import config
from icecube.fastai_fix import *
import math
from icecube.loss_functions import VonMisesFisher3DLoss

def loss(pred,y):
    #print(pred.max())
    pred = F.normalize(pred.float(),dim=-1)
    
    sa2 = torch.sin(y['target'][:,0])
    ca2 = torch.cos(y['target'][:,0])
    sz2 = torch.sin(y['target'][:,1])
    cz2 = torch.cos(y['target'][:,1])
    
    scalar_prod = (pred[:,0]*sa2*sz2 + pred[:,1]*ca2*sz2 + pred[:,2]*cz2).clip(-1+1e-8,1-1e-8)
    return torch.acos(scalar_prod).abs().mean(-1)   

def loss_vms(pred,y):
    sa2 = torch.sin(y['target'][:,0])
    ca2 = torch.cos(y['target'][:,0])
    sz2 = torch.sin(y['target'][:,1])
    cz2 = torch.cos(y['target'][:,1])
    t = torch.stack([sa2*sz2,ca2*sz2,cz2],-1)
    
    p = pred.float()
    l = torch.norm(pred.float(),dim=-1).unsqueeze(-1)
    p = torch.cat([pred.float()/l,l],-1)
    
    loss = VonMisesFisher3DLoss()(p,t)
    return loss

def get_val(pred):
    pred = F.normalize(pred,dim=-1)
    zen = torch.acos(pred[:,2].clip(-1,1))
    f = F.normalize(pred[:,:2],dim=-1)
    az = torch.asin(f[:,0].clip(-1,1))
    az = torch.where(f[:,1] > 0, az, math.pi - az)
    az = torch.where(az > 0, az, az + 2.0*math.pi)
    return torch.stack([az,zen],-1)

def WrapperAdamW(param_groups,**kwargs):
    return OptimWrapper(param_groups,torch.optim.AdamW)

def set_gpu_environ():
    """Sets CUDA_VISIBLE_DEVICES to those under minimal memory load.
    Meant to be used in notebooks only.
    """
    import os
    import subprocess
    query = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']).decode().split('\n')[1:-1]
    utilization = [int(x.replace(" MiB", "")) for x in query]
    free = [i for i in range(len(utilization)) if utilization[i] == min(utilization)]
    set_visible = ",".join([str(i) for i in free])
    os.environ["CUDA_VISIBLE_DEVICES"] = set_visible
    print(set_visible)
set_gpu_environ()

In [None]:
k

In [None]:
def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train(cfg):
    seed()
    custom_model = cfg.MODEL_NAME()
    if cfg.MODEL_WTS:
        print(f"Loading model weights from {cfg.MODEL_WTS}")
        custom_model.load_state_dict(torch.load(cfg.MODEL_WTS))
    opt = cfg.OPT(
        custom_model.parameters(), lr=cfg.LR, weight_decay=cfg.WD
    )
    loss_func = cfg.LOSS_FUNC()
    len_trn_dl = (cfg.N_FILES * 200000)//cfg.BATCH_SIZE
    warmup_steps = int(len_trn_dl * cfg.WARM_UP_PCT * cfg.EPOCHS)
    total_steps = int(len_trn_dl * cfg.EPOCHS)

    print(f"Total steps: {total_steps}, Warmup steps: {warmup_steps}")

    scheduler = cfg.SCHEDULER(
        opt,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    cfg.FIT_FUNC(
        epochs=cfg.EPOCHS,
        model=custom_model,
        loss_fn=loss_func,
        opt=opt,
        metric=cfg.METRIC,
        config = cfg,
        folder=cfg.FOLDER/cfg.EXP_NAME,
        exp_name=f"{cfg.EXP_NAME}",
        device=cfg.DEVICE,
        sched=scheduler,
    )
    
def main(config_name):
    configs = eval(f"config.{config_name}")
    print(f"Training with config: {configs.__dict__}")
    #os.makedirs(configs.FOLDER/configs.EXP_NAME)
    #train(configs)
    return configs


In [None]:
from functools import partial

In [None]:
cfg = main('FA_GRAPH_V0')
dls = cfg.DATALOADER_FUNC(bs = cfg.BS, L = cfg.L, NUM_WORKERS =cfg.NUM_WORKERS)
model = cfg.MODEL_NAME()

In [None]:
learn = Learner(dls, model, path = cfg.FOLDER, loss_func=loss_vms,cbs=[GradientClip(cfg.GR_CLIP),
            SaveModelCallback(monitor='loss',comp=np.less,every_epoch=True, fname=cfg.EXP_NAME),
],
            metrics=[loss], opt_func=partial(WrapperAdamW,eps=1e-7)).to_fp16()

In [None]:
#learn.lr_find()
if cfg.MD_WTS:
    print(f"Loading model weights from {cfg.MD_WTS}")
    learn.model.load_state_dict(torch.load(cfg.MD_WTS))

In [None]:
learn.fit_one_cycle(cfg.EPOCHS, lr_max=cfg.EPOCHS, wd=cfg.WD, pct_start=cfg.PCT_START)