In [1]:
import warnings
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

warnings.filterwarnings(action='ignore')

In [2]:
def get_dataloader(
    config: Dict[str, Any]
) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]: # -> 주석을 의미
    train_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        ways=config["num_ways"],
        shuffle=True,
        meta_train=True,
        download=config["download"],
    )
    train_dataloader = BatchMetaDataLoader(
        train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=1
    )

    val_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        ways=config["num_ways"],
        shuffle=True,
        meta_val=True,
        download=config["download"],
    )
    val_dataloader = BatchMetaDataLoader(
        val_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=1
    )

    test_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        ways=config["num_ways"],
        shuffle=True,
        meta_val=True,
        download=config["download"],
    )
    test_dataloader = BatchMetaDataLoader(
        test_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=1
    )

    return train_dataloader, val_dataloader, test_dataloader

In [3]:
import torch

config = {
    "folder_name": "dataset",
    "download": True, # 이미 download 했으면, False
    "num_shots": 2,
    "num_ways": 5,
    "batch_size": 3,
    "num_batches_train": 6000,
    "num_batches_test": 2000,
    "num_batches_val": 100,
    "device": torch.device( "cpu" if torch.cuda.is_available() else "cpu" ), # "cuda" or "cpu"
}

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to dataset\omniglot\images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to dataset\omniglot\images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

In [12]:
for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx >= config["num_batches_train"]:
        break

    support_xs = batch['train'][0].to(device=config["device"])
    support_ys = batch['train'][1].to(device=config["device"])
    query_xs = batch['test'][0].to(device=config["device"])
    query_ys = batch['test'][1].to(device=config["device"])

    print(
        f"support_x shape : {support_xs.shape}\n",  # [3, 10, 1, 28, 28] ->  (batch_size, num_shot*num_way, C, H, W)
        f"support_y shape : {support_ys.shape}\n",  # [3, 10] -> (batch_size, label(=num_shot*num_way)), label은 0~4, 즉 N way의 개수만큼만 존재.
        f"query_x shape   : {query_xs.shape}\n",    # [3, 10, 1, 28, 28]
        f"query_y shape   : {query_ys.shape}",      # [3, 10]
    )
    
    break
    

tensor(0)
tensor(0)
tensor(0)
tensor(4)
tensor(3)
tensor(3)
tensor(1)
tensor(4)
tensor(1)
tensor(2)
tensor(4)
tensor(2)
tensor(1)
tensor(3)
tensor(2)
tensor(1)
tensor(3)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(1)
tensor(1)
tensor(0)
tensor(1)
tensor(3)
tensor(1)
tensor(3)
tensor(0)
tensor(4)
tensor(2)
tensor(0)
tensor(1)
tensor(1)
tensor(2)
tensor(3)
tensor(1)
tensor(1)
tensor(4)
tensor(2)
tensor(3)
tensor(2)
tensor(1)
tensor(4)
tensor(4)
tensor(3)
tensor(3)
tensor(2)
tensor(4)
tensor(1)
tensor(4)
tensor(0)
tensor(3)
tensor(3)
tensor(3)
tensor(4)
tensor(0)
tensor(4)
tensor(4)
tensor(0)
tensor(4)
tensor(2)
tensor(0)
tensor(1)
tensor(0)
tensor(1)
tensor(1)
tensor(0)
tensor(4)
tensor(3)
tensor(1)
tensor(4)
tensor(3)
tensor(0)
tensor(4)
tensor(4)
tensor(1)
tensor(2)
tensor(2)
tensor(1)
tensor(4)
tensor(0)
tensor(1)
tensor(4)
tensor(1)
tensor(3)
tensor(1)
tensor(0)
tensor(4)
tensor(3)
tensor(1)
tensor(0)
tensor(2)
tensor(4)
tensor(4)
tensor(4)
tensor(4)
tensor(1)
tensor(1)
tensor(3)


KeyboardInterrupt: 

In [2]:
# 시각화 -> kernel error 뜸 ('Canceled future for execute_request message before replies were done')
'''for b in range(config["batch_size"]):
    fig = plt.figure(constrained_layout=True, figsize=(18, 4))
    subfigs = fig.subfigures(1, 2, wspace=0.07)

    subfigs[0].set_facecolor("0.75")
    subfigs[0].suptitle("Support set", fontsize="x-large")
    support_axs = subfigs.flat[0].subplots(nrows=2, ncols=5)
    for i, ax in enumerate(support_axs.T.flatten()):
        ax.imshow(support_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")

    subfigs[1].set_facecolor("0.75")
    subfigs[1].suptitle("Query set", fontsize="x-large")
    query_axes = subfigs.flat[1].subplots(nrows=2, ncols=5)
    for i, ax in enumerate(query_axes.T.flatten()):
        ax.imshow(query_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")

    fig.suptitle("Batch " + str(b), fontsize="xx-large")

    plt.show()'''

'for b in range(config["batch_size"]):\n    fig = plt.figure(constrained_layout=True, figsize=(18, 4))\n    subfigs = fig.subfigures(1, 2, wspace=0.07)\n\n    subfigs[0].set_facecolor("0.75")\n    subfigs[0].suptitle("Support set", fontsize="x-large")\n    support_axs = subfigs.flat[0].subplots(nrows=2, ncols=5)\n    for i, ax in enumerate(support_axs.T.flatten()):\n        ax.imshow(support_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")\n\n    subfigs[1].set_facecolor("0.75")\n    subfigs[1].suptitle("Query set", fontsize="x-large")\n    query_axes = subfigs.flat[1].subplots(nrows=2, ncols=5)\n    for i, ax in enumerate(query_axes.T.flatten()):\n        ax.imshow(query_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")\n\n    fig.suptitle("Batch " + str(b), fontsize="xx-large")\n\n    plt.show()'