In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd as ag
import higher  # Import the higher library for meta-learning

import numpy as np 
import os
import copy
import matplotlib.pyplot as plt

import src.datamaker as datamaker
import src.training as training
import src.models as models

from tqdm import tqdm

from importlib import reload
reload(datamaker)
reload(training)
reload(models)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['cmr10']
plt.rcParams['mathtext.fontset'] ='cm'
plt.rcParams['figure.facecolor'] = 'white'
plt.rc('axes', unicode_minus=False)
plt.rc('axes.formatter', use_mathtext=True)

In [29]:
def support_query_split(x_train, y_train, spt_frac):
    '''
    Split the data into support and query sets.

    Args:
    - x_train: the input data.
    - y_train: the target data.
    - spt_frac: the fraction of the data to use for the support set.

    Returns:
    - x_spt: the support input data.
    - y_spt: the support target data.
    - x_qry: the query input data.
    - y_qry: the query target data.
    '''
    n = len(x_train)
    n_spt = int(n * spt_frac)
    
    idx = torch.randperm(n)
    x_spt = x_train[idx[:n_spt]]
    y_spt = y_train[idx[:n_spt]]
    x_qry = x_train[idx[n_spt:]]
    y_qry = y_train[idx[n_spt:]]

    return x_spt, y_spt, x_qry, y_qry