In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [None]:
meta_step_size = 0.25

meta_iters = 1000

eval_interval = 1
train_shots = 40
eval_shots = 4
classes = len(set(labels))

batch_size = 1
#obs: total shots = classes * shots

n_times=X.shape[-1]
n_channels=len(epochs.picks)

seed = 1330
splits = 5
lr=1e-3

skf = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True)

train_index, test_index = skf.split(X, y).__next__()
X_train, X_test = X[train_index], X[test_index],
y_train, y_test = y[train_index], y[test_index]

In [None]:
class Dataset:
    def __init__(self, training):
        split = "train" if training else "test"

        if split:
            X_dataset, y_dataset = X_test, y_test
        else:
            X_dataset, y_dataset = X_train, y_train

        self.data = {}

        for value, label in zip(X_dataset, y_dataset):
            if label not in self.data:
                self.data[label] = []
            self.data[label].append(value)
        self.labels = list(self.data.keys())

    def get_mini_dataset(self, shots, num_classes, split=False):
        temp_labels = torch.zeros((num_classes * shots))
        temp_X = torch.zeros((num_classes * shots, n_channels, n_times))
        if split:
            test_labels = torch.zeros((num_classes * eval_shots))
            test_X = torch.zeros((num_classes * eval_shots, n_channels, n_times))

        # Get a random subset of labels from the entire label set.
        label_subset = random.choices(self.labels, k=num_classes)
        for class_idx, class_obj in enumerate(label_subset):
            # Use enumerated index value as a temporary label for mini-batch in
            # few shot learning.
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
            # If creating a split dataset for testing, select an extra sample from each
            # label to create the test dataset.
            if split:
                test_labels[class_idx] = class_idx
                X_to_split = torch.stack(random.choices(self.data[label_subset[class_idx]], k=shots + 1))
                test_X[class_idx] = X_to_split[-1]
                temp_X[class_idx * shots : (class_idx + 1) * shots] = X_to_split[:-1]
            else:
                # For each index in the randomly selected label_subset, sample the
                # necessary number of images.
                temp_X[class_idx * shots : (class_idx + 1) * shots] = \
                    torch.stack(random.choices(self.data[label_subset[class_idx]], k=shots))

        temp_X, temp_labels = unison_shuffled_copies(temp_X, temp_labels)
        temp_X, temp_labels = torch.stack(temp_X.chunk(batch_size)), torch.stack(temp_labels.chunk(batch_size))
        dataset = zip(temp_X, temp_labels)

        if split:
            test_X, test_labels = unison_shuffled_copies( test_X, test_labels)
            return dataset, test_X, test_labels
        return dataset

train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

In [None]:
training = []
testing = []
for meta_iter in range(meta_iters):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    frac_done = meta_iter / meta_iters

    cur_meta_step_size = (1 - frac_done) * meta_step_size

    old_vars = model.state_dict()

    mini_dataset = train_dataset.get_mini_dataset(
        train_shots, classes
    )

    for X_values, y_labels in mini_dataset:
        y_labels = y_labels.to(dtype=torch.long)
        preds, loss, out_values = model(X_values, y_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    new_vars = model.state_dict()

    for key, var in new_vars.items():
        new_vars[key] = old_vars[key] + ((new_vars[key] - old_vars[key]) * 0.1)

    model.load_state_dict(new_vars)

    # Evaluation loop
    if meta_iter % eval_interval == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # print("test dataset reset!\n")
            # Sample a mini dataset from the full dataset.
            train_set, test_X, test_labels = dataset.get_mini_dataset(
                eval_shots, classes, split=True
            )
            old_vars = model.state_dict()

            for X_values, y_labels in train_set:
                y_labels = y_labels.to(dtype=torch.long)

                preds, test_loss, out_values = model(X_values, y_labels)

                optimizer.zero_grad()
                test_loss.backward()
                optimizer.step()

            test_labels = test_labels.to(dtype=torch.long)
            test_preds, test_loss, test_out_values = model(test_X, test_labels)

            accuracy = (test_preds.argmax(1) == test_labels).type(torch.float32).sum().item() / test_labels.shape[0]
            accuracies.append(accuracy)

            model.load_state_dict(old_vars)

        training.append(accuracies[0])
        testing.append(accuracies[1])

        if meta_iter % 100 == 0:
            print(f"batch {meta_iter}: train={accuracies[0]} test={accuracies[1]}")

In [None]:
# First, some preprocessing to smooth the training and testing arrays for display.
window_length = 100
train_s = np.r_[
    training[window_length - 1 : 0 : -1], training, training[-1:-window_length:-1]
]
test_s = np.r_[
    testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
]
w = np.hamming(window_length)
train_y = np.convolve(w / w.sum(), train_s, mode="valid")
test_y = np.convolve(w / w.sum(), test_s, mode="valid")

# Display the training accuracies.
x = np.arange(0, len(test_y), 1)
plt.plot(x, test_y, x, train_y)
plt.legend(["test", "train"])
plt.grid()