In [10]:
import warnings
from typing import Any, Dict, Tuple

import torch
from torchmeta.toy import Sinusoid
from torchmeta.utils.data import BatchMetaDataLoader

warnings.filterwarnings(action='ignore')

In [11]:
def get_dataloader(
    config: Dict[str, Any]
) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]: # -> 주석을 의미
    train_dataset = Sinusoid(
        num_samples_per_task=config["num_shots"] * 2, # 하나의 task는 support set과 query set으로 나누어야 함으로 2를 미리 곱해줌.
        num_tasks=config["num_batches_train"] * config["batch_size"],
        noise_std = None,
    )
    train_dataloader = BatchMetaDataLoader(train_dataset, batch_size=config["batch_size"])

    val_dataset = Sinusoid(
        num_samples_per_task=config["num_shots"] * 2,
        num_tasks=config["num_batches_val"] * config["batch_size"],
        noise_std = None,
    )
    val_dataloader = BatchMetaDataLoader(val_dataset, batch_size=config["batch_size"])

    test_dataset = Sinusoid(
        num_samples_per_task=config["num_shots"] * 2,
        num_tasks=config["num_batches_test"] * config["batch_size"],
        noise_std = None,
    )
    test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=config["batch_size"])

    return train_dataloader, val_dataloader, test_dataloader

In [12]:
import torch

# 메타 지도 학습 회귀 task는 N-way라는 개념이 존재하지 않음. N은 클레스의 개수임. 
config = {
    "folder_name": "dataset",
    "num_shots": 5,
    "batch_size": 3,
    "num_batches_train": 6000,
    "num_batches_test": 2000,
    "num_batches_val": 100,
    "device": torch.device( "cpu" if torch.cuda.is_available() else "cpu" ), # "cuda" or "cpu"
}

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

In [16]:
for batch_idx, batch in enumerate(train_dataloader):
    xs, ys = batch # batch는 리스트 형태임
    print(xs.shape)
    support_xs = xs[:, : config["num_shots"], :].to(device=config["device"]).type(torch.float)
    query_xs = xs[:, config["num_shots"]:, :].to(device=config["device"]).type(torch.float)
    support_ys = ys[:, : config["num_shots"], :].to(device=config["device"]).type(torch.float)
    query_ys= ys[:, config["num_shots"]:, :].to(device=config["device"]).type(torch.float)

    print(
        f"support_x shape : {support_xs.shape}\n",  # [3, 5, 1] ->  (batch_size, num_shot, 1(=x좌표))
        f"support_y shape : {support_ys.shape}\n",  # [3, 5, 1] ->  (batch_size, num_shot, 1(=y좌표))
        f"query_x shape   : {query_xs.shape}\n",    # [3, 5, 1]
        f"query_y shape   : {query_ys.shape}",      # [3, 5, 1]
    )
    
    break

torch.Size([3, 10, 1])
support_x shape : torch.Size([3, 5, 1])
 support_y shape : torch.Size([3, 5, 1])
 query_x shape   : torch.Size([3, 5, 1])
 query_y shape   : torch.Size([3, 5, 1])
tensor([4.5216])
tensor([-0.3021])
