In [1]:
import torch 
from torch.nn import functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import json
from tqdm.notebook import tqdm
import math
from sklearn.metrics import accuracy_score
import os
import gc

import transformers
from transformers.models.speech_to_text import Speech2TextConfig
from transformers.models.speech_to_text.modeling_speech_to_text import shift_tokens_right, Speech2TextDecoder
from Squeezeformer import SqueezeformerEncoder
from timm.layers.norm_act import BatchNormAct2d

Dataset

In [2]:
# constants
FEATURES = {
    'hand': {
        'left': range(21),
        'right': range(21)
    },
    'pose': {
        'left': [13, 15, 17, 19, 21],
        'right': [14, 16, 18, 20, 22]
    },
    #'head': range(468)
}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
VAL_SPLIT = 0.8
TEST_SPLIT = 0.2
BATCH_SIZE = 64

FRAME_LEN = 128
MAX_LEN = 384 # max number of frames
MAX_PHRASE = 31 + 3 # max len from data + start, pad, end tokens

START_TOKEN = 'S'
PAD_TOKEN = 'P'
END_TOKEN = 'E'

cuda


In [3]:
def extract_columns(features: dict = FEATURES) -> list:
    HAND_COLS = []
    if 'hand' in features:
        HAND_COLS = [
            f'{d}_{o}_hand_{i}' 
            for o in ['right', 'left'] 
            for i in features['hand'][o] 
            for d in ['x', 'y', 'z']
        ]
        
    POSE_COLS = []
    if 'pose' in features:
        POSE_COLS = [
            f'{d}_pose_{i}' 
            for o in ['right', 'left'] 
            for i in features['pose'][o]
            for d in ['x', 'y', 'z']
        ]
        
    HEAD_COLS = []
    if 'head' in features:
        HEAD_COLS = [
            f'{d}_head_{i}' 
            for i in features['head'] 
            for d in ['x', 'y', 'z']
        ]
    
    return HAND_COLS + POSE_COLS + HEAD_COLS


class FingerspellingDataset(Dataset):
    def __init__(self,
                 features: dict = FEATURES,
                 train: str = True,
                 transform = None
                ):
        
        if train:
            # dataset_path = '/kaggle/input/asl-fingerspelling/train_landmarks'
            # dataset_file = '/kaggle/input/asl-fingerspelling/train.csv'
            dataset_path = 'C:/MASTERS/SYDE Project/asl-fingerspelling/train_landmarks'
            dataset_file = 'C:/MASTERS/SYDE Project/asl-fingerspelling/train.csv'
        else:
            # dataset_path = '/kaggle/input/asl-fingerspelling/supplemental_landmarks'
            # dataset_file = '/kaggle/input/asl-fingerspelling/supplemental_metadata.csv'
            dataset_path = 'C:/MASTERS/SYDE Project/asl-fingerspelling/supplemental_landmarks'
            dataset_file = 'C:/MASTERS/SYDE Project/asl-fingerspelling/train.csv'
        
        self.dataset_path = dataset_path
        self.dataset_df = pd.read_csv(dataset_file)
        self.feature_columns = extract_columns(features)
        self.transform = transform
        
        # fetch the data from the .parquet file
        # filter out the non used columns
        self.parquet_df = {
            file.split('.')[0]: pq.read_table(
                f"{self.dataset_path}/{file.split('.')[0]}.parquet",
                columns=['sequence_id'] + self.feature_columns
            ).to_pandas()
            for file in os.listdir(dataset_path)
        }
        # convert parquet data to numpy
        self.parquet_np = {
            file_id: self.parquet_df[file_id].to_numpy() for file_id in self.parquet_df
        }
        
        self.X_IDX = [i for i, col in enumerate(self.feature_columns)  if "x_" in col]
        self.Y_IDX = [i for i, col in enumerate(self.feature_columns)  if "y_" in col]
        self.Z_IDX = [i for i, col in enumerate(self.feature_columns)  if "z_" in col]
    
    def __len__(self):
        return len(self.dataset_df)
    
    def __getitem__(self, index):
        # convert to list of indices (if index is a tensor)
        if torch.is_tensor(index):
            index = index.tolist()
    
        # locate sample in dataset dataframe
        sequence_id, file_id, phrase = self.dataset_df.iloc[index][['sequence_id', 'file_id', 'phrase']]
        
        # filter dataset and fetch entries for the relevant file_id
        file_df = self.dataset_df.loc[self.dataset_df["file_id"] == file_id]
    
        # filter the parquet data by the sequence_id of the sample
        frames = self.parquet_np[str(file_id)][self.parquet_df[str(file_id)].index == sequence_id]
        indices_lists = [self.X_IDX, self.Y_IDX, self.Z_IDX]
        frames = np.stack([frames[:, indices] for indices in indices_lists], axis=-1)
        frames = frames.reshape(frames.shape[0], -1, len(indices_lists))

        sample = {
            'data': frames, # numpy.ndarray
            'phrase': phrase, # string
        }
        
        # apply transformation(s)
        if self.transform:
            sample = self.transform(sample)
    
        return sample

In [4]:
class NormalizeAndFillNaNs(object):
    def __init__(self):
        super(NormalizeAndFillNaNs, self).__init__()
    
    def normalize(self,x):
        nonan = x[~torch.isnan(x)].view(-1, x.shape[-1])
        x = x - nonan.mean(0)[None, None, :]
        x = x / nonan.std(0, unbiased=False)[None, None, :]
        return x
    
    def fill_nans(self,x):
        x[torch.isnan(x)] = 0
        return x
        
    def __call__(self, sample):
        x, phrase = sample['data'], sample['phrase']
        #seq_len, 3* n_landmarks -> seq_len, n_landmarks, 3
        x = x.reshape(x.shape[0],3,-1).permute(0,2,1)
        
        # Normalize & fill nans
        x = self.normalize(x)
        x = self.fill_nans(x)
        
        return {
            'data': x, 
            'phrase': phrase
        }

class TokenizePhrase(object):
    def __init__(self):
        # with open('/kaggle/input/asl-fingerspelling/character_to_prediction_index.json', "r") as f:
        with open('C:/MASTERS/SYDE Project/asl-fingerspelling/character_to_prediction_index.json', "r") as f:
            self.char_to_num = json.load(f)
        n = len(self.char_to_num)
        self.char_to_num[PAD_TOKEN] = n
        self.char_to_num[START_TOKEN] = n + 1
        self.char_to_num[END_TOKEN] = n + 2
        
        self.num_to_char = {j:i for i,j in self.char_to_num.items()}
        
    def __call__(self, sample):
        left_hand, right_hand, left_pose, right_pose, data_mask, phrase = sample['left_hand'], sample['right_hand'], sample['left_pose'], sample['right_pose'], sample['data_mask'], sample['phrase']
        
        start_token_id = self.char_to_num[START_TOKEN]
        end_token_id = self.char_to_num[END_TOKEN]
        pad_token_id = self.char_to_num[PAD_TOKEN]
        phrase_tokens = [self.char_to_num[char] for char in phrase]
        if len(phrase_tokens) > MAX_PHRASE - 1:
            phrase_tokens = phrase_tokens[:MAX_PHRASE - 1]
        phrase_tokens = [start_token_id] + phrase_tokens + [end_token_id]
        phrase_mask = [1] * len(phrase_tokens)
        
        to_pad = MAX_PHRASE - len(phrase_tokens)
        phrase_tokens = torch.tensor(phrase_tokens + [pad_token_id] * to_pad)
        phrase_mask = torch.tensor(phrase_mask + [0] * to_pad)
        
        return {
            'left_hand': left_hand, # tensor
            'right_hand': right_hand, # tensor
            'left_pose': left_pose, # tensor
            'right_pose': right_pose, # tensor
            'data_mask': data_mask, # tensor
            
            'phrase': phrase, # string
            'phrase_tokens': phrase_tokens, # tensor (long)
            'phrase_mask': phrase_mask # tensor (long)
        }
    
class InterpolateOrPad(object):
    def __init__(self, max_length: int = MAX_LEN):
        self.max_length = max_length
        
    def __call__(self, sample):
        data, phrase = sample['data'], sample['phrase']
        diff = self.max_length - data.shape[0]
        
        # # crop
        # if diff <= 0:
        #     data = F.interpolate(data.permute(1,2,0),max_len).permute(2,0,1)
        #     data_mask = torch.ones_like(data[:,0,0])
        #     return data, mask

        # pad
        coef = 0
        padding = torch.ones((diff, data.shape[1], data.shape[2]))
        data_mask = torch.ones_like(data[:,0,0])
        data = torch.cat([data, padding * coef])
        data_mask = torch.cat([data_mask, padding[:,0,0] * coef])

        return {
            'data': data,
            'data_mask': data_mask,
            'phrase': phrase
        }
    
class SplitData(object):
    def __init__(self, features: dict = FEATURES):
        columns = extract_columns(features)
        
        self.X_IDX = [i for i, col in enumerate(columns)  if "x_" in col]
        self.Y_IDX = [i for i, col in enumerate(columns)  if "y_" in col]
        self.Z_IDX = [i for i, col in enumerate(columns)  if "z_" in col]
        
        self.RHAND_IDX = list(set([int(i/3) for i, col in enumerate(columns)  if "right" in col]))
        self.LHAND_IDX = list(set([int(i/3) for i, col in enumerate(columns)  if  "left" in col]))
        self.RPOSE_IDX = list(set([int(i/3) for i, col in enumerate(columns)  if  "pose" in col and int(col[-2:]) in FEATURES['pose']['right']]))
        self.LPOSE_IDX = list(set([int(i/3) for i, col in enumerate(columns)  if  "pose" in col and int(col[-2:]) in FEATURES['pose']['left']]))

    def __call__(self, sample):
        data, data_mask, phrase = sample['data'], sample['data_mask'], sample['phrase']
        data, phrase = sample['data'], sample['phrase']
        return {
            'left_hand': data[:, self.LHAND_IDX],
            'right_hand': data[:, self.RHAND_IDX],
            'left_pose': data[:, self.LPOSE_IDX],
            'right_pose': data[:, self.RPOSE_IDX],
            'data_mask': data_mask,
            
            'phrase': phrase,
        }

class ToTensor(object):
    def __call__(self, sample):
        frames, phrase = sample['data'], sample['phrase']
        return {
            'data': torch.from_numpy(frames), 
            'phrase': phrase
        }
    
class Resample(object):
    def __init__(self, rate=(0.8,1.2)):
        self.rate = rate
    
    def interp1d_(self, x, new_size):
        indices = torch.linspace(0, len(x) - 1, steps=new_size, dtype=torch.float32)
        indices_floor = torch.floor(indices).to(torch.int64)
        indices_frac = indices - indices_floor
        indices_floor = torch.clamp(indices_floor, 0, len(x) - 2)
 
        x0 = x[indices_floor]
        x1 = x[indices_floor + 1]
        indices_frac = indices_frac.view(-1, 1, 1)
        new_x = x0 + (x1 - x0) * indices_frac
        return new_x

    def __call__(self, sample):
        frames, phrase = sample['data'], sample['phrase']
        if torch.rand(1)>0.8:
            rate = torch.FloatTensor(1).uniform_(self.rate[0], self.rate[1])
            length = frames.shape[0]
            new_size = int(rate * length)
            new_frames = self.interp1d_(frames, new_size)
        else:
            new_frames = frames
        return {'data': new_frames, 'phrase': phrase}
    
class SpatialRandomAffine(object):
    def __init__(self, 
                 scale=(0.8, 1.2),
                 shear=(-0.15, 0.15),
                 shift=(-0.1, 0.1),
                 degree=(-30, 30)
                ):
        self.scale = scale
        self.shear = shear
        self.shift = shift
        self.degree = degree
    
    def __call__(self, sample):
        data, phrase = sample['data'], sample['phrase']
        if torch.rand(1)>0.75:
            center = torch.tensor([0.5, 0.5])
    
            if self.scale is not None:
                scale = torch.rand(1).item() * (self.scale[1] - self.scale[0]) + self.scale[0]
                data = scale * data

            if self.shear is not None:
                xy = data[..., :2]
                z = data[..., 2:]
                shear_x = shear_y = torch.rand(1).item() * (self.shear[1] - self.shear[0]) + self.shear[0]
                if torch.rand(1).item() < 0.5:
                    shear_x = 0.0
                else:
                    shear_y = 0.0
                shear_mat = torch.tensor([
                    [1.0, shear_x],
                    [shear_y, 1.0]
                ])
                xy = torch.matmul(xy, shear_mat)
                center = center + torch.tensor([shear_y, shear_x])
                data = torch.cat([xy, z], dim=-1)
            
            if self.degree is not None:
                xy = data[..., :2]
                z = data[..., 2:]
                xy -= center
                degree = torch.rand(1).item() * (self.degree[1] - self.degree[0]) + self.degree[0]
                radian = degree / 180 * torch.tensor([3.14159265358979323846])
                c = torch.cos(radian)
                s = torch.sin(radian)
                rotate_mat = torch.tensor([
                    [c, s],
                    [-s, c]
                ])
                xy = torch.matmul(xy, rotate_mat)
                xy = xy + center
                data = torch.cat([xy, z], dim=-1)

            if self.shift is not None:
                shift = torch.rand(1).item() * (self.shift[1] - self.shift[0]) + self.shift[0]
                data = data + shift
            
        return {'data': data, 'phrase': phrase}
    
    
class TemporalMask(object):
    def __init__(self, size=(0.2,0.4), mask_value=float('nan')):
        self.size = size
        self.mask_value = mask_value
        
    def __call__(self, sample):
        data, phrase = sample['data'], sample['phrase']
        if torch.rand(1)>0.5:
            l = data.shape[0]
            mask_size = torch.rand(1).item() * (self.size[1] - self.size[0]) + self.size[0]
            mask_size = int(l * mask_size)
            mask_offset = torch.randint(0, l - mask_size + 1, (1,)).item()
            mask_indices = torch.arange(mask_offset, mask_offset + mask_size).unsqueeze(1)
            mask = torch.full((mask_size, 52, 3), self.mask_value, dtype=data.dtype)
            data[mask_indices,...] = mask.unsqueeze(1)
        
        return {'data': data, 'phrase': phrase}
    
class SpatialMask(object):
    def __init__(self, size=(0.2,0.4), mask_value=float('nan')):
        self.size = size
        self.mask_value = mask_value
        
    def __call__(self, sample):
        # TODO: determine if this works as intended (does x/y refer to xyz coordinates?)
        xyz, phrase = sample['data'], sample['phrase']
        if torch.rand(1)>0.5:
            mask_offset_y = torch.rand(1).item()
            mask_offset_x = torch.rand(1).item()
            mask_size = torch.rand(1).item() * (self.size[1] - self.size[0]) + self.size[0]

            mask_x = (mask_offset_x < xyz[..., 0]) & (xyz[..., 0] < mask_offset_x + mask_size)
            mask_y = (mask_offset_y < xyz[..., 1]) & (xyz[..., 1] < mask_offset_y + mask_size)
            mask = mask_x & mask_y

            xyz = torch.where(mask.unsqueeze(-1), torch.tensor(self.mask_value), xyz)

        
        return {'data': xyz, 'phrase': phrase}
    
# TODO: determine indices for left/right 
LEFT = []
RIGHT = []

class FlipLeftRight(object):
    def __init__(self, left=LEFT, right=RIGHT):
        self.left = left
        self.right = right
        
    # TODO: fix, not sure if non 3d data format will work (ie requires xyz to extend into 3rd dimension)
    def __call__(self, sample):
        xyz, phrase = sample['data'], sample['phrase']
        if torch.rand(1)>0.5:
            x, y, z = torch.unbind(xyz, dim=-1)
            x = 1 - x
            new_xyz = torch.stack([x, y, z], dim=-1)
            new_xyz = new_xyz.transpose(0, 1)

            l_x = new_xyz[self.left]
            r_x = new_xyz[self.right]

            for i in range(len(self.left)):
                new_xyz[self.left[i]] = r_x[i]
                new_xyz[self.right[i]] = l_x[i]

            new_xyz = new_xyz.transpose(0, 1)
        
        else:
            new_xyz = xyz

        return {'data': new_xyz, 'phrase': phrase}

class TemporalCrop(object):
    def __init__(self, length=32):
        self.length = length
        
    def __call__(self, sample):
        data, phrase = sample['data'], sample['phrase']
        
        l = data.shape[0]
        # TODO remove self.length<l-1 once fixed
        if self.length is not None and self.length<l-1:
            offset = torch.randint(0, l - self.length + 1, (1,)).item()
            data = data[offset:offset + self.length]
            
        return {
            'data': data, 
            'phrase': phrase
        }
    
def get_data_loaders(features: dict = FEATURES, val_split: float = 0.8, test_split: float = 0.2, batch_size: int = 16):
    # load datasets
    transform = transforms.Compose([
                ToTensor(),
                NormalizeAndFillNaNs(),
                Resample(),
                SpatialRandomAffine(),
                SpatialMask(),
                TemporalMask(),
                TemporalCrop(),
                InterpolateOrPad(), 
                SplitData(),
                TokenizePhrase() 
    ])
    train_data = FingerspellingDataset(features=features, train=True, transform=transform) 
    
    test_val_data = FingerspellingDataset(features=features, train=False, transform=transform) 
    dataset_size = len(test_val_data)
    test_data, val_data = torch.utils.data.random_split(test_val_data, [math.ceil(test_split * dataset_size), math.floor(val_split * dataset_size)])

    # setup data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)
    
    return train_loader, val_loader, test_loader

In [5]:
# train_data = FingerspellingDataset() # need to add transforms
# train_loader,_,_ = get_data_loaders()

In [6]:
# tr_it = iter(train_loader)

# print(next(tr_it)['left_hand'].shape)
# interpolateD = InterpolateOrPad()
# print(train_data[1]['data'].shape)
# print(interpolateD(train_data[1])['data'].shape)
# print(interpolateD(train_data[1])['data_mask'].shape)
# splitD = SplitData()
# tokenD = TokenizePhrase()
# print(list(tokenD(splitD(interpolateD(train_data[1]))).keys()))
# print(tokenD(splitD(interpolateD(train_data[1])))['left_hand'].shape)
# print(tokenD(splitD(interpolateD(train_data[1])))['phrase_tokens'])
# print(tokenD(splitD(interpolateD(train_data[1])))['phrase_mask'])



Model Architecture

In [7]:
class FeatureExtractor(nn.Module):
    def __init__(self,
                 n_landmarks,out_dim, conv_ch = 3):
        super().__init__()   

        self.in_channels = in_channels = 32 * math.ceil(n_landmarks / 2)
        self.stem_linear = nn.Linear(in_channels,out_dim,bias=False)
        self.stem_bn = nn.BatchNorm1d(out_dim, momentum=0.95)
        self.conv_stem = nn.Conv2d(conv_ch, 32, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), bias=False)
        self.bn_conv = BatchNormAct2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True,act_layer = nn.SiLU,drop_layer=None)
        
    def forward(self, data, mask):


        xc = data.permute(0,3,1,2)
        xc = self.conv_stem(xc)
        xc = self.bn_conv(xc)
        xc = xc.permute(0,2,3,1)
        xc = xc.reshape(*data.shape[:2], -1)
        
        m = mask.to(torch.bool)  
        x = self.stem_linear(xc)
        
        # Batchnorm without pads
        bs,slen,nfeat = x.shape
        x = x.view(-1, nfeat)
        x_bn = x[mask.view(-1)==1].unsqueeze(0)
        x_bn = self.stem_bn(x_bn.permute(0,2,1)).permute(0,2,1)
        x[mask.view(-1)==1] = x_bn[0]
        x = x.view(bs,slen,nfeat)
        # Padding mask
        x = x.masked_fill(~mask.bool().unsqueeze(-1), 0.0)
        
        return x

    
class Decoder(nn.Module):
    def __init__(self, decoder_config):
        super(Decoder, self).__init__()
        
        self.config = decoder_config
        self.decoder = Speech2TextDecoder(decoder_config) 
        self.lm_head = nn.Linear(decoder_config.d_model, decoder_config.vocab_size, bias=False)
        
        self.decoder_start_token_id = decoder_config.decoder_start_token_id
        self.decoder_pad_token_id = decoder_config.pad_token_id #used for early stopping
        self.decoder_end_token_id= decoder_config.eos_token_id
        
    def forward(self,x, labels=None, attention_mask = None, encoder_attention_mask = None):
        
        if labels is not None:
            decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
            
        decoder_outputs = self.decoder(input_ids=decoder_input_ids,
                                       encoder_hidden_states=x, 
                                       attention_mask = attention_mask,
                                       encoder_attention_mask = encoder_attention_mask)
        lm_logits = self.lm_head(decoder_outputs.last_hidden_state)
        return lm_logits
            
    def generate(self, x, max_new_tokens=33, encoder_attention_mask=None):

        decoder_input_ids = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.long).fill_(self.decoder_start_token_id)
        for i in range(max_new_tokens-1):  
            decoder_outputs = self.decoder(input_ids=decoder_input_ids,encoder_hidden_states=x, encoder_attention_mask=encoder_attention_mask)
            logits = self.lm_head(decoder_outputs.last_hidden_state)
            decoder_input_ids = torch.cat([decoder_input_ids,logits.argmax(2)[:,-1:]],dim=1)

            if torch.all((decoder_input_ids==self.decoder_end_token_id).sum(-1) > 0):
                break
                
        return decoder_input_ids
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        dim=208
        n_handlandmarks = 21
        n_poselandmarks = 5
        n_landmarks = 2*n_handlandmarks+2*n_poselandmarks

        d_cfg = Speech2TextConfig.from_pretrained("facebook/s2t-small-librispeech-asr")
        d_cfg.encoder_layers = 0
        d_cfg.decoder_layers = 2
        d_cfg.d_model = dim
        d_cfg.max_target_positions = 1024 #?
        d_cfg.num_hidden_layers = 1
        d_cfg.vocab_size = 63
        d_cfg.bos_token_id = 60
        d_cfg.eos_token_id = 61
        d_cfg.decoder_start_token_id = 60
        d_cfg.pad_token_id = 59
        d_cfg.num_conv_layers = 0
        d_cfg.conv_kernel_sizes = []
        d_cfg.max_length = dim
        d_cfg.input_feat_per_channel = dim
        d_cfg.num_beams = 1
        d_cfg.attention_dropout = 0.2
        d_cfg.decoder_ffn_dim = 512
        d_cfg.init_std = 0.02
        
        self.feature_extractor = FeatureExtractor(n_landmarks=n_landmarks,out_dim=dim)
        self.feature_extractor_lhand = FeatureExtractor(n_handlandmarks,out_dim=dim//4)
        self.feature_extractor_rhand = FeatureExtractor(n_handlandmarks,out_dim=dim//4)
        self.feature_extractor_lpose = FeatureExtractor(n_poselandmarks,out_dim=dim//4)
        self.feature_extractor_rpose = FeatureExtractor(n_poselandmarks,out_dim=dim//4)
        
        self.encoder = SqueezeformerEncoder(
                      input_dim=dim,
                      encoder_dim=dim,
                      num_layers=5,
                      num_attention_heads= 4,
                      feed_forward_expansion_factor=1,
                      conv_expansion_factor= 2,
                      input_dropout_p=0.1,
                      feed_forward_dropout_p= 0.1,
                      attention_dropout_p= 0.1,
                      conv_dropout_p= 0.1,
                      conv_kernel_size= 51,)
        self.decoder = Decoder(d_cfg)
        self.loss_fn = nn.CrossEntropyLoss() #done
        print('n_params:',count_parameters(self))

    def forward(self, batch):
        #Concat,'rhand','lhand','lpose','rpose'
        x = torch.cat([batch['left_hand'],batch['right_hand'],batch['left_pose'],batch['right_pose']],dim=-2)
        labels = batch['phrase_tokens']
        mask = batch['data_mask'].long()
        label_mask = batch['phrase_mask']    

        #maybe normalize
        x_lhand = self.feature_extractor_lhand(batch['left_hand'].clone(), mask)
        x_lpose = self.feature_extractor_lpose(batch['left_pose'].clone(), mask)
        x_rhand = self.feature_extractor_rhand(batch['right_hand'].clone(), mask)
        x_rpose = self.feature_extractor_rpose(batch['right_pose'].clone(), mask)
        
        x1 = torch.cat([x_lhand,x_rhand,x_lpose,x_rpose],dim=-1)
        x = self.feature_extractor(x, mask)
        x = x + x1
        x = self.encoder(x, mask)
        decoder_labels = labels.clone()        
        
        #??
        # if self.training:
        #     m = torch.rand(labels.shape) < self.decoder_mask_aug
        #     decoder_labels[m] = 62

        logits = self.decoder(x,
                            labels=decoder_labels, 
                            encoder_attention_mask=mask.long()
                            )
        
        loss = self.loss_fn(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1))

        output = {'loss':loss}

        if not self.training:
            generated_ids_padded = torch.ones((x.shape[0],self.max_phrase), dtype=torch.long, device=x.device) * 59
            
            if self.val_mode == 'padded':
                generated_ids = self.decoder.generate(x,max_new_tokens=self.max_phrase + 1, encoder_attention_mask=mask.long())
                    
            elif self.val_mode == 'cutted':
                generated_ids = torch.ones((x.shape[0],self.max_phrase+1), dtype=torch.long, device=x.device) * 59
                mask_lens = mask.sum(1)
                for lidx in mask_lens.unique():
                    liddx = lidx == mask_lens
                    preds = self.decoder.generate(x[liddx, :lidx],max_new_tokens=self.max_phrase + 1)
                    generated_ids[liddx, :preds.shape[1]] = preds
                    
            cutoffs = (generated_ids==self.decoder.decoder_end_token_id).float().argmax(1).clamp(0,self.max_phrase)
            for i, c in enumerate(cutoffs):
                generated_ids_padded[i,:c] = generated_ids[i,:c]
            output['generated_ids'] = generated_ids_padded
            output['seq_len'] = batch['seq_len']    
        return output

In [8]:
# input_dict = {'lhand': torch.rand(1,380,21,3),
#             'rhand':torch.rand(1,380,21,3),
#             'lpose':torch.rand(1,380,5,3),
#             'rpose':torch.rand(1,380,5,3),
#             'token_ids':(torch.rand(1,32)*52).long(),
#             'input_mask': torch.ones_like(torch.rand(1,380)),
#             'attention_mask': torch.ones_like(torch.rand(1,32))         
# }
# tr_it = iter(train_loader)
# model = Net()

# print(model(next(tr_it)))

NEED TO FIX AFTER THIS


In [9]:

def train(model, train_loader, optimizer,scheduler,epoch):
    total_loss = 0
    all_predictions = []
    all_targets = []
    loss_history = []

    model = model.to(DEVICE)
    # set model to training mode
    model.train()  
    progress_bar = tqdm(range(len(train_loader))[:],desc=f'Epoch {epoch} Progress ')
    tr_it = iter(train_loader)

    for itr in progress_bar:
        optimizer.zero_grad()
        data = next(tr_it)
        batch = {key: data[key].to(DEVICE) for key in data if key != 'phrase'}
        output = model(batch)
        loss = output['loss']
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        # track some values to compute statistics
        total_loss += loss.item()
        # preds = torch.argmax(outputs, dim=-1)
        # all_predictions.extend(preds.detach().cpu().tolist())
        # all_targets.extend(targets.cpu().tolist())

    # acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(train_loader)
    
    # print average loss and accuracy
    print(f"learning Rate = {optimizer.param_groups[0]['lr']}. average train loss = {final_loss:.2f}")
    return final_loss

def validation(model, val_loader, loss_fn,epoch):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    # set model to evaluation mode
    model.eval()  
    for i, (inputs, targets) in enumerate(val_loader):
        with torch.no_grad():
            model.to(DEVICE)
            outputs = model(inputs.to(DEVICE))
            loss = loss_fn(outputs, targets.to(DEVICE))

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=-1)
            all_predictions.extend(preds.detach().cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(val_loader)
    # Print average loss and accuracy
    print(f"Epoch {epoch + 1} done. average validation loss = {final_loss:.2f}, average validation accuracy = {acc * 100:.3f}%")
    return acc, final_loss

def test(model, test_loader, loss_fn):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    # set model to evaluation mode
    model.eval()
    for i, (inputs, targets) in enumerate(test_loader):
        with torch.no_grad():
            model.to(DEVICE)
            outputs = model(inputs.to(DEVICE))
            loss = loss_fn(outputs, targets.to(DEVICE))

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=-1)
            all_predictions.extend(preds.detach().cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(test_loader)
    # Print average loss and accuracy
    print(f'average test loss = {final_loss:.2f}, average test accuracy = {acc * 100:.3f}%')
    return acc, final_loss

In [10]:
def train_model(
                data_loaders: tuple = None,
                learning_rate: float = 4.5e-3, 
                weight_decay: float = 0.08, 
                features: dict = None,
                warmup_steps=1,
                training_steps = 10
               ):
    
    if not features:
        features = FEATURES
        
    if not data_loaders:
        train_loader, val_loader, test_loader = get_data_loaders(features, VAL_SPLIT, TEST_SPLIT, BATCH_SIZE)
    else:
        train_loader, val_loader, test_loader = data_loaders

    model = Net().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = transformers.get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps*len(train_loader),num_training_steps=training_steps*len(train_loader),num_cycles=0.5
    )

    max_epochs = 16
    best_acc = -1
    dip_count = 0
    num_epochs = 0
    optimizer.zero_grad()

    for e in range(max_epochs):
        gc.collect()
        train_loss = train(model, train_loader, optimizer,scheduler,e)
        
         # early stopping based on train acc
        # if e%2 == 0:
        #     val_acc, val_loss = validation(model, val_loader)
        #     if val_acc >= best_acc:
        #         best_acc = val_acc
        #         dip_count = 0
        #     else:
        #         dip_count +=1

        # if dip_count >1:
        #     break

    # val_acc, val_loss = validation(model, val_loader)
    # print(f'Final Accuracy: {val_acc}')
    # return val_acc


In [11]:
train_model()

n_params: 4058336


Epoch 0 Progress :   0%|          | 0/1051 [00:00<?, ?it/s]

torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])
torch.Size([64, 384, 52, 3])


KeyboardInterrupt: 