In [1]:
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader

In [2]:
from matplotlib import pyplot as plt
from utilities import sliding_windows
import os
import numpy as np
import pandas as pd
import torch


            

inp = np.load('./datasets/OpportunityUCIDataset/loco_2_mask.npy', allow_pickle=True)
inp.item()['inputs'].shape
inputs, labels = inp.item()['inputs'], inp.item()['labels']
sw = sliding_windows(300, 50)
segmented_samples, segmented_labels = sw(torch.tensor(inputs), torch.tensor(labels))
print(segmented_samples.shape, segmented_labels.shape)
torch.isnan(segmented_samples).any()

torch.Size([10508, 300, 6]) torch.Size([10508, 300])


tensor(False)

In [73]:
# custom Dataset:
from torch.utils.data import Dataset, DataLoader

def majority_vote(series):
    """
    Convert a single time series of shape (300,) to its majority-vote class.

    :param series: np.array of shape (300,), where each element is a class label.
    :return: The majority class for the time series.
    """
    counts = np.bincount(series)
    return np.argmax(counts)


class CustomDataset(Dataset):
    def __init__(self, data, label, transform=None):
        self.data = data
        self.label = label
        self.transform = transform
        self.class_labels = [majority_vote(self.label[idx]) for idx in range(len(self.label))]
        assert len(self.data) == len(self.label)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        if self.transform:
            raise NotImplementedError
            return self.transform(self.data[idx]), self.label[idx]
        cur_label = self.label[idx]
        cur_label = np.where(cur_label == 0, 0, 1)
        print(cur_label)
        return torch.tensor(np.concatenate((self.data[idx], cur_label[np.newaxis].T), axis=1)), torch.tensor(self.class_labels[idx], dtype=torch.int16)
    def get_labels(self):
        return self.class_labels

In [74]:
train_set = CustomDataset(segmented_samples, segmented_labels)

train_set[0][0].shape

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0]


torch.Size([300, 7])

In [85]:
n_way = 5
n_shot = 5
n_query = 2
n_tasks_per_epoch = 500


In [116]:
train_sampler = TaskSampler(train_set, n_way, n_shot, n_query, n_tasks_per_epoch)
def wrapped_collate_fn(batch):
    """supp_imgs: support images
                way x shot x [B x 6 x 100], list of lists of tensors
            fore_mask: foreground masks for support images
                way x shot x [B x 100], list of lists of tensors
            back_mask: background masks for support images
                way x shot x [B x 100], list of lists of tensors
            qry_imgs: query images
                N x [B x 6 x 100], list of tensors"""
    # 5 -> 7 of tuples:
    # (support_set, query_set, support_labels, query_labels, classes)
    original_output = train_sampler.episodic_collate_fn(batch)
    (example_support_images, example_support_labels, example_query_images, example_query_labels, example_class_ids )= original_output
    example_support_images, example_support_images_labels = example_support_images[:, :, :-1], example_support_images[:, :, -1]
    example_query_images, example_query_images_labels = example_query_images[:, :, :-1], example_query_images[:, :, -1]
    #return (example_support_images, example_support_images_labels, example_support_labels, example_query_images, example_query_images_labels, example_query_labels, example_class_ids)
    
    example_support_images = [ [example_support_images[i+j, :, :] for j in range(n_shot)] for i in range(n_way)]
    example_support_images_labels = [ [example_support_images_labels[i+j, :] for j in range(n_shot)] for i in range(n_way)]
    example_support_labels = [example_support_labels[i] for i in range(n_way)]
    example_query_images = [example_query_images[i, :, :] for i in range(n_query)]
    example_query_images_labels = [example_query_images_labels[i, :] for i in range(n_query)]
    example_query_labels = [example_query_labels[i] for i in range(n_query)]
    
    return (example_support_images, example_support_images_labels, example_support_labels, example_query_images, example_query_images_labels, example_query_labels, example_class_ids)

    

In [117]:
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=wrapped_collate_fn,
)

In [118]:
(example_support_images, example_support_images_labels, example_support_labels, example_query_images, example_query_images_labels, example_query_labels, example_class_ids)= next(iter(train_loader))
# example_support_images, example_support_images_labels = example_support_images[:, :, :-1], example_support_images[:, :, -1]
# example_query_images, example_query_images_labels = example_query_images[:, :, :-1], example_query_images[:, :, -1]
        

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 

In [121]:
print(len(example_support_images), len(example_support_images_labels), len(example_query_images), len(example_query_images_labels), len(example_class_ids))
# for i in range(8):
#     print(example_support_images_labels[i], example_support_labels[i])
print(len(example_support_images[0]), len(example_support_images_labels[0]))
print(example_support_images[0][0].shape, example_support_images_labels[0][0].shape)

5 5 2 2 5
5 5
torch.Size([300, 6]) torch.Size([300])
