In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/opt/slh/icecube/')

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,0"
#os.environ["CUDA_VISIBLE_DEVICES"]="1"
#os.environ["NCCL_P2P_DISABLE"] = "1"

In [6]:
import polars as pl
import pandas as pd
import gc
import os
import numpy as np
from icecube.fastai_fix import *
from tqdm.notebook import tqdm
from icecube.data_train_v3 import RandomChunkSampler,LenMatchBatchSampler,IceCubeCache, DeviceDataLoader
from icecube.loss import loss, loss_vms
from icecube.models import EncoderWithDirectionReconstructionV22
from fastxtend.vision.all import EMACallback
from tqdm import tqdm

In [7]:
SELECTION = 'total'
OUT = 'V22'
PATH = '../data/'

NUM_WORKERS = 24
SEED = 2023
bs = 1024 - 256
L = 192

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)
os.makedirs(OUT, exist_ok=True)

In [8]:
def WrapperAdamW(param_groups,**kwargs):
    return OptimWrapper(param_groups,torch.optim.AdamW)




def load_matching_weights(model, weights_path):
    """
    Load model weights from a given path if they match, otherwise skip.
    Prints the number of matched and unmatched weights.

    :param model: The PyTorch model for which weights should be loaded.
    :param weights_path: The path to the saved weights file (.pth or .pt).
    """
    # Load the saved state dictionary
    saved_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))

    # Get the model's state dictionary
    model_state_dict = model.state_dict()

    # Create a new state dictionary to store matching weights
    matching_state_dict = {}

    # Initialize counters for matched and unmatched weights
    matched_weights = 0
    unmatched_weights = 0

    # Iterate through the saved state dictionary
    for name, saved_weight in saved_state_dict.items():
        # Check if the name exists in the model's state dictionary and the shapes match
        if name in model_state_dict and model_state_dict[name].shape == saved_weight.shape:
            # If it matches, add it to the matching state dictionary
            matching_state_dict[name] = saved_weight
            matched_weights += 1
        else:
            print(f"Skipping weight: {name} - Shape mismatch or not found in model")
            unmatched_weights += 1

    # Update the model's state dictionary with the matching state dictionary
    model_state_dict.update(matching_state_dict)

    # Load the updated state dictionary into the model
    model.load_state_dict(model_state_dict)

    print(f"Matched weights: {matched_weights}")
    print(f"Unmatched weights: {unmatched_weights}")





In [None]:
fname = OUT

ds_train = IceCubeCache(PATH, mode='train', L=L, selection=SELECTION,reduce_size=0.125)
ds_train_len = IceCubeCache(PATH, mode='train', L=L, reduce_size=0.125, selection=SELECTION, mask_only=True)
sampler_train = RandomChunkSampler(ds_train_len, chunks=ds_train.chunks)
len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=bs, drop_last=True)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
            batch_sampler=len_sampler_train, num_workers=NUM_WORKERS, persistent_workers=True))

ds_val = IceCubeCache(PATH, mode='eval', L=L, selection=SELECTION)
ds_val_len = IceCubeCache(PATH, mode='eval', L=L, selection=SELECTION, mask_only=True)
sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=bs, drop_last=False)
dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, batch_sampler=len_sampler_val,
            num_workers=0))


data = DataLoaders(dl_train,dl_val)
model = EncoderWithDirectionReconstructionV22(dim=384, dim_base=128, depth=8, head_size=32)
#load_matching_weights(model, '/opt/slh/icecube/hb_training_loop/V20FT3/models/model_7.pth')
model = nn.DataParallel(model)
model = model.cuda()
learn = Learner(data,
                model,  
                path = OUT, 
                loss_func=loss_vms,
                cbs=[GradientClip(3.0),
                     CSVLogger(),
                     SaveModelCallback(monitor='loss',comp=np.less,every_epoch=True),
                     GradientAccumulation(n_acc=4096//bs)],
                     metrics=[loss], 
                     opt_func=partial(WrapperAdamW,eps=1e-7)).to_fp16()




In [None]:
learn.fit_one_cycle(8, lr_max=5e-4, wd=0.05, pct_start=0.01)

epoch,train_loss,valid_loss,loss,time
0,1.433667,1.432677,1.001155,2:57:28
1,1.391384,1.422149,0.995766,2:57:50
2,1.356437,1.399272,0.99036,2:57:42
3,1.341106,1.382831,0.986583,2:57:33
4,1.341254,1.366049,0.98249,2:57:27


  L = max(1,L // 16)
  if not isinstance(inputs, collections.Container) or isinstance(inputs, torch.Tensor):
IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

