In [1]:
import polars as pl
import pandas as pd
import gc
import os
import numpy as np
from fastai_fix import *
from tqdm.notebook import tqdm
from data_train_v3 import RandomChunkSampler,LenMatchBatchSampler,IceCubeCache, DeviceDataLoader
#from model_transformer_base import DeepIceModel
from loss import loss, loss_vms, loss_comb
from baselineV3_SE_globalrel_d32_2 import DeepIceModel as TransformerV3_2
from fastxtend.vision.all import EMACallback

In [2]:
# !pip install polars==0.16.8
# !pip install pyarrow
# !pip install fastxtend
# !pip install kornia

In [3]:
SELECTION = 'total'
OUT = 'baselineV3_BE_globalrel_d64_0_3emaFT'
PATH = 'data/'

NUM_WORKERS = 16
SEED = 2023
bs = 512 + 64 + 32
L = 256
L_VALID = 512
bs_VALID = 256 + 128

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)
os.makedirs(OUT, exist_ok=True)

In [4]:
def WrapperAdamW(param_groups,**kwargs):
    return OptimWrapper(param_groups,torch.optim.AdamW)




def load_matching_weights(model, weights_path):
    """
    Load model weights from a given path if they match, otherwise skip.
    Prints the number of matched and unmatched weights.

    :param model: The PyTorch model for which weights should be loaded.
    :param weights_path: The path to the saved weights file (.pth or .pt).
    """
    # Load the saved state dictionary
    saved_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))

    # Get the model's state dictionary
    model_state_dict = model.state_dict()

    # Create a new state dictionary to store matching weights
    matching_state_dict = {}

    # Initialize counters for matched and unmatched weights
    matched_weights = 0
    unmatched_weights = 0

    # Iterate through the saved state dictionary
    for name, saved_weight in saved_state_dict.items():
        # Check if the name exists in the model's state dictionary and the shapes match
        if name in model_state_dict and model_state_dict[name].shape == saved_weight.shape:
            # If it matches, add it to the matching state dictionary
            matching_state_dict[name] = saved_weight
            matched_weights += 1
        else:
            #print(f"Skipping weight: {name} - Shape mismatch or not found in model")
            unmatched_weights += 1

    # Update the model's state dictionary with the matching state dictionary
    model_state_dict.update(matching_state_dict)

    # Load the updated state dictionary into the model
    model.load_state_dict(model_state_dict)

    print(f"Matched weights: {matched_weights}")
    print(f"Unmatched weights: {unmatched_weights}")





In [5]:
#load_matching_weights(model, '/opt/slh/icecube/hb_training_loop/V22FT5/models/model_7.pth')
#load_matching_weights(model, '/opt/slh/icecube/hb_training_loop/V23/baselineV3_BE_globalrel_d64_0_2.pth')

fname = OUT

ds_train = IceCubeCache(PATH, mode='train', L=L, selection=SELECTION,reduce_size=0.125)
ds_train_len = IceCubeCache(PATH, mode='train', L=L, reduce_size=0.125, selection=SELECTION, mask_only=True)
sampler_train = RandomChunkSampler(ds_train_len, chunks=ds_train.chunks)
len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=bs, drop_last=True)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            batch_sampler=len_sampler_train, num_workers=NUM_WORKERS, persistent_workers=True))

ds_val = IceCubeCache(PATH, mode='eval', L=L_VALID, selection=SELECTION)
ds_val_len = IceCubeCache(PATH, mode='eval', L=L_VALID, selection=SELECTION, mask_only=True)
sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=bs_VALID, drop_last=False)
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, batch_sampler=len_sampler_val,
            num_workers=0))


data = DataLoaders(dl_train,dl_val)
model = TransformerV3_2(dim=768, dim_base=192, depth=12, head_size=64)
model.load_state_dict(torch.load('baselineV3_BE_globalrel_d64_0_3ema.pth'))
model = nn.DataParallel(model)
learn = Learner(data, 
                model,
                path = OUT,
                loss_func=loss_comb,
                cbs=[GradientClip(3.0),
                     CSVLogger(),
                     EMACallback(), 
                     SaveModelCallback(monitor='loss',comp=np.less,every_epoch=True),
                     GradientAccumulation(n_acc=4096//bs)],
                     metrics=[loss],
                     opt_func=partial(WrapperAdamW,eps=1e-7)).to_fp16()




In [None]:
learn.fit(4, lr=0.2e-6, wd=0.05)

epoch,train_loss,valid_loss,loss,time
0,0.997746,1.030024,0.96407,4:20:03
1,1.005598,1.029966,0.964018,4:19:29
2,1.005712,1.029867,0.963947,4:18:12


  L = max(1,L // 16)
