In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchaudio
from torch.utils.data import DataLoader, random_split
import transformers
from torchinfo import summary
from ibm_dataset import IBMDebater
transformers.logging.set_verbosity_error()

In [2]:
bundle = torchaudio.pipelines.WAV2VEC2_BASE

data_path = 'data/ibm_debater/full'

data = IBMDebater(data_path, 'train', audio_bundle=bundle, load_text=False)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

In [3]:
batch_size = 8
loader_train = DataLoader(data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    #collate_fn=utils.batch_generator_bert,
                    drop_last=True)
loader_val = DataLoader(data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    #collate_fn=utils.batch_generator_bert,
                    drop_last=True)

In [4]:
model = bundle.get_model().cuda()

for params in model.parameters():
    params.requires_grad = False

summary(model)

Layer (type:depth-idx)                                  Param #
Wav2Vec2Model                                           --
├─FeatureExtractor: 1-1                                 --
│    └─ModuleList: 2-1                                  --
│    │    └─ConvLayerBlock: 3-1                         (6,144)
│    │    └─ConvLayerBlock: 3-2                         (786,432)
│    │    └─ConvLayerBlock: 3-3                         (786,432)
│    │    └─ConvLayerBlock: 3-4                         (786,432)
│    │    └─ConvLayerBlock: 3-5                         (786,432)
│    │    └─ConvLayerBlock: 3-6                         (524,288)
│    │    └─ConvLayerBlock: 3-7                         (524,288)
├─Encoder: 1-2                                          --
│    └─FeatureProjection: 2-2                           --
│    │    └─LayerNorm: 3-8                              (1,024)
│    │    └─Linear: 3-9                                 (393,984)
│    │    └─Dropout: 3-10                          

In [None]:
model.eval()
device = 'cuda'
with torch.inference_mode():
    for data in loader_train:
        wave, labels = [x.to(device) for x in data]
        emission, _ = model(wave.to('cuda'))
        print(emission.shape)