In [2]:
import math
import copy
import scipy.special
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.insert(0, "/home/rob/Git/meta-fsl-nas/metanas")

import metanas.utils.genotypes as gt

In [8]:
class Model(nn.Module):
    """Model for a single DARTS cell"""
    def __init__(self, n_ops=7, n_nodes=3):
        self.n_ops = len(gt.PRIMITIVES_FEWSHOT)
        self.n_nodes = 3

        
        self.encoded_states = []
        self.states = []
        self.topk = []
        
        self.alphas = []
        self.norm_alphas = []

        # Adjacency matrix
        self.A = np.ones((self.n_nodes+2, self.n_nodes+2)) - np.eye(self.n_nodes+2)

        # Remove the 2 input nodes from A
        self.A[0, 1] = 0
        self.A[1, 0] = 0

        for i in range(n_nodes):
            a = nn.Parameter(
                1e-3 * torch.randn(i + 2, n_ops))
            self.alphas.append(a)
            self.norm_alphas.append(F.softmax(a, dim=-1))

    def print_topk(self):
        for i, edges in enumerate(self.norm_alphas):
            # edges: Tensor(n_edges, n_ops)
            edge_max, _ = torch.topk(edges[:, :], 1)
            # selecting the top-k input nodes, k=2
            _, topk_edge_indices = torch.topk(edge_max.view(-1), k=2)
            
            print(topk_edge_indices)
    
    def parse(self, alpha, k=2, primitives=gt.PRIMITIVES_FEWSHOT):
        gene = []
        for edges in alpha:
            edge_max, primitive_indices = torch.topk(
                edges[:, :], 1
            )
            
#             print(edges[:,:], primitive_indices, edge_max, "\n")

            topk_edge_values, topk_edge_indices = torch.topk(
                edge_max.view(-1), k)

#             print(topk_edge_values, topk_edge_indices, "\n")
            
            node_gene = []
            for edge_idx in topk_edge_indices:
                prim_idx = primitive_indices[edge_idx]
                prim = primitives[prim_idx]
                node_gene.append((prim, edge_idx.item()))

            gene.append(node_gene)
        return gene
    
    def calculate_states(self):
        s_idx = 0
        
#         print(self.states)
#         if current_states is not None:
        prev_topk = copy.deepcopy(self.topk)
        prev_edge = copy.deepcopy(self.encoded_states)
        
        self.topk = []
        self.states = []
        self.encoded_states = []
        self.edge_to_index = {}
        self.edge_to_alpha = {}

        for i, edges in enumerate(self.norm_alphas):
            # edges: Tensor(n_edges, n_ops)
            edge_max, edge_idx = torch.topk(edges[:, :], 1)
            
            # selecting the top-k input nodes, k=2
            _, topk_edge_indices = torch.topk(edge_max.view(-1), k=2)

            edge_one_hot = torch.zeros_like(edges[:,:])
            
            
            for hot_e, op in zip(edge_one_hot, edge_idx):
                hot_e[op.item()] = 1

            for j, edge in enumerate(edge_one_hot):
                self.edge_to_index[(j, i+2)] = s_idx
                self.edge_to_index[(i+2, j)] = s_idx+1

                self.edge_to_alpha[(j, i+2)] = (i, j)
                self.edge_to_alpha[(i+2, j)] = (i, j)

                self.encoded_states.append(edge.numpy())
                
                # For undirected edge we add the edge twice
                self.states.append([
                        f"from:{j} to:{i+2}",
                        int(j in topk_edge_indices)])
                
                self.topk.append([
                        f"from:{j} to:{i+2}",
                        int(j in topk_edge_indices)])

#                 self.states.append((
#                         (f"from:{i+2}",
#                         f"to:{j}"),
#                         [int(j in topk_edge_indices)]))
                s_idx += 2
    
        d = {'prev_topk': np.array(prev_topk),
             'prev_edges': np.array(prev_edge)}
    
        self.encoded_states = np.array(self.encoded_states)
#         change = (np.array(prev_topk) < np.array(self.topk))
        return d, self.states
    
    def _inverse_softmax(self, x, C):
        return torch.log(x) + C
    
    def increase_op(self, cur_node, next_node, op_idx, prob=0.1, n_ops=7):
#         t_max = 5.0
#         t_min = 0.1
#         max_step = 6
#         curr_step = 1
#         # Temperature
#         temp = t_max - curr_step * (t_max - t_min)/max_step-1

        C = math.log(10.)

        row_idx, edge_idx = self.edge_to_alpha[(cur_node, next_node)]
        
        # Set short-hands
        curr_op = self.norm_alphas[row_idx][edge_idx][op_idx]
        curr_edge = self.norm_alphas[row_idx][edge_idx]
        
        
        # Allow for increasing to 0.99
        if curr_op + prob > 1.0:
            surplus = curr_op + prob - 0.99
            prob -= surplus

        if curr_op + prob < 1.0:
            # Increase chosen op
            with torch.no_grad():
                curr_op += prob

            # Prevent 0.00 normalized alpha values, resulting in
            # -inf
            with torch.no_grad():
                curr_edge += 0.01

            # Set the meta-model, update the env state in
            # self.update_states()
            with torch.no_grad():
                self.alphas[
                    row_idx][edge_idx] = self._inverse_softmax(
                    curr_edge, C)
        
        # /temp
        self.norm_alphas = [
            F.softmax(alpha, dim=-1).detach().cpu()
            for alpha in self.alphas]
    
    def decrease_op(self, cur_node, next_node, op_idx, prob=0.1, n_ops=7):
        C = math.log(10.)

        row_idx, edge_idx = self.edge_to_alpha[(cur_node, next_node)]
        
        # Set short-hands
        curr_op = self.norm_alphas[row_idx][edge_idx][op_idx]
        curr_edge = self.norm_alphas[row_idx][edge_idx]
        
        
        # Allow for increasing to 0.99
        if curr_op - prob < 0.0:
            surplus = prob - curr_op + 0.01
#             print(surplus)
            prob -= surplus
#             print(prob)

        if curr_op - prob > 0.0:
            # Increase chosen op
            with torch.no_grad():
                curr_op -= prob
                
            # Prevent 0.00 normalized alpha values, resulting in
            # -inf
            with torch.no_grad():
                curr_edge += 0.01
            
            # Set the meta-model, update the env state in
            # self.update_states()
            with torch.no_grad():
                self.alphas[
                    row_idx][edge_idx] = self._inverse_softmax(
                    curr_edge, C)
            
        self.norm_alphas = [
            F.softmax(alpha, dim=-1).detach().cpu()
            for alpha in self.alphas]

In [19]:
model = Model()

print("primitives:", gt.PRIMITIVES_FEWSHOT, "\n")
model.print_topk(), model.parse(model.norm_alphas, k=2)

_, b = model.calculate_states()
# print("init:", b)

# print(model.norm_alphas[1])

# d, _ = model.calculate_states()

primitives: ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'conv_1x5_5x1', 'conv_3x3', 'sep_conv_3x3', 'dil_conv_3x3'] 

tensor([1, 0])
tensor([0, 2])
tensor([2, 0])


In [20]:
model.norm_alphas

[tensor([[0.1428, 0.1427, 0.1429, 0.1430, 0.1430, 0.1426, 0.1430],
         [0.1427, 0.1428, 0.1428, 0.1430, 0.1430, 0.1430, 0.1427]],
        grad_fn=<SoftmaxBackward>),
 tensor([[0.1428, 0.1431, 0.1426, 0.1429, 0.1430, 0.1429, 0.1428],
         [0.1429, 0.1429, 0.1430, 0.1427, 0.1429, 0.1429, 0.1427],
         [0.1427, 0.1429, 0.1430, 0.1428, 0.1429, 0.1429, 0.1428]],
        grad_fn=<SoftmaxBackward>),
 tensor([[0.1429, 0.1428, 0.1428, 0.1431, 0.1427, 0.1430, 0.1427],
         [0.1429, 0.1429, 0.1429, 0.1428, 0.1429, 0.1427, 0.1429],
         [0.1429, 0.1429, 0.1432, 0.1426, 0.1428, 0.1428, 0.1428],
         [0.1430, 0.1429, 0.1428, 0.1428, 0.1431, 0.1425, 0.1429]],
        grad_fn=<SoftmaxBackward>)]

In [21]:
model.topk

[['from:0 to:2', 1],
 ['from:1 to:2', 1],
 ['from:0 to:3', 1],
 ['from:1 to:3', 0],
 ['from:2 to:3', 1],
 ['from:0 to:4', 1],
 ['from:1 to:4', 0],
 ['from:2 to:4', 1],
 ['from:3 to:4', 0]]

In [22]:
# model.increase_op(0, 3, 5)
model.increase_op(0, 3, 5)

model.calculate_states()[1]

[['from:0 to:2', 1],
 ['from:1 to:2', 1],
 ['from:0 to:3', 1],
 ['from:1 to:3', 0],
 ['from:2 to:3', 1],
 ['from:0 to:4', 1],
 ['from:1 to:4', 0],
 ['from:2 to:4', 1],
 ['from:3 to:4', 0]]

In [23]:
model.norm_alphas

[tensor([[0.1428, 0.1427, 0.1429, 0.1430, 0.1430, 0.1426, 0.1430],
         [0.1427, 0.1428, 0.1428, 0.1430, 0.1430, 0.1430, 0.1427]]),
 tensor([[0.1306, 0.1308, 0.1304, 0.1307, 0.1307, 0.2161, 0.1306],
         [0.1429, 0.1429, 0.1430, 0.1427, 0.1429, 0.1429, 0.1427],
         [0.1427, 0.1429, 0.1430, 0.1428, 0.1429, 0.1429, 0.1428]]),
 tensor([[0.1429, 0.1428, 0.1428, 0.1431, 0.1427, 0.1430, 0.1427],
         [0.1429, 0.1429, 0.1429, 0.1428, 0.1429, 0.1427, 0.1429],
         [0.1429, 0.1429, 0.1432, 0.1426, 0.1428, 0.1428, 0.1428],
         [0.1430, 0.1429, 0.1428, 0.1428, 0.1431, 0.1425, 0.1429]])]

In [30]:
for w in self.workers:
    

[[12, 12, 32, 12], [321, 312, 12]]

In [87]:
model.print_topk()

tensor([1, 0])
tensor([1, 2])
tensor([1, 3])


In [25]:

print(model.norm_alphas[1])

tensor([[0.1430, 0.1432, 0.1426, 0.1429, 0.1426, 0.1427, 0.1430],
        [0.1429, 0.1428, 0.1427, 0.1427, 0.1431, 0.1427, 0.1431],
        [0.1426, 0.1429, 0.1428, 0.1428, 0.1430, 0.1429, 0.1430]],
       grad_fn=<SoftmaxBackward>)


In [62]:
for i in range(100):
    a = np.random.randint(4, 5)
    b = np.random.randint(0, len(gt.PRIMITIVES_FEWSHOT))
    model.increase_op(a, 3, b)
    model.increase_op(a, 3, b)

In [64]:

print(model.norm_alphas[1])

tensor([[0.1426, 0.1429, 0.1429, 0.1430, 0.1430, 0.1428, 0.1429],
        [0.0248, 0.0150, 0.0420, 0.5894, 0.0915, 0.0163, 0.2210],
        [0.0150, 0.0914, 0.0150, 0.5899, 0.0248, 0.0429, 0.2210]])


In [85]:

model.increase_op(1, 3, 1)

In [86]:

print(model.norm_alphas[1])

tensor([[0.1426, 0.1429, 0.1429, 0.1430, 0.1430, 0.1428, 0.1429],
        [0.0303, 0.7786, 0.0312, 0.0569, 0.0335, 0.0299, 0.0396],
        [0.0150, 0.0914, 0.0150, 0.5899, 0.0248, 0.0429, 0.2210]])


In [None]:
import os
import time
import copy
import glob
import shelve

import igraph as ig
from igraph import Graph
from PIL import Image

In [None]:
def generate_graph_path(path, last_steps=None, paths_left=5):
    d = shelve.open(path)
    walks = sum(d.values(), [])
    d.close()
    
    if last_steps is not None:
        walks = walks[:last_steps]
        
    # TODO: Starting path might be variable
    path = [(0,2)]
    weights = [1]

    walks_temp = []
    walks_curr = copy.deepcopy(walks)
    max_k = len(walks)

    for i in range(max_k):
        edge_dict = {}
        walks_temp = []

        for j, walk in enumerate(walks_curr):
            # Check if current walk is long enough
            if i >= len(walk):
                continue
            else:
                # Current step
                w = walk[i]

                if w in edge_dict:
                    edge_dict[w] += 1
                else:
                    edge_dict[w] = 1

                walks_temp.append(walk)

        # Stop if the path ended or,
        if len(edge_dict) == 0:
            break
        # Or if only 5 walks are left 
        if sum([v for v in edge_dict.values()]) < paths_left:
            break

        # Step with highest count
        max_edge = max(edge_dict, key=edge_dict.get)
        path.append(max_edge)
        weights.append(edge_dict[max_edge]/(sum(edge_dict.values())))

        for walk in walks_temp:
            if walk[i] != max_edge:
                walks_temp.remove(walk)
        walks_curr = copy.deepcopy(walks_temp)
        
        return path, weights

In [None]:
def generate_gif(path, weights, save_paths, format_path):
    for f in save_paths:
        os.remove(f)

    edges = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
    edge_color = ["gray"]*len(edges)

    for i, (edge, weight) in enumerate(zip(path, weights)):
        edges = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
        g = Graph(edges)

        if edge in edges:
            index = edges.index(edge)
        else:
            index = edges.index((edge[1], edge[0]))

        edge_color[index] = 'red'

        lb = [""]*len(edges)
        lb[index] = f"step {i}: {edge[0]} -> {edge[1]}, weight: {weight:.2f}"

        # 5 Nodes
        g.vs["label"] = ["0", "1", "2", "3", "4"]
        g.vs["input"] = [True, True, False, False, False]
        g.es["color"] = edge_color
        g.es["label"] = lb

        ig.plot(
            g, 
            vertex_size=40, 
            edge_width=[3],
            vertex_color=['yellow', 'yellow', 'blue', 'blue', 'blue'],
            target=format_path.format(i),
            bbox=(800, 800),
            margin=200
        )
        edge_color[index] = 'purple'

    frames = []
    for img in sorted(glob.glob(save_paths), key=os.path.getmtime):
        frames.append(Image.open(img))

    frames[0].save('graph_walk.gif', format='GIF', append_images=frames[1:],
        save_all=True, duration=1800, loop=0)

In [None]:
path = "/home/rob/Git/meta-fsl-nas/metanas/results/triplemnist/ppo_metad2a_environment_1/seed_2/graph_walk.shlv"
save_paths = glob.glob("/home/rob/Git/meta-fsl-nas/notebooks/path/*.png")
format_path = "/home/rob/Git/meta-fsl-nas/notebooks/path/{0}.png"

path, weights = generate_graph_path(path)
generate_gif(path, weights, save_paths, format_path)

In [None]:
import cv2 as cv
import glob

In [None]:
def calc_avg_mean_std_dataset(path):
    mean_sum = np.array([0., 0., 0.])
    std_sum = np.array([0., 0., 0.])
    
    n_images = 0
    for file in list(glob.glob(path)):
        img = cv.imread(file)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        img = img/255
        mean, std = cv.meanStdDev(img)
        
        mean_sum += np.squeeze(mean)
        std_sum += np.squeeze(std)
        n_images += 1
    return (mean_sum / n_images, std_sum / n_images)

path = '/home/rob/Desktop/meta5/*/*/*.png'

mean, std = calc_avg_mean_std_dataset(path)
print(mean, std)

print("{:0.4f}, {:0.4f}".format(mean[0], std[0]))

In [10]:
import torch
import pickle

# Load model back in

In [2]:
vars_path = '/home/rob/Git/meta-fsl-nas/metanas/results/omniprint/ppo_debug/seed_2/_s2/vars1.pkl'
model_path = '/home/rob/Git/meta-fsl-nas/metanas/results/omniprint/ppo_debug/seed_2/_s2/pyt_save/model1.pt'

In [None]:
type(torch.load(model_path)['ac'])

In [None]:
a = pickle.load(open(vars_path, 'rb'))

a

In [57]:
import os
import json
import h5py
import numpy as np
from PIL import Image

from torchmeta.utils.data import Dataset, ClassDataset, CombinationMetaDataset
from torchmeta.datasets.utils import download_file_from_google_drive
from torchmeta.datasets.helpers import helper_with_default


def omniprint(folder, shots, ways, shuffle=True, test_shots=None,
              seed=None, **kwargs):
    return helper_with_default(OmniPrint, folder, shots, ways,
                               shuffle=shuffle, test_shots=test_shots,
                               seed=seed, **kwargs)


class OmniPrint(CombinationMetaDataset):
    def __init__(self, root, num_classes_per_task=None, meta_train=False,
                 meta_val=False, meta_test=False, meta_split=None,
                 transform=None, target_transform=None, dataset_transform=None,
                 class_augmentations=None, download=False,
                 print_split='meta1',  # Addition for the OmniPrint dataset
                 ):
        dataset = OmniPrintClassDataset(
            root, meta_train=meta_train,
            meta_val=meta_val, meta_test=meta_test,
            print_split=print_split, transform=transform,
            meta_split=meta_split,
            class_augmentations=class_augmentations,
            download=download)
        super(OmniPrint, self).__init__(
            dataset, num_classes_per_task,
            target_transform=target_transform,
            dataset_transform=dataset_transform)


class OmniPrintClassDataset(ClassDataset):
    gdrive_id = '1JBXYMTsdlm8RaEBPqrJbDRzs3hJ4q_gH'
    folder = 'omniprint'

    zip_filename = '{0}.zip'
    filename = '{0}_{1}_data.hdf5'
    filename_labels = '{0}_{1}_labels.json'

    def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
                 meta_split=None, print_split=None, transform=None,
                 class_augmentations=None, download=False):
        super(OmniPrintClassDataset, self).__init__(
            meta_train=meta_train,
            meta_val=meta_val,
            meta_test=meta_test,
            meta_split=meta_split,
            class_augmentations=class_augmentations)

        self.root = os.path.join(os.path.expanduser(
            root), self.folder)
        self.print_split = print_split
        self.transform = transform

        self.split_filename = os.path.join(
            self.root,
            self.filename.format(print_split, self.meta_split))
        self.split_filename_labels = os.path.join(
            self.root,
            self.filename_labels.format(print_split, self.meta_split))

        self._data = None
        self._labels = None

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('OmniPrint integrity check failed')
        self._num_classes = len(self.labels)

    def __getitem__(self, index):
        character_name = '/'.join(self.labels[index % self.num_classes])
        data = self.data[character_name]
        transform = self.get_transform(index, self.transform)
        target_transform = self.get_target_transform(index)

        return OmniPrintDataset(
            index, data, character_name,
            transform=transform, target_transform=target_transform)

    @property
    def num_classes(self):
        return self._num_classes

    @property
    def data(self):
        if self._data is None:
            self._data = h5py.File(self.split_filename, 'r')
        return self._data

    @property
    def labels(self):
        if self._labels is None:
            with open(self.split_filename_labels, 'r') as f:
                self._labels = json.load(f)
        return self._labels

    def _check_integrity(self):
        return (os.path.isfile(self.split_filename)
                and os.path.isfile(self.split_filename_labels))

    def close(self):
        if self._data is not None:
            self._data.close()
            self._data = None

    def download(self):
        import zipfile
        import shutil
        import glob
        from tqdm import tqdm

        if self._check_integrity():
            return

        zip_foldername = os.path.join(
            self.root, self.zip_filename.format(self.folder))
        # Download the datasets
        if not os.path.isfile(zip_foldername):
            download_file_from_google_drive(
                self.gdrive_id, self.root,
                self.zip_filename.format(self.folder))

        # Unzip the dataset
        if not os.path.isdir(zip_foldername):
            with zipfile.ZipFile(zip_foldername) as f:
                for member in tqdm(f.infolist(), desc='Extracting '):
                    try:
                        f.extract(member, self.root)
                    except zipfile.BadZipFile:
                        print('Error: Zipfile is corrupted')

        for print_split in ['meta1', 'meta2', 'meta3', 'meta4', 'meta5']:
            for split in tqdm(['train', 'val', 'test'], desc=print_split):
                filename_labels = os.path.join(
                    self.root, self.filename_labels.format(print_split, split))

                with open(filename_labels, 'r') as f:
                    labels = json.load(f)

                filename = os.path.join(
                    self.root, self.filename.format(print_split, split))

                with h5py.File(filename, 'w') as f:
                    group = f.create_group(print_split)
                    for _, alphabet, character in labels:
                        filenames = glob.glob(
                            os.path.join(
                                self.root, print_split,
                                alphabet, character, '*.png'))
                        dataset = group.create_dataset('{0}/{1}'.format(
                            alphabet, character),
                            (len(filenames), 32, 32),
                            dtype='uint8')

                        for i, char_filename in enumerate(filenames):
                            image = Image.open(
                                char_filename, mode='r').convert('L')
                            dataset[i] = image

            shutil.rmtree(os.path.join(self.root, print_split))


class OmniPrintDataset(Dataset):
    def __init__(self, index, data, character_name,
                 transform=None, target_transform=None):
        super(OmniPrintDataset, self).__init__(
            index, transform=transform,
            target_transform=target_transform)
        self.data = data
        self.character_name = character_name

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = Image.fromarray(self.data[index])
        target = self.character_name

        if self.transform is not None:
            image = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)
        return (image, target)


In [58]:

from torchmeta.utils.data import BatchMetaDataLoader


dataset = omniprint(
        "/home/rob/Desktop",
        15,
        5,
        print_split='meta1',
        meta_split='val',
        test_shots=1,
        download=True,
        seed=1,
)

dataloader = BatchMetaDataLoader(
    dataset, batch_size=20, num_workers=1, shuffle=True
)

train_it = iter(dataloader)

In [59]:
# type(next(iter(dataloader))['train'][0][0])

In [49]:
next(train_it)['train'][0].shape

torch.Size([20, 75, 1, 32, 32])

In [60]:

from collections import namedtuple
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from torchmeta.utils.data import BatchMetaDataLoader

Task = namedtuple("Task", ["train_loader", "valid_loader", "test_loader"])

meta_batch_size = 1
task_batch_size = 15*20
shots = 5
ways = 5
validation_set = True

batch = next(train_it)
train_batch_x, train_batch_y = batch["train"]
test_batch_x, test_batch_y = batch["test"]
num_tasks = meta_batch_size

meta_train_batch = list()
for task_idx in range(num_tasks):
    
    # Parse validation set
    train_idx, valid_idx = train_test_split(
        np.arange(len(train_batch_y[task_idx].numpy())),
        test_size=0.2, random_state=42, shuffle=True,
        stratify=train_batch_y[task_idx].numpy())
    ####
    
    # Train loader
    dset_train = TensorDataset(train_batch_x[task_idx][train_idx], train_batch_y[task_idx][train_idx])
    train_loader = DataLoader(dset_train, batch_size=20)

    # Validation loader
    dset_val = TensorDataset(train_batch_x[task_idx][valid_idx], train_batch_y[task_idx][valid_idx])
    val_loader = DataLoader(dset_val, batch_size=5)
    
    # Test loader
    dset_test = TensorDataset(test_batch_x[task_idx], test_batch_y[task_idx])
    test_loader = DataLoader(dset_test, batch_size=shots * ways)
    
    
    print(train_batch_x[task_idx][train_idx].shape)
    print(train_batch_x[task_idx][valid_idx].shape)
    print(test_batch_x[task_idx].shape)
    
    meta_train_batch.append(Task(train_loader, val_loader, test_loader))
    
meta_train_batch

torch.Size([60, 1, 32, 32])
torch.Size([15, 1, 32, 32])
torch.Size([5, 1, 32, 32])


[Task(train_loader=<torch.utils.data.dataloader.DataLoader object at 0x7fda153230b8>, valid_loader=<torch.utils.data.dataloader.DataLoader object at 0x7fda15323208>, test_loader=<torch.utils.data.dataloader.DataLoader object at 0x7fda147bc550>)]

In [61]:
meta_train_batch[0].valid_loader

<torch.utils.data.dataloader.DataLoader at 0x7fda15323208>

In [62]:
next(iter(meta_train_batch[0].valid_loader))[0].shape

torch.Size([5, 1, 32, 32])

In [63]:
task = meta_train_batch[0]

for step, ((train_X, train_y), (val_X, val_y)) in enumerate(
    zip(task.train_loader, task.valid_loader)):
    print(val_X.shape, train_X.shape)

torch.Size([5, 1, 32, 32]) torch.Size([20, 1, 32, 32])
torch.Size([5, 1, 32, 32]) torch.Size([20, 1, 32, 32])
torch.Size([5, 1, 32, 32]) torch.Size([20, 1, 32, 32])
