In [1]:
import pickle as pkl
import torch 
import numpy as np
import os
import sys
sys.path.append('/mnt/workspace/slt_baseline/')

In [2]:
import math
import pdb

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def conv_branch_init(conv, branches):
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    nn.init.constant_(conv.bias, 0)


def conv_init(conv):
    if conv.weight is not None:
        nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.data.normal_(1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0)


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super(TemporalConv, self).__init__()
        pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1))

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 dilations=[1,2,3,4],
                 residual=True,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size]*len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0),
                nn.BatchNorm2d(branch_channels),
                nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation),
            )
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)

        # initialize
        self.apply(weights_init)

    def forward(self, x):
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)

        out = torch.cat(branch_outs, dim=1)
        out += res
        return out


class CTRGC(nn.Module):
    def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1):
        super(CTRGC, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if in_channels == 3 or in_channels == 5 or in_channels == 9:
            self.rel_channels = 8
            self.mid_channels = 16
        else:
            self.rel_channels = in_channels // rel_reduction
            self.mid_channels = in_channels // mid_reduction
        self.conv1 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
        self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1)
        self.tanh = nn.Tanh()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                bn_init(m, 1)

    def forward(self, x, A=None, alpha=1):
        # x: [16, 3, 52, 55] => [B, C, T, V]
        # x1, x2: [16, 8, 55]  -> mapping 'Theta' and 'Phi'
        # x3: [16, 64, 52, 55] -> Feature Transformation
        x1, x2, x3 = self.conv1(x).mean(-2), self.conv2(x).mean(-2), self.conv3(x)
        # M 
        x1 = self.tanh(x1.unsqueeze(-1) - x2.unsqueeze(-2))
        # R
        x1 = self.conv4(x1) * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0)  # N,C,V,V
        x1 = torch.einsum('ncuv,nctv->nctu', x1, x3)
        return x1

class unit_tcn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
        super(unit_tcn, self).__init__()
        pad = int((kernel_size - 1) / 2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
                              stride=(stride, 1))

        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        conv_init(self.conv)
        bn_init(self.bn, 1)

    def forward(self, x):
        x = self.bn(self.conv(x))
        return x


class unit_gcn(nn.Module):
    def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True):
        super(unit_gcn, self).__init__()
        inter_channels = out_channels // coff_embedding
        self.inter_c = inter_channels
        self.out_c = out_channels
        self.in_c = in_channels
        self.adaptive = adaptive
        self.num_subset = A.shape[0]
        self.convs = nn.ModuleList()
        for i in range(self.num_subset):
            self.convs.append(CTRGC(in_channels, out_channels))

        if residual:
            if in_channels != out_channels:
                self.down = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1),
                    nn.BatchNorm2d(out_channels)
                )
            else:
                self.down = lambda x: x
        else:
            self.down = lambda x: 0
        if self.adaptive: # wether to update the Adj Matrix
            self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
        else:
            self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
        self.alpha = nn.Parameter(torch.zeros(1)) # zero gradient
        self.bn = nn.BatchNorm2d(out_channels)
        self.soft = nn.Softmax(-2)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                bn_init(m, 1)
        bn_init(self.bn, 1e-6)

    def forward(self, x):
        y = None
        if self.adaptive:
            A = self.PA
        else:
            A = self.A.cuda(x.get_device())
        for i in range(self.num_subset):
            z = self.convs[i](x, A[i], self.alpha)

            y = z + y if y is not None else z
        y = self.bn(y)
        y += self.down(x)
        y = self.relu(y)


        return y


class TCN_GCN_unit(nn.Module):
    """
        A single stack of spatial GCN and temporal TCN
    """
    def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5, dilations=[1,2]):
        super(TCN_GCN_unit, self).__init__()
        self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive)
        self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations,
                                            residual=False)
        self.relu = nn.ReLU(inplace=True)
        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
        return y


class Model(nn.Module):
    def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
                 drop_out=0, adaptive=True):
        super(Model, self).__init__()

        if graph is None:
            raise ValueError()
        else:
            Graph = import_class(graph)
            self.graph = Graph(**graph_args)

        A = self.graph.A # 3,25,25

        self.num_class = num_class
        self.num_point = num_point
        self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)

        base_channel = 64
        self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive)
        self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
        self.l3 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
        self.l4 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
        self.l5 = TCN_GCN_unit(base_channel, base_channel*2, A, stride=2, adaptive=adaptive)
        self.l6 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
        self.l7 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
        self.l8 = TCN_GCN_unit(base_channel*2, base_channel*4, A, stride=2, adaptive=adaptive)
        self.l9 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)
        self.l10 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)

        self.fc = nn.Linear(base_channel*4, num_class)
        nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
        bn_init(self.data_bn, 1)
        if drop_out:
            self.drop_out = nn.Dropout(drop_out)
        else:
            self.drop_out = lambda x: x

    def forward(self, x):
        # [16, 3, 52, 55, 1] => [B, C, T, V, Man]
        # if len(x.shape) == 3:
        #     N, T, VC = x.shape
        #     x = x.view(N, T, self.num_point, -1).permute(0, 3, 1, 2).contiguous().unsqueeze(-1)
        # N, C, T, V, M = x.size()

        # x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
        # # x = self.data_bn(x)
        # x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
        
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        x = self.l7(x)
        x = self.l8(x)
        x = self.l9(x)
        x = self.l10(x)

        # N*M,C,T,V 
        # [16, 256, 13, 55]
        # c_new = x.size(1)
        # x = x.view(N, M, c_new, -1)
        # x = x.mean(3).mean(1)
        # x = self.drop_out(x)

        # return self.fc(x)
        return x

In [3]:
class PoseBackboneWrapper(nn.Module):
    def __init__(self):
        super(PoseBackboneWrapper, self).__init__()
        self.pose_model = Model(
                num_class=2000, num_point=78, num_person=1, 
                graph='models.graph.openpose_78.Graph',
                graph_args={'labeling_mode': 'spatial'}, drop_out=0)
        pose_weights = torch.load('/mnt/workspace/slt_baseline/models/ckpt/ctr_op78_mix_HF05_F64_e1/runs-82-93316.pt')
        # output [B C T V]
        self.pose_model.load_state_dict(pose_weights, strict=True)
    
    def forward(self, prefix):
        pose_output = self.pose_model(prefix) # B C T V
        pose_pool = pose_output.mean(-1).mean(-1) # B C T
        # prefix = pose_pool.transpose(-1, -2) # B T C
        out_cls = self.pose_model.fc(pose_pool) # B T Class
        return out_cls

In [4]:
def gen_slide(length, span=8, step=2):
    if length <= span:
        diff = span - length
        idxs = np.array(range(length))
        idxs = np.concatenate((idxs, (length-1)*np.ones(diff)))
        idxs = idxs[None,:]
    else:
        num_clips = (length - span + (step - 1)) // step + 1
        offsets = np.arange(num_clips)[:,None] * step
        idxs = offsets + np.arange(span)[None, :]
    # idxs = np.mod(idxs, length) # ensure no out of bounds
    idxs = idxs.clip(max=length-1) 
    return idxs


In [5]:
def read_pose_file(filepath):
    # exclude lower body parts    

    body_pose_exclude = {9, 10, 11, 22, 23, 24, 12, 13, 14, 19, 20, 21}

    body_sample_indices = [x for x in range(25) if x not in body_pose_exclude]

    with open(filepath, 'rb') as f:
        pose_dict = pkl.load(f)
    # [X Y Confidence]
    body_pose = pose_dict['pose_keypoints'] # [F, 25, 3] -> [F, 13, 2]
    hand_left = pose_dict['hand_left_keypoints'] # [F, 21, 3] -> [F, 21, 2]
    hand_right = pose_dict['hand_right_keypoints'] # [F, 21, 3] -> [F, 21, 2]
    face = pose_dict['face_keypoints'] #  [F, 70, 3] -> [F, 70, 2]
    # sample body
    body_pose = body_pose[:, body_sample_indices, :] # 23(17+6) 11 selected
    # sample face
    face_sample_index = [71, 77, 85, 89] + \
                        [40, 42, 44, 45, 47, 49] + \
                        [59, 60, 61, 62, 63, 64] + [65, 66, 67, 68, 69, 70] + \
                        [50]
    face_sample_index = [(x - 23) for x in face_sample_index]
    face = face[:, face_sample_index, :] # 23 Keypoints

    pose_tuple = (body_pose, hand_left, hand_right, face)
    pose_cated = np.concatenate(pose_tuple, axis=1) # [F, 13+21+21+23=78, 3]

    pose_cated[:, :, 0:2] = 2.0 * ((pose_cated[:, :, 0:2] / 256.0) - 0.5) # scale to [-1, 1] by image frame

    return pose_cated

In [6]:
def _normalize_joints(value):
    # scale to [-1, 1]
    scalerValue = np.reshape(value, (-1, 3))
    scalerValue = (scalerValue - np.min(scalerValue, axis=0)) / ((np.max(scalerValue, axis=0) - np.min(scalerValue,axis=0)) + 1e-5)
        
    scalerValue = scalerValue * 2 - 1
    scalerValue = np.reshape(scalerValue, (-1, 78, 3))

    return scalerValue

In [7]:
def inference(input_data, model):
    device = 'cuda:0'
    model = model.to(device)
    input_data = torch.from_numpy(input_data).type(torch.float32).to(device)
    input_data = input_data.permute(0, 3, 1, 2) # B T V C -> B C T V
    with torch.no_grad():
        output = model(input_data) # B T Class
    return output 

In [8]:
def load_openpose_sample(sample_name, wlasl_openpose_root='/mnt/workspace/WLASL/data/openpose'):
    file_path = os.path.join(wlasl_openpose_root, f'{sample_name}', f'{sample_name}.pkl')
    value = read_pose_file(file_path)
    value[:,:,:2] = _normalize_joints(value)[:,:,:2]
    return value

In [9]:
class_path = '/mnt/workspace/CTR-GCN/wlasl2000_label.txt'
with open(class_path, 'r') as f:
    vocab = f.readlines()
    vocab = [x.strip() for x in vocab]
print(len(vocab))
vocab = np.array(vocab)

2000


In [None]:
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
import json
subset = 'asl2000'
split = 'test'

wlasl_root = '/mnt/workspace/WLASL'
split_file = os.path.join(wlasl_root, f'data/splits/{subset}.json')
label_encoder = LabelEncoder()
with open(split_file, 'r') as f:
    content = json.load(f)
# init lable encoder
glosses = sorted([gloss_entry['gloss'] for gloss_entry in content])
label_encoder.fit(glosses)
# 
data_list = []
for entry in content:
    gloss, instances = entry['gloss'], entry['instances']
    gloss_cat = label_encoder.transform([gloss])[0] # label index
    for instance in instances:
        if instance['split'] not in split:
            continue

        video_id = instance['video_id']
        
        instance_entry = video_id, gloss_cat
        data_list.append(instance_entry)
train_dict = dict(data_list)


In [85]:
sample = '57488' # label 963
lable = train_dict[sample]
value = load_openpose_sample(sample)
print(value.shape)
T, V, C = value.shape

(29, 78, 3)


In [86]:
# sliding window inference
slide_window = gen_slide(T, span=8, step=2)
input_data = value[slide_window, ...]
output = inference(value[slide_window, ...], model) # B T Class
print('Output shape:', output.shape)
output = output.mean(-2)
idxs = output.argmax(dim=-1).squeeze().cpu().numpy()
print('Index shape:', idxs.shape)
slide_out = vocab[idxs]
print('Target:', vocab[lable])
print(slide_out)

Output shape: torch.Size([12, 2000])
Index shape: ()
Target: terrible
cool


In [87]:
sample_idx = np.linspace(0, T-1, 52).astype(np.int)
output = inference(value[None, sample_idx, ...], model)
# output = inference(value[None, ...], model)
print('Output shape:', output.shape)
output = output.mean(-2)
idxs = output.argmax(dim=-1).squeeze().cpu().numpy()
print('Index shape:', idxs.shape)
slide_out = vocab[idxs]
print('Target:', vocab[lable])
print(slide_out)

Output shape: torch.Size([1, 2000])
Index shape: ()
Target: terrible
before


In [20]:
from tqdm.notebook import tqdm

model = PoseBackboneWrapper()

total = len(train_dict)
print(total)
count = 0
curlen = 0
progress = tqdm(total=total)
for s, l in train_dict.items():
    value = load_openpose_sample(s)
    T, V, C = value.shape
    sample_idx = np.linspace(0, T-1, 52).astype(np.int)
    output = inference(value[None, sample_idx, :, :], model)
    # output = output.mean(-2)
    idxs = output.argmax(dim=-1).squeeze().cpu().numpy()
    # print(idxs)
    curlen += 1
    if idxs == l:
        count += 1
    progress.set_postfix({'acc': count/curlen})
    progress.update()
progress.close()
print(f'{count}/{total}, {count/total}')


2879


  0%|          | 0/2879 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [66]:
print(train_dict)

{'07092': 210, '07093': 210, '07095': 210, '07072': 210, '17730': 568, '17713': 568, '17716': 568, '17718': 568, '12306': 392, '12332': 392, '12336': 392, '12320': 392, '67519': 392, '05735': 168, '05727': 168, '05739': 168, '05741': 168, '09847': 313, '09867': 313, '09855': 313, '24857': 789, '24960': 789, '67715': 789, '11305': 358, '11311': 358, '11318': 358, '63219': 1955, '63226': 1955, '63233': 1955, '65300': 285, '08915': 285, '08925': 285, '65415': 433, '13635': 433, '67535': 433, '14855': 472, '14883': 472, '14888': 472, '65717': 703, '70234': 703, '21890': 703, '27194': 858, '27209': 858, '27213': 858, '38482': 1166, '38525': 1166, '38531': 1166, '57941': 1797, '57943': 1797, '57944': 1797, '62152': 1913, '62160': 1913, '62166': 1913, '64201': 1991, '64210': 1991, '64209': 1991, '64275': 1993, '64293': 1993, '64297': 1993, '68720': 48, '01986': 48, '01996': 48, '06455': 196, '06481': 196, '06483': 196, '70271': 418, '13198': 418, '13196': 418, '21933': 705, '70361': 705, '219