In [1]:
import pickle
import random
import numpy as np

# 2D Paper Code

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class Encoder(nn.Module):
    def __init__(self, channels, kernel_size=8, global_pool=None, convpool=None, compress=False):
        super(Encoder, self).__init__()

        model = []
        acti = nn.LeakyReLU(0.2)

        nr_layer = len(channels) - 2 if compress else len(channels) - 1

        for i in range(nr_layer):
            if convpool is None:
                pad = (kernel_size - 2) // 2
                model.append(nn.ReflectionPad1d(pad))
                model.append(nn.Conv1d(channels[i], channels[i+1],
                                   kernel_size=kernel_size, stride=2))
                model.append(acti)
            else:
                pad = (kernel_size - 1) // 2
                model.append(nn.ReflectionPad1d(pad))
                model.append(nn.Conv1d(channels[i], channels[i+1],
                                       kernel_size=kernel_size, stride=1))
                model.append(acti)
                model.append(convpool(kernel_size=2, stride=2))

        self.global_pool = global_pool
        self.compress = compress

        self.model = nn.Sequential(*model)

        if self.compress:
            self.conv1x1 = nn.Conv1d(channels[-2], channels[-1], kernel_size=1)

    def forward(self, x):
        x = self.model(x)
        if self.global_pool is not None:
            ks = x.shape[-1]
            x = self.global_pool(x, ks)
            if self.compress:
                x = self.conv1x1(x)
        return x


class Decoder(nn.Module):
    def __init__(self, channels, kernel_size=7):
        super(Decoder, self).__init__()

        model = []
        pad = (kernel_size - 1) // 2
        acti = nn.LeakyReLU(0.2)

        for i in range(len(channels) - 1):
            model.append(nn.Upsample(scale_factor=2, mode='nearest'))
            model.append(nn.ReflectionPad1d(pad))
            model.append(nn.Conv1d(channels[i], channels[i + 1],
                                            kernel_size=kernel_size, stride=1))
            if i == 0 or i == 1:
                model.append(nn.Dropout(p=0.2))
            if not i == len(channels) - 2:
                model.append(acti)          # whether to add tanh a last?
                #model.append(nn.Dropout(p=0.2))

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class AutoEncoder2x(nn.Module):
    def __init__(self, mot_en_channels, body_en_channels, de_channels, global_pool=None, convpool=None, compress=False):
        super(AutoEncoder2x, self).__init__()
        assert mot_en_channels[0] == de_channels[-1] and \
               mot_en_channels[-1] + body_en_channels[-1] == de_channels[0]

        self.mot_encoder = Encoder(mot_en_channels)
        self.static_encoder = Encoder(body_en_channels, kernel_size=7, global_pool=global_pool, convpool=convpool, compress=compress)
        self.decoder = Decoder(de_channels)

    def cross(self, x1, x2):
        m1 = self.mot_encoder(x1)
        b1 = self.static_encoder(x1[:, :-2, :])
        m2 = self.mot_encoder(x2)
        b2 = self.static_encoder(x2[:, :-2, :])

        out1 = self.decoder(torch.cat([m1, b1.repeat(1, 1, m1.shape[-1])], dim=1))
        out2 = self.decoder(torch.cat([m2, b2.repeat(1, 1, m2.shape[-1])], dim=1))
        out12 = self.decoder(torch.cat([m1, b2.repeat(1, 1, m1.shape[-1])], dim=1))
        out21 = self.decoder(torch.cat([m2, b1.repeat(1, 1, m2.shape[-1])], dim=1))

        return out1, out2, out12, out21

    def transfer(self, x1, x2):
        m1 = self.mot_encoder(x1)
        b2 = self.static_encoder(x2[:, :-2, :]).repeat(1, 1, m1.shape[-1])

        out12 = self.decoder(torch.cat([m1, b2], dim=1))

        return out12

    def cross_with_triplet(self, x1, x2, x12, x21):
        m1 = self.mot_encoder(x1)
        b1 = self.static_encoder(x1[:, :-2, :])
        m2 = self.mot_encoder(x2)
        b2 = self.static_encoder(x2[:, :-2, :])

        out1 = self.decoder(torch.cat([m1, b1.repeat(1, 1, m1.shape[-1])], dim=1))
        out2 = self.decoder(torch.cat([m2, b2.repeat(1, 1, m2.shape[-1])], dim=1))
        out12 = self.decoder(torch.cat([m1, b2.repeat(1, 1, m1.shape[-1])], dim=1))
        out21 = self.decoder(torch.cat([m2, b1.repeat(1, 1, m2.shape[-1])], dim=1))

        m12 = self.mot_encoder(x12)
        b12 = self.static_encoder(x12[:, :-2, :])
        m21 = self.mot_encoder(x21)
        b21 = self.static_encoder(x21[:, :-2, :])

        outputs = [out1, out2, out12, out21]
        motionvecs = [m1.reshape(m1.shape[0], -1),
                      m2.reshape(m2.shape[0], -1),
                      m12.reshape(m12.shape[0], -1),
                      m21.reshape(m21.shape[0], -1)]
        bodyvecs = [b1.reshape(b1.shape[0], -1),
                      b2.reshape(b2.shape[0], -1),
                      b21.reshape(b21.shape[0], -1),
                      b12.reshape(b12.shape[0], -1)]

        return outputs, motionvecs, bodyvecs

    def forward(self, x):
        m = self.mot_encoder(x)
        b = self.static_encoder(x)
        b = b.repeat(1, 1, m.shape[-1])
        d = torch.cat([m, b], dim=1)
        d = self.decoder(d)
        return d

# Data

In [3]:
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

In [4]:
# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

to delete 28814


In [5]:
# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

In [6]:
for file in X:
    X[file] = torch.tensor(X[file])

In [7]:
a = {}
p = {}
for file in X:
    if file[16:20] not in a:
        a[file[16:20]] = {}
    if file[8:12] not in a[file[16:20]]:
        a[file[16:20]][file[8:12]] = []
    a[file[16:20]][file[8:12]].append(file)
    
    if file[8:12] not in p:
        p[file[8:12]] = set()
    p[file[8:12]].add(file[16:20])

In [8]:
samples = 500
train_x = []
train_y = []

for _ in range(samples):
    # sample two random p
    p1, p2 = random.sample(list(p.keys()), 2)
    # find overlapping a
    a1 = p[p1]
    a2 = p[p2]
    a12 = a1.intersection(a2)
    if len(a12) == 0:
        continue
    # sample two random a
    a1, a2 = random.sample(list(a12), 2)
    # sample x and y
    x1 = random.sample(a[a1][p1], 1)[0]
    x2 = random.sample(a[a2][p2], 1)[0]
    y1 = random.sample(a[a1][p2], 1)[0]
    y2 = random.sample(a[a2][p1], 1)[0]
    train_x.append([x1, x2])
    train_y.append([y1, y2])

In [9]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        return X[self.X[index][0]], X[self.X[index][1]], X[self.y[index][0]],  X[self.y[index][1]]
    
    def __len__(self):
        return len(self.X)

In [10]:
data = Data(train_x, train_y)
dl = DataLoader(data, batch_size=16, shuffle=True)

# Model

In [11]:
mot_en_channels = [75, 64, 96, 128]
body_en_channels = [75, 32, 48, 64]
de_channels = [mot_en_channels[-1] + body_en_channels[-1], 128, 64, 75]
model = AutoEncoder2x(mot_en_channels, body_en_channels, de_channels, global_pool=F.max_pool1d, convpool=nn.MaxPool1d, compress=False)

In [12]:
for (x1, x2, y1, y2) in dl:
    print(model(x1.float()).shape)

torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([16, 75, 72])
torch.Size([7, 75, 72])
