In [1]:
import torch
import torch.nn as nn
import pdb

import os
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import random_split, Dataset, TensorDataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# make dataset
class GestureDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {
            'data': self.data[idx],
            'label': self.labels[idx]
        }
        return sample

In [3]:
# load data
img_data = torch.load('/mnt/fyp/data/img_data.pt').reshape(3995,-1)
teng_data = torch.load('/mnt/fyp/data/teng_data.pt')
imu_data = torch.load('/mnt/fyp/data/imu_data.pt')
# img_data: 3995*(3*224*224)
# teng_data: 3995*(50*10)
# imu_data: 3995*(50*18)

In [4]:
# merge data
merged_data = torch.cat((img_data, teng_data, imu_data),dim=1)
# evaluate
merged_data.shape # must be 3995*151928

torch.Size([3995, 151928])

In [5]:
count = torch.load('/mnt/fyp/data/count.pt')

for i,num in enumerate(count):
    if i==0:
        selected_merged_data = merged_data[:min(num,100), :]
    else:
        selected_merged_data = torch.cat((selected_merged_data, merged_data[:min(num,100), :]), dim=0)
        
    # split
    merged_data = merged_data[max(num, 100):, :]
    print(merged_data.shape)

torch.Size([3893, 151928])
torch.Size([3790, 151928])
torch.Size([3689, 151928])
torch.Size([3588, 151928])
torch.Size([3486, 151928])
torch.Size([3382, 151928])
torch.Size([3278, 151928])
torch.Size([3171, 151928])
torch.Size([3069, 151928])
torch.Size([2968, 151928])
torch.Size([2866, 151928])
torch.Size([2764, 151928])
torch.Size([2657, 151928])
torch.Size([2556, 151928])
torch.Size([2452, 151928])
torch.Size([2351, 151928])
torch.Size([2250, 151928])
torch.Size([2148, 151928])
torch.Size([2046, 151928])
torch.Size([1945, 151928])
torch.Size([1844, 151928])
torch.Size([1742, 151928])
torch.Size([1640, 151928])
torch.Size([1538, 151928])
torch.Size([1437, 151928])
torch.Size([1336, 151928])
torch.Size([1228, 151928])
torch.Size([1126, 151928])
torch.Size([1025, 151928])
torch.Size([923, 151928])
torch.Size([822, 151928])
torch.Size([719, 151928])
torch.Size([618, 151928])
torch.Size([517, 151928])
torch.Size([413, 151928])
torch.Size([312, 151928])
torch.Size([205, 151928])
torch.Siz

In [6]:
selected_merged_data.shape

torch.Size([3900, 151928])

In [7]:
import random 

class_data = [[] for _ in range(39)]
selected_labels = torch.tensor([i for i in range(39) for _ in range(100)]).reshape(-1,1)

for class_idx in range(39):
    class_mask = (selected_labels == class_idx).squeeze()
    data = selected_merged_data[class_mask]
    labels = selected_labels[class_mask]
    
    train_size = int(0.8 * len(data)) # must be 80
    val_size = len(data) - train_size # must be 20
    
    # random indices
    num_samples = len(data) # must be 100
    random_indices = torch.randperm(num_samples)
    
    train_data = data[random_indices[:train_size]]

    val_data = data[random_indices[train_size:]]
    
    class_data[class_idx] = (train_data, val_data)

In [8]:
for i in range(39):
    if i==0:
        train_data = class_data[i][0]
        val_data = class_data[i][1]
    else:
        train_data = torch.cat((train_data, class_data[i][0]),dim=0)
        val_data = torch.cat((val_data, class_data[i][1]),dim=0)

In [9]:
train_label = [i for i in range(39) for _ in range(80)]
val_label = [i for i in range(39) for _ in range(20)]

In [10]:
train_dataset = GestureDataset(train_data, train_label)
val_dataset = GestureDataset(val_data, val_label)

In [11]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # enable shuffle and drop_last
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [15]:
train_loader = torch.load('/mnt/fyp/data/train_loader')
for batch in train_loader:
    print(batch['data'].shape)
    print(batch['label'].shape)
    print(batch['label'])
    break

torch.Size([64, 151928])
torch.Size([64])
tensor([ 5, 18, 11, 14, 24, 38, 22, 24, 23, 31, 10, 31, 16, 26,  1, 25,  4, 26,
        38, 14, 25, 33, 18, 28, 30, 36,  1, 25,  8, 35, 19, 11, 22, 14, 26, 30,
        31, 19,  6, 17, 23, 34,  6,  5,  9, 31, 18,  1,  8, 35,  5, 28, 26, 13,
        17,  0, 36,  8, 28, 24, 25, 37, 23, 24])
