In [1]:
from src.models import DeepIceModel, EncoderWithDirectionReconstructionV22, EncoderWithDirectionReconstructionV23
import polars as pl
import pandas as pd
from src.fastai_fix import *
from tqdm.notebook import tqdm
from src.dataset import RandomChunkSampler,LenMatchBatchSampler,IceCubeCache, DeviceDataLoader
from src.loss import loss, loss_vms
from fastxtend.vision.all import EMACallback
from tqdm import tqdm
from src.utils import seed_everything, WrapperAdamW

  warn(f"Failed to load image Python extension: {e}")


[1;34mgraphnet[0m: [32mINFO    [0m 2023-04-21 01:10:58 - get_logger - Writing log to [1mlogs/graphnet_20230421-011058.log[0m


In [2]:
class CONFIG:
    SELECTION = 'total'
    OUT = 'BASELINE'
    PATH = '../data/'
    NUM_WORKERS = 8
    SEED = 2023
    BS = 1024 * 3
    BS_VALID = 1024 * 3
    L = 192
    L_VALID = 192
    EPOCHS = 8
    MODEL = DeepIceModel
    MODEL_KWARGS = {'dim': 384, 'dim_base': 128, 'depth': 8, 'head_size':32}
    WEITHS = False
    LOSS_FUNC = loss_vms
    METRIC = loss

In [None]:
def train(config):
    ds_train = IceCubeCache(config.PATH, mode='train', L=config.L, selection=config.SELECTION,reduce_size=0.125)
    ds_train_len = IceCubeCache(config.PATH, mode='train', L=config.L, reduce_size=0.125, selection=config.SELECTION, mask_only=True)
    sampler_train = RandomChunkSampler(ds_train_len, chunks=ds_train.chunks)
    len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=config.BS, drop_last=True)
    dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
                batch_sampler=len_sampler_train, num_workers=config.NUM_WORKERS, persistent_workers=True))

    ds_val = IceCubeCache(config.PATH, mode='eval', L=config.L_VALID, selection=config.SELECTION)
    ds_val_len = IceCubeCache(config.PATH, mode='eval', L=config.L_VALID, selection=config.SELECTION, mask_only=True)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=config.BS_VALID, drop_last=False)
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, batch_sampler=len_sampler_val,
                num_workers=0))


    data = DataLoaders(dl_train,dl_val)
    model = config.MODEL(**config.MODEL_KWARGS)
    if config.WEITHS:
        print('Loading weights from ...',config.WEITHS)
        model.load_state_dict(torch.load(config.WEITHS))
    model = nn.DataParallel(model)
    model = model.cuda()
    learn = Learner(data,
                    model,  
                    path = config.OUT, 
                    loss_func=config.LOSS_FUNC,
                    cbs=[GradientClip(3.0),
                        CSVLogger(),
                        SaveModelCallback(monitor='loss',comp=np.less,every_epoch=True),
                        GradientAccumulation(n_acc=4096//config.BS)],
                        metrics=[config.METRIC], 
                        opt_func=partial(WrapperAdamW,eps=1e-7)).to_fp16()