In [16]:
import torch
import torch.nn as nn
import math

In [17]:
class GlobalAttentionHead(nn.Module):
    """
    Self attention performed globally.
    """
    def __init__(self, in_dim, out_dim):
        super(GlobalAttentionHead, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.q_mlp = nn.Linear(in_dim, out_dim//4)
        self.k_mlp = nn.Linear(in_dim, out_dim//4)
        self.v_mlp = nn.Linear(in_dim, out_dim)
        
        self.residual = nn.Linear(in_dim, out_dim)
        
        self.batch_norm = nn.BatchNorm1d(out_dim)
        self.activation = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
        
        
    def forward(self, x):
        q_feats = self.q_mlp(x) #(N, out_dim//4)
        k_feats = self.k_mlp(x).transpose(0,1)   # (out_dim//4, N)
        v_feats = self.v_mlp(x)  # (N, out_dim)
        
        # Compute attention
        energy = torch.matmul(q_feats, k_feats) / math.sqrt(self.out_dim // 4)
        attention = self.softmax(energy)
        
        at_feats = torch.matmul(energy, v_feats)
        at_feats = self.activation(self.batch_norm(at_feats))
        
        residual = self.residual(x)
        
        return residual + at_feats
        
        
        
        
        

In [18]:
net = GlobalAttentionHead(32, 64)

In [22]:
a = torch.rand(459000, 32)

In [23]:
net(a).shape

torch.Size([45900, 64])

In [None]:
2048 * 32 @ 32 * 2048

In [19]:
import numpy as np
import pandas as pd
import glob

In [64]:
seq = '03'

In [62]:
files = sorted(glob.glob(f'../GenderData/{seq}/*'))[:-5]

In [38]:
files

['../GenderData/04/0.ply',
 '../GenderData/04/000000000.csv',
 '../GenderData/04/000003750.csv',
 '../GenderData/04/000007500.csv',
 '../GenderData/04/000011250.csv',
 '../GenderData/04/000015000.csv',
 '../GenderData/04/000018750.csv',
 '../GenderData/04/000022500.csv',
 '../GenderData/04/000026250.csv',
 '../GenderData/04/000030000.csv',
 '../GenderData/04/000033750.csv',
 '../GenderData/04/000037500.csv',
 '../GenderData/04/000041250.csv',
 '../GenderData/04/000045000.csv',
 '../GenderData/04/000048750.csv',
 '../GenderData/04/000052500.csv',
 '../GenderData/04/000056250.csv',
 '../GenderData/04/000060000.csv',
 '../GenderData/04/000063750.csv',
 '../GenderData/04/000067500.csv',
 '../GenderData/04/000071250.csv',
 '../GenderData/04/000075000.csv',
 '../GenderData/04/000078750.csv',
 '../GenderData/04/000082500.csv',
 '../GenderData/04/000086250.csv',
 '../GenderData/04/000090000.csv',
 '../GenderData/04/000093750.csv',
 '../GenderData/04/000097500.csv',
 '../GenderData/04/000101250

In [65]:
files = sorted(glob.glob(f'../GenderData/{seq}/*'))[:-5]

for file in files:
    df = pd.read_csv(file)
    labels = np.unique(df['Label'])
    print(labels)

[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[2 3 4 5 6 7 8]
[2 3 4 5 6 7 8]
[2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3 4 5 6 8]
[1 2 3

In [57]:
files

[]

In [28]:
files

['../GenderData/01/000000000.csv',
 '../GenderData/01/000003750.csv',
 '../GenderData/01/000007500.csv',
 '../GenderData/01/000011250.csv',
 '../GenderData/01/000015000.csv',
 '../GenderData/01/000018750.csv',
 '../GenderData/01/000022500.csv',
 '../GenderData/01/000026250.csv',
 '../GenderData/01/000030000.csv',
 '../GenderData/01/000033750.csv',
 '../GenderData/01/000037500.csv',
 '../GenderData/01/000041250.csv',
 '../GenderData/01/000045000.csv',
 '../GenderData/01/000048750.csv',
 '../GenderData/01/000052500.csv',
 '../GenderData/01/000056250.csv',
 '../GenderData/01/000060000.csv',
 '../GenderData/01/000063750.csv',
 '../GenderData/01/000067500.csv',
 '../GenderData/01/000071250.csv',
 '../GenderData/01/000075000.csv',
 '../GenderData/01/000078750.csv',
 '../GenderData/01/000082500.csv',
 '../GenderData/01/000086250.csv',
 '../GenderData/01/000090000.csv',
 '../GenderData/01/000093750.csv',
 '../GenderData/01/000097500.csv',
 '../GenderData/01/000101250.csv',
 '../GenderData/01/0

In [15]:
np.unique(df['Label'])

array([1, 2, 3, 4, 5, 6])