In [87]:
!pip3 install torchinfo



In [1]:
import torch 
from torch.nn import functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import json
from tqdm.notebook import tqdm
import math
from sklearn.metrics import accuracy_score

from transformers.models.speech_to_text import Speech2TextConfig
from transformers.models.speech_to_text.modeling_speech_to_text import shift_tokens_right, Speech2TextDecoder
from Squeezeformer import SqueezeformerEncoder
from timm.layers.norm_act import BatchNormAct2d
from torchinfo import summary


Dataset

In [2]:
# constants
FEATURES = {
    'hand': {
        'left': range(21),
        'right': range(21)
    },
    'pose': {
        'left': [13, 15, 17, 19, 21],
        'right': [14, 16, 18, 20, 22]
    },
    #'head': range(468)
}

FRAME_LEN = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VAL_SPLIT = 0.8
TEST_SPLIT = 0.2
BATCH_SIZE = 128

In [3]:
class FingerspellingDataset(Dataset):
    def __init__(self,
                 features: dict = FEATURES,
                 train: str = True,
                 transform = None
                ):
        HAND_COLS = []
        if 'hand' in features:
            HAND_COLS = [
                f'{d}_{o}_hand_{i}' 
                for o in ['right', 'left'] 
                for i in features['hand'][o] 
                for d in ['x', 'y', 'z']
            ]
        POSE_COLS = []
        if 'pose' in features:
            POSE_COLS = [
                f'{d}_pose_{i}' 
                for o in ['right', 'left'] 
                for i in features['pose'][o]
                for d in ['x', 'y', 'z']
            ]
        HEAD_COLS = []
        if 'head' in features:
            HEAD_COLS = [
                f'{d}_head_{i}' 
                for i in features['head'] 
                for d in ['x', 'y', 'z']
            ]
 
        if train:
            # dataset_path = '/kaggle/input/asl-fingerspelling/train_landmarks'
            # dataset_file = '/kaggle/input/asl-fingerspelling/train.csv'
            dataset_path = 'C:/Users/kevin/OneDrive/Desktop/MASTERS/SYDE Project/asl-fingerspelling/train_landmarks'
            dataset_file = 'C:/Users/kevin/OneDrive/Desktop/MASTERS/SYDE Project/asl-fingerspelling/train.csv'
        else:
            # dataset_path = '/kaggle/input/asl-fingerspelling/supplemental_landmarks'
            # dataset_file = '/kaggle/input/asl-fingerspelling/supplemental_metadata.csv'
            dataset_path = 'C:/Users/kevin/OneDrive/Desktop/MASTERS/SYDE Project/asl-fingerspelling/supplemental_landmarks'
            dataset_file = 'C:/Users/kevin/OneDrive/Desktop/MASTERS/SYDE Project/asl-fingerspelling/supplemental_metadata.csv'
        self.dataset_path = dataset_path
        self.dataset_df = pd.read_csv(dataset_file)
        self.feature_columns = HAND_COLS + POSE_COLS + HEAD_COLS
        self.transform = transform
    def __len__(self):
        return len(self.dataset_df)
    def __getitem__(self, index):
        # convert to list of indices (if index is a tensor)
        if torch.is_tensor(index):
            index = index.tolist()
        # locate sample in dataset dataframe
        sequence_id, file_id, phrase = self.dataset_df.iloc[index][['sequence_id', 'file_id', 'phrase']]
        # filter dataset and fetch entries for the relevant file_id
        file_df = self.dataset_df.loc[self.dataset_df["file_id"] == file_id]
        # fetch the data from the .parquet file
        # filter out the non used columns
        parquet_df = pq.read_table(
                        f"{self.dataset_path}/{str(file_id)}.parquet",
                        columns=['sequence_id'] + self.feature_columns
                     ).to_pandas()
        # convert parquet data to numpy
        parquet_np = parquet_df.to_numpy()
        # filter the parquet data by the sequence_id of the sample
        frames = parquet_np[parquet_df.index == sequence_id]
 
        sample = {
            'data': frames, # numpy.ndarray
            'phrase': phrase, # string
        }
        # apply transformation(s)
        if self.transform:
            sample = self.transform(sample)
        return sample

In [4]:
class ToTensor(object):
 
    def __call__(self, sample):
        frames, phrase = sample['data'], sample['phrase']
        return {
            'data': torch.from_numpy(frames), 
            'phrase': phrase
        }

def get_data_loaders(features: dict = FEATURES, val_split: float = 0.8, test_split: float = 0.2, batch_size: int = 16):
    # load datasets
    transform = transforms.Compose([ToTensor()])
    train_data = FingerspellingDataset(features=features, train=True, transform=transform) # need to add transforms
    test_val_data = FingerspellingDataset(features=features, train=False, transform=transform) # need to add transforms
    dataset_size = len(test_val_data)
    test_data, val_data = torch.utils.data.random_split(test_val_data, [math.ceil(test_split * dataset_size), math.floor(val_split * dataset_size)])
 
    # setup data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)
    return train_loader, val_loader, test_loader

In [5]:
train_data = FingerspellingDataset() # need to add transforms
train_data[3]

{'data': array([[ 2.3356757e-01,  8.7151617e-01,  1.4675981e-06, ...,
          1.2006173e+00,  1.3686731e+00, -1.2188929e+00],
        [ 2.5071117e-01,  8.7850654e-01,  1.1639925e-06, ...,
          1.1771251e+00,  1.4146559e+00, -1.3279152e+00],
        [ 2.5689089e-01,  8.8893843e-01,  8.6225799e-07, ...,
          1.1739606e+00,  1.4577438e+00, -1.3690206e+00],
        ...,
        [ 3.0872941e-01,  8.0841637e-01,  7.0709848e-07, ...,
          1.1138570e+00,  1.4244299e+00, -8.6565697e-01],
        [ 2.8616741e-01,  8.2798451e-01,  1.3166692e-07, ...,
          1.1295004e+00,  1.4395330e+00, -8.9529443e-01],
        [ 2.7528158e-01,  8.1241024e-01, -1.1747585e-07, ...,
          1.1178130e+00,  1.4387276e+00, -7.9332572e-01]], dtype=float32),
 'phrase': '988 franklin lane'}

Model Architecture

In [6]:
class FeatureExtractor(nn.Module):
    def __init__(self,
                 n_landmarks,out_dim, conv_ch = 3):
        super().__init__()   

        self.in_channels = in_channels = 32 * math.ceil(n_landmarks / 2)
        self.stem_linear = nn.Linear(in_channels,out_dim,bias=False)
        self.stem_bn = nn.BatchNorm1d(out_dim, momentum=0.95)
        self.conv_stem = nn.Conv2d(conv_ch, 32, kernel_size=(3, 3), stride=(1, 2), padding=(1, 1), bias=False)
        self.bn_conv = BatchNormAct2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True,act_layer = nn.SiLU,drop_layer=None)
        
    def forward(self, data, mask):


        xc = data.permute(0,3,1,2)
        xc = self.conv_stem(xc)
        xc = self.bn_conv(xc)
        xc = xc.permute(0,2,3,1)
        xc = xc.reshape(*data.shape[:2], -1)
        
        m = mask.to(torch.bool)  
        x = self.stem_linear(xc)
        
        # Batchnorm without pads
        bs,slen,nfeat = x.shape
        x = x.view(-1, nfeat)
        x_bn = x[mask.view(-1)==1].unsqueeze(0)
        x_bn = self.stem_bn(x_bn.permute(0,2,1)).permute(0,2,1)
        x[mask.view(-1)==1] = x_bn[0]
        x = x.view(bs,slen,nfeat)
        # Padding mask
        x = x.masked_fill(~mask.bool().unsqueeze(-1), 0.0)
        
        return x

    
class Decoder(nn.Module):
    def __init__(self, decoder_config):
        super(Decoder, self).__init__()
        
        self.config = decoder_config
        self.decoder = Speech2TextDecoder(decoder_config) 
        self.lm_head = nn.Linear(decoder_config.d_model, decoder_config.vocab_size, bias=False)
        
        self.decoder_start_token_id = decoder_config.decoder_start_token_id
        self.decoder_pad_token_id = decoder_config.pad_token_id #used for early stopping
        self.decoder_end_token_id= decoder_config.eos_token_id
        
    def forward(self,x, labels=None, attention_mask = None, encoder_attention_mask = None):
        
        if labels is not None:
            decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
            
        decoder_outputs = self.decoder(input_ids=decoder_input_ids,
                                       encoder_hidden_states=x, 
                                       attention_mask = attention_mask,
                                       encoder_attention_mask = encoder_attention_mask)
        lm_logits = self.lm_head(decoder_outputs.last_hidden_state)
        return lm_logits
            
    def generate(self, x, max_new_tokens=33, encoder_attention_mask=None):

        decoder_input_ids = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.long).fill_(self.decoder_start_token_id)
        for i in range(max_new_tokens-1):  
            decoder_outputs = self.decoder(input_ids=decoder_input_ids,encoder_hidden_states=x, encoder_attention_mask=encoder_attention_mask)
            logits = self.lm_head(decoder_outputs.last_hidden_state)
            decoder_input_ids = torch.cat([decoder_input_ids,logits.argmax(2)[:,-1:]],dim=1)

            if torch.all((decoder_input_ids==self.decoder_end_token_id).sum(-1) > 0):
                break
                
        return decoder_input_ids
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        dim=208
        n_handlandmarks = 21
        n_poselandmarks = 5
        n_landmarks = 2*n_handlandmarks+2*n_poselandmarks

        d_cfg = Speech2TextConfig.from_pretrained("facebook/s2t-small-librispeech-asr")
        d_cfg.encoder_layers = 0
        d_cfg.decoder_layers = 2
        d_cfg.d_model = dim
        d_cfg.max_target_positions = 1024 #?
        d_cfg.num_hidden_layers = 1
        d_cfg.vocab_size = 63
        d_cfg.bos_token_id = 59
        d_cfg.eos_token_id = 60
        d_cfg.decoder_start_token_id = 59
        d_cfg.pad_token_id = 61
        d_cfg.num_conv_layers = 0
        d_cfg.conv_kernel_sizes = []
        d_cfg.max_length = dim
        d_cfg.input_feat_per_channel = dim
        d_cfg.num_beams = 1
        d_cfg.attention_dropout = 0.2
        d_cfg.decoder_ffn_dim = 512
        d_cfg.init_std = 0.02
        
        self.feature_extractor = FeatureExtractor(n_landmarks=n_landmarks,out_dim=dim)
        self.feature_extractor_lhand = FeatureExtractor(n_handlandmarks,out_dim=dim//4)
        self.feature_extractor_rhand = FeatureExtractor(n_handlandmarks,out_dim=dim//4)
        self.feature_extractor_lpose = FeatureExtractor(n_poselandmarks,out_dim=dim//4)
        self.feature_extractor_rpose = FeatureExtractor(n_poselandmarks,out_dim=dim//4)
        
        self.encoder = SqueezeformerEncoder(
                      input_dim=dim,
                      encoder_dim=dim,
                      num_layers=5,
                      num_attention_heads= 4,
                      feed_forward_expansion_factor=1,
                      conv_expansion_factor= 2,
                      input_dropout_p=0.1,
                      feed_forward_dropout_p= 0.1,
                      attention_dropout_p= 0.1,
                      conv_dropout_p= 0.1,
                      conv_kernel_size= 51,)
        self.decoder = Decoder(d_cfg)
        self.loss_fn = nn.CrossEntropyLoss() #done
        print('n_params:',count_parameters(self))

    def forward(self, batch):
        #Concat,'rhand','lhand','lpose','rpose'
        x = torch.cat([batch['lhand'],batch['rhand'],batch['lpose'],batch['rpose']],dim=-2)
        labels = batch['token_ids']
        mask = batch['input_mask'].long()
        label_mask = batch['attention_mask']    

        #maybe normalize
        x_lhand = self.feature_extractor_lhand(batch['lhand'].clone(), mask)
        x_lpose = self.feature_extractor_lpose(batch['lpose'].clone(), mask)
        x_rhand = self.feature_extractor_rhand(batch['rhand'].clone(), mask)
        x_rpose = self.feature_extractor_rpose(batch['rpose'].clone(), mask)
        
        x1 = torch.cat([x_lhand,x_rhand,x_lpose,x_rpose],dim=-1)
        x = self.feature_extractor(x, mask)
        x = x + x1
        x = self.encoder(x, mask)
        decoder_labels = labels.clone()        
        
        #??
        # if self.training:
        #     m = torch.rand(labels.shape) < self.decoder_mask_aug
        #     decoder_labels[m] = 62

        logits = self.decoder(x,
                            labels=decoder_labels, 
                            encoder_attention_mask=mask.long()
                            )
        
        loss = self.loss_fn(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1))

        output = {'loss':loss}

        if not self.training:
            generated_ids_padded = torch.ones((x.shape[0],self.max_phrase), dtype=torch.long, device=x.device) * 59
            
            if self.val_mode == 'padded':
                generated_ids = self.decoder.generate(x,max_new_tokens=self.max_phrase + 1, encoder_attention_mask=mask.long())
                    
            elif self.val_mode == 'cutted':
                generated_ids = torch.ones((x.shape[0],self.max_phrase+1), dtype=torch.long, device=x.device) * 59
                mask_lens = mask.sum(1)
                for lidx in mask_lens.unique():
                    liddx = lidx == mask_lens
                    preds = self.decoder.generate(x[liddx, :lidx],max_new_tokens=self.max_phrase + 1)
                    generated_ids[liddx, :preds.shape[1]] = preds
                    
            cutoffs = (generated_ids==self.decoder.decoder_end_token_id).float().argmax(1).clamp(0,self.max_phrase)
            for i, c in enumerate(cutoffs):
                generated_ids_padded[i,:c] = generated_ids[i,:c]
            output['generated_ids'] = generated_ids_padded
            output['seq_len'] = batch['seq_len']    
        return output

In [89]:
input_dict = {'lhand': torch.rand(1,380,21,3),
            'rhand':torch.rand(1,380,21,3),
            'lpose':torch.rand(1,380,5,3),
            'rpose':torch.rand(1,380,5,3),
            'token_ids':(torch.rand(1,32)*52).long(),
            'input_mask': torch.ones_like(torch.rand(1,380)),
            'attention_mask': torch.ones_like(torch.rand(1,32))         
}
model = Net()

print(model(input_dict))

n_params: 4050016


AttributeError: 'SqueezeformerEncoder' object has no attribute 'blocks'

In [76]:
def train(model, train_loader, loss_fn, optimizer):
    total_loss = 0
    all_predictions = []
    all_targets = []
    loss_history = []

    model = model.to(DEVICE)
    # set model to training mode
    model.train()  

    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)
        loss = loss_fn(outputs, targets.to(DEVICE))
        loss.backward()
        optimizer.step()

        # track some values to compute statistics
        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=-1)
        all_predictions.extend(preds.detach().cpu().tolist())
        all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(train_loader)
    
    # print average loss and accuracy
    print(f"learning Rate = {optimizer.param_groups[0]['lr']}. average train loss = {final_loss:.2f}, average train accuracy = {acc * 100:.3f}%")
    return acc, final_loss

def validation(model, val_loader, loss_fn,epoch):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    # set model to evaluation mode
    model.eval()  
    for i, (inputs, targets) in enumerate(val_loader):
        with torch.no_grad():
            model.to(DEVICE)
            outputs = model(inputs.to(DEVICE))
            loss = loss_fn(outputs, targets.to(DEVICE))

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=-1)
            all_predictions.extend(preds.detach().cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(val_loader)
    # Print average loss and accuracy
    print(f"Epoch {epoch + 1} done. average validation loss = {final_loss:.2f}, average validation accuracy = {acc * 100:.3f}%")
    return acc, final_loss

def test(model, test_loader, loss_fn):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    # set model to evaluation mode
    model.eval()
    for i, (inputs, targets) in enumerate(test_loader):
        with torch.no_grad():
            model.to(DEVICE)
            outputs = model(inputs.to(DEVICE))
            loss = loss_fn(outputs, targets.to(DEVICE))

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=-1)
            all_predictions.extend(preds.detach().cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(test_loader)
    # Print average loss and accuracy
    print(f'average test loss = {final_loss:.2f}, average test accuracy = {acc * 100:.3f}%')
    return acc, final_loss

In [77]:
def train_model(
                data_loaders: tuple = None,
                learning_rate: float = 1e-4, 
                weight_decay: float = 1e-4, 
                features: dict = None
               ):
    
    if not features:
        features = FEATURES
        
    if not data_loaders:
        train_loader, val_loader, test_loader = get_data_loaders(features, VAL_SPLIT, TEST_SPLIT, BATCH_SIZE)
    else:
        train_loader, val_loader, test_loader = data_loaders

    model = None
    optimizer = None
    scheduler = None

    max_epochs = 16
    best_acc = -1
    dip_count = 0
    num_epochs = 0
    for e in range(max_epochs):
        print(f'EPOCH {e}:')
        train_acc, train_loss = train(model, train_loader, loss_fn, optimizer)
        scheduler.step()
        
         # early stopping based on train acc
        if e%2 == 0:
            val_acc, val_loss = validation(model, val_loader, loss_fn)
            if val_acc >= best_acc:
                best_acc = val_acc
                dip_count = 0
            else:
                dip_count +=1

        if dip_count >1:
            break

    val_acc, val_loss = validation(model, val_loader, loss_fn)
    print(f'Final Accuracy: {val_acc}')
    return val_acc