In [1]:
# %pip install --upgrade transformers
# %pip install datasets
# %pip install huggingface_hub

In [2]:
# %pip install --upgrade pip
# %pip install boltons

In [3]:
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import torchaudio

In [4]:
import transformers
import json
from torch.utils.data import Dataset, DataLoader, Sampler
import os

In [5]:
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel
from utils import ConstrativeLoss, sample_negatives

In [6]:
from model_transformers import SegmentTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
torch.multiprocessing.set_sharing_strategy('file_system')
from tqdm import tqdm
import numpy as np
import os
from os.path import join, basename
from boltons.fileutils import iter_find_files
import soundfile as sf
import librosa
import pickle
from multiprocessing import Pool
import random
import torchaudio
import math
from torchaudio.datasets import LIBRISPEECH

In [8]:
transformers.__version__

'4.14.0'

In [9]:
torch.manual_seed(0)

<torch._C.Generator at 0x7febb14cf150>

In [10]:
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class WavPhnDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data = list(iter_find_files(self.path, "*.wav"))
        print(self.path)
        print(len(self.data))
        super(WavPhnDataset, self).__init__()

    @staticmethod
    def get_datasets(path):
        raise NotImplementedError

    def process_file(self, wav_path):
        phn_path = wav_path.replace("WAV.wav", "PHN")

        # load audio
        audio, sr = torchaudio.load(wav_path)
        audio = audio[0]
        audio_len = len(audio)
        
        filetext_id = phn_path.replace('.PHN', '.txt').replace('./', ''). replace('/', '_')
        
        # load labels -- segmentation and phonemes
        with open(phn_path, "r") as f:
            lines = f.readlines()
            lines = list(map(lambda line: line.split(" "), lines))

            # get segment times
            times = torch.FloatTensor(list(map(lambda line: int(line[1])/16000, lines)))[:-1]
            
            # get phonemes in each segment
            phonemes = list(map(lambda line: line[2].strip(), lines))

        return audio, times.tolist(), phonemes, wav_path, phn_path, filetext_id

    def __getitem__(self, idx):
        signal, seg, phonemes, fname, segment_file, filetext_id = self.process_file(self.data[idx])
        
        return {'audio_file':fname, 
                'segment_file':segment_file, 
                'id':filetext_id, 
                'sample': signal, 
                'length': len(signal), 'boundaries': seg}

    def __len__(self):
        return len(self.data)

In [11]:
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class TrainTestDataset(WavPhnDataset):
    def __init__(self, path):
        super(TrainTestDataset, self).__init__(path)

    @staticmethod
    def get_datasets(path, val_ratio=0.1):
        train_dataset = TrainTestDataset(os.path.join(path, 'TRAIN'))
        test_dataset  = TrainTestDataset(os.path.join(path, 'TEST'))

        train_len   = len(train_dataset)
        train_split = int(train_len * (1 - val_ratio))
        val_split   = train_len - train_split
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_split, val_split])

        train_dataset.path = os.path.join(path, 'train')
        val_dataset.path = os.path.join(path, 'train')

        return train_dataset, val_dataset, test_dataset

In [12]:
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    
    return dict(batch=batch, lengths=lengths, attention_mask=attention_mask, boundaries=boundaries)

In [13]:
timit_path = './timit/data'

In [14]:
train_dataset, val_dataset, test_dataset = TrainTestDataset.get_datasets(path=timit_path)

./timit/data/TRAIN
4620
./timit/data/TEST
1680


In [15]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [40]:
accumulate_grad_batches = 10

In [41]:
cfg = {'model_path':"facebook/wav2vec2-base-960h",
       'mask': False,
       'optimizer': "adam",
       'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
'epochs': 20,
'grad_clip': 2,
'batch_size': 8,

'conv_args': {},
'mask_args': {},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 10, "loss_args": {"reduction": "mean"}},
'num_epoch': 2}

In [42]:
class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [43]:
import warnings
warnings.filterwarnings("ignore")
from collections import OrderedDict

In [44]:
AVAIL_GPUS = min(1, torch.cuda.device_count())

train_dataset, val_dataset, test_dataset = TrainTestDataset.get_datasets(path=timit_path)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

print(len(train_dataset))
print(len(val_dataset))

model = SegmentTransformer(config)

logger = TensorBoardLogger("timit_r_val", name="my_model")

checkpoint_callback = ModelCheckpoint(
    monitor="val_r_metr",
    dirpath="./",
    filename="timit_r_val",
    save_top_k=3,
    mode="max",
)
trainer = Trainer(max_epochs=20,
                  gpus=AVAIL_GPUS,
                  progress_bar_refresh_rate=1,
                  logger=logger,
                  accumulate_grad_batches=accumulate_grad_batches,
                  gradient_clip_val=2, 
                  callbacks=[checkpoint_callback])


./timit/data/TRAIN
4620
./timit/data/TEST
1680
4158
462


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ModelForSegmentation: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ModelForSegmentation were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False

In [46]:
%tensorboard --logdir ./timit_r_val

Launching TensorBoard...


In [47]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                         | Params
--------------------------------------------------------------
0 | wav2vec_segm | Wav2Vec2ModelForSegmentation | 94.4 M
--------------------------------------------------------------
94.4 M    Trainable params
0         Non-trainable params
94.4 M    Total params
377.487   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0.0005
)

