In [None]:
from torch.utils.data import DataLoader, random_split
import torch

def data_split(dataset,split_ratio):
    """ 데이터를 스플릿 해준다.

    Args:
        dataset (dataset): dataset
        split_ratio (array): [train, val, test] 비율을 넣어준다. ex)[0.8, 0.1, 0.1]

    Returns:
        _type_: train_ds, val_ds, test_ds
    """
    total_count = len(dataset)
    n_train = int(total_count * split_ratio[0])
    n_val = int(total_count * split_ratio[1])
    n_test = total_count - n_train - n_val

    g = torch.Generator().manual_seed(42) 

    return random_split(dataset, [n_train, n_val, n_test], generator=g) 

def create_loader(dataset, batch_size, seed_worker=None,is_train=True, shuffle=False,seed=42,num_workers=4):
    if is_train:
        g = torch.Generator() 
        g.manual_seed(seed)

        return DataLoader( 
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=2,
        )
    else:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            worker_init_fn=None
        )