In [3]:
!accelerate launch run_sign2vec_pretraining.py \
						--model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \
						--output_dir="./sign2vec" \
						--max_train_steps="2" \
						--num_warmup_steps="3" \
						--gradient_accumulation_steps="4" \
						--learning_rate="0.001" \
						--weight_decay="0.01" \
						--max_duration_in_seconds="20.0" \
						--min_duration_in_seconds="2.0" \
						--logging_steps="1" \
						--saving_steps="10000" \
						--per_device_train_batch_size="8" \
						--per_device_eval_batch_size="8" \
						--adam_beta1="0.9" \
						--adam_beta2="0.98" \
						--adam_epsilon="1e-06" \
						--gradient_checkpointing \
						--mask_time_prob="0.65" \
						--mask_time_length="10" \
						--use_face \
						--use_hands \
						--use_pose \
						--train_info_path="../sign2vec/config/info.json" \
						--train_data_path="../sign2vec/features" \
						--validation_info_path="../sign2vec/config/info.json" \
						--validation_data_path="../sign2vec/features" \
						--config_name="config.json"

The following values were not passed to `accelerate launch` and had defaults used instead:
	`--num_cpu_threads_per_process` was set to `12` to improve out-of-box performance when training on CPUs
None
[34m[1mwandb[0m: Currently logged in as: [33mkarahan-sahin[0m ([33mboun-pilab[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/karahansahin/Documents/Research/sign2vec/pretraining/wandb/run-20240619_094634-lagvyvl8[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mprime-plasma-20[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/boun-pilab/sign2vec[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/boun-pilab/sign2vec/runs/lagvyvl8

In [2]:
from utils.modeling_sign2vec import (
    Sign2VecNoLayerNormConvLayer,
    Sign2VecGroupNormConvLayer,
    Sign2VecLayerNormConvLayer
)

from torch import nn

class Sign2VecFeatureEncoder(nn.Module):
    """Construct the features from raw audio waveform"""

    def __init__(self, config):
        super().__init__()

        if config.feat_extract_norm == "group":
            conv_layers = [Sign2VecGroupNormConvLayer(config, layer_id=0)] + [
                Sign2VecNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
            ]
        elif config.feat_extract_norm == "layer":
            conv_layers = [
                Sign2VecLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
            ]
        else:
            raise ValueError(
                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
            )
        self.conv_layers = nn.ModuleList(conv_layers)
        self.gradient_checkpointing = False
        self._requires_grad = True

    def _freeze_parameters(self):
        for param in self.parameters():
            param.requires_grad = False
        self._requires_grad = False

    def forward(self, input_values):
        hidden_states = input_values[None, :]

        # make sure hidden_states require grad for gradient_checkpointing
        if self._requires_grad and self.training:
            hidden_states.requires_grad = True

        for conv_layer in self.conv_layers:
            print('Layer input:', hidden_states.shape)
            print('--->')
            print(conv_layer)
            print('--->')

            if self._requires_grad and self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    conv_layer.__call__,
                    hidden_states,
                )
            else:
                hidden_states = conv_layer(hidden_states)

            print('Layer output:', hidden_states.shape)
            print('======')
        return hidden_states

In [3]:
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2NoLayerNormConvLayer,
    Wav2Vec2GroupNormConvLayer,
    Wav2Vec2LayerNormConvLayer
)

from torch import nn

class Wav2Vec2FeatureEncoder(nn.Module):
    """Construct the features from raw audio waveform"""

    def __init__(self, config):
        super().__init__()

        if config.feat_extract_norm == "group":
            conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
                Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
            ]
        elif config.feat_extract_norm == "layer":
            conv_layers = [
                Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
            ]
        else:
            raise ValueError(
                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
            )
        self.conv_layers = nn.ModuleList(conv_layers)
        self.gradient_checkpointing = False
        self._requires_grad = True

    def _freeze_parameters(self):
        for param in self.parameters():
            param.requires_grad = False
        self._requires_grad = False

    def forward(self, input_values):
        hidden_states = input_values[None, :]

        # make sure hidden_states require grad for gradient_checkpointing
        if self._requires_grad and self.training:
            hidden_states.requires_grad = True

        for conv_layer in self.conv_layers:
            print(hidden_states.shape, '--->')
            print(conv_layer)
            if self._requires_grad and self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    conv_layer.__call__,
                    hidden_states,
                )
            else:
                hidden_states = conv_layer(hidden_states)

        return hidden_states

In [1]:
import torch
from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from datasets import load_dataset

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values  # Batch size 1

# compute masked indices
batch_size, raw_sequence_length = input_values.shape
sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
mask_time_indices = _compute_mask_indices(
    shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
)
sampled_negative_indices = _sample_negative_indices(
    features_shape=(batch_size, sequence_length),
    num_negatives=model.config.num_negatives,
    mask_time_indices=mask_time_indices,
)
mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=input_values.device, dtype=torch.long
)
# for contrastive loss training model should be put into train mode

print(input_values.shape)
model = model.train()
loss = model(
    input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
).loss

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForPreTraining: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForPreTraining were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should prob

torch.Size([1, 93680])
Layer out -> torch.Size([1, 512, 18735])
Layer out -> torch.Size([1, 512, 9367])
Layer out -> torch.Size([1, 512, 4683])
Layer out -> torch.Size([1, 512, 2341])
Layer out -> torch.Size([1, 512, 1170])
Layer out -> torch.Size([1, 512, 585])
Layer out -> torch.Size([1, 512, 292])


In [9]:
import json
import torch
from utils.config import Sign2VecConfig
from utils.modeling_sign2vec import Sign2VecFeatureEncoder

config_dict = json.load(open("config.json", "r"))

config = Sign2VecConfig(**config_dict)

In [10]:
from utils.modeling_sign2vec import Sign2VecForPreTraining
from utils.bobsl import BOBSLDataset

model = Sign2VecForPreTraining(config)

In [11]:

channel_size = {
    'face_keypoints_2d': 70,
    'hand_left_keypoints_2d': 21,
    'hand_right_keypoints_2d': 21,
    'pose_keypoints_2d': 25,
}


# 1. Set the correct target sampling rate
sampling_rate = int(
    channel_size['pose_keypoints_2d'] * 2 +
    channel_size['face_keypoints_2d'] * 2 + 
    channel_size['hand_left_keypoints_2d'] * 2  +
    channel_size['hand_right_keypoints_2d'] * 2 
) * 25

In [12]:
from torch.utils.data import DataLoader
from utils.feature_extraction_wav2vec2 import Sign2VecFeatureExtractor
from utils.collator import DataCollatorForWav2Vec2Pretraining


feature_extractor = Sign2VecFeatureExtractor(
    feature_size=config.input_dim,
    sampling_rate=sampling_rate,
    padding_value=0,
    do_normalize=True,
    return_tensors="pt",
)

vectorized_datasets = {
    'train': BOBSLDataset(
        data_path="../sign2vec/features" ,
        info_path="../sign2vec/config/info.json",
        use_face=True,
        use_hands=True,
        use_pose=True,
        stride=20,
        max_length=int(20) * sampling_rate,
        sampling_rate=sampling_rate,
        feature_extractor=feature_extractor,
    )
}

# Activate gradient checkpointing if needed
if True:
    model.gradient_checkpointing_enable()

# 4. Define data collator, optimizer and scheduler

mask_time_prob = config.mask_time_prob if 0.65 is None else 0.65
mask_time_length = config.mask_time_length if 10 is None else 10

data_collator = DataCollatorForWav2Vec2Pretraining(
    model=model,
    feature_extractor=feature_extractor,
    pad_to_multiple_of=None,
    mask_time_prob=mask_time_prob,
    mask_time_length=mask_time_length,
)
train_dataloader = DataLoader(
    vectorized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

BOBSLDataset
Sampling rate: 6850
Max frame diff: 10
Max length: 137000
FPS: 25
Max Frames: 500
Info path: ../sign2vec/config/info.json
Data path: ../sign2vec/features
Loaded 9150 training samples


  mask_time_prob = config.mask_time_prob if 0.65 is None else 0.65
  mask_time_length = config.mask_time_length if 10 is None else 10


In [13]:
v = next(iter(vectorized_datasets["train"]))

In [14]:
v['input_values'].shape

(274, 500)

In [15]:
for batch in train_dataloader:
    print(batch.input_values.shape)
    break

torch.Size([8, 274, 500])


In [17]:
out = model(
    **batch
)

[sign2vec-pretrain] input_values torch.Size([8, 274, 500])
[sign2vec] input_values torch.Size([8, 274, 500])
[Sign2VecFeatureEncoder]: Forward
Input values ->  torch.Size([8, 274, 500])
Layer out ->  torch.Size([8, 512, 498])
Layer out ->  torch.Size([8, 512, 166])
Layer out ->  torch.Size([8, 512, 55])
[sign2vec] extract_features torch.Size([8, 55, 512])


In [21]:
out.keys()

odict_keys(['loss', 'projected_states', 'projected_quantized_states', 'codevector_perplexity', 'contrastive_loss', 'diversity_loss'])

In [23]:
out.projected_states.shape

torch.Size([8, 55, 256])