In [1]:
# %pip install --upgrade transformers
# %pip install datasets
# %pip install huggingface_hub

In [2]:
# %pip install --upgrade pip
# %pip install boltons

In [3]:
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import torchaudio

In [4]:
import transformers
import json
from torch.utils.data import Dataset, DataLoader, Sampler
import os

In [5]:
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel
from utils import ConstrativeLoss, sample_negatives
# from trainer import Trainer

In [6]:
from model_transformers import SegmentTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
torch.multiprocessing.set_sharing_strategy('file_system')
from tqdm import tqdm
import numpy as np
import os
from os.path import join, basename
from boltons.fileutils import iter_find_files
import soundfile as sf
import librosa
import pickle
from multiprocessing import Pool
import random
import torchaudio
import math
from torchaudio.datasets import LIBRISPEECH

In [8]:
transformers.__version__

'4.14.0'

In [9]:
# %pip list

In [10]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f6b30405150>

In [11]:
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class WavPhnDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data = list(iter_find_files(self.path, "*.wav"))
        super(WavPhnDataset, self).__init__()

    @staticmethod
    def get_datasets(path):
        raise NotImplementedError

    def process_file(self, wav_path):
#         print(wav_path)
        phn_path = wav_path.replace(".wav", ".txt")

        # load audio
        audio, sr = torchaudio.load(wav_path)
        audio = audio[0]
        audio_len = len(audio)

        # load labels -- segmentation and phonemes
        with open(phn_path, "r") as f:
            lines = [i.strip() for i in f.readlines()]
#             lines = lines[lines.index('#')+1:]

            times = torch.FloatTensor([eval(i.split()[0]) for i in lines])[:-1]
            phonemes = torch.FloatTensor([eval(i.split()[0])*16000 for i in lines])[:-1]

        return audio, times, phonemes, wav_path

    def __getitem__(self, idx):
        signal, seg, phonemes, fname = self.process_file(self.data[idx])
        
        return {'sample': signal, 'length': len(signal), 'boundaries': seg}

    def __len__(self):
        return len(self.data)

In [13]:
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    
    return dict(batch=batch, lengths=lengths, attention_mask=attention_mask, boundaries=boundaries)

In [12]:
train_dataset = WavPhnDataset('Buckeye_fin/Train')
val_dataset = WavPhnDataset('Buckeye_fin/Valid')
test_dataset = WavPhnDataset('Buckeye_fin/Test')

Buckeye_fin/Train
38656
Buckeye_fin/Valid
4267
Buckeye_fin/Test
4015


In [14]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [30]:
accumulate_grad_batches = 10

In [31]:
cfg = {'model_path':"facebook/wav2vec2-base-960h",
       'mask': False,
       'optimizer': "adam",
       'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
'epochs': 20,
'grad_clip': 2,
'batch_size': 8,

'conv_args': {},
'mask_args': {},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 10, "loss_args": {"reduction": "mean"}},
'num_epoch': 2}

In [32]:
class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [33]:
import warnings
warnings.filterwarnings("ignore")
from collections import OrderedDict

In [34]:
AVAIL_GPUS = min(1, torch.cuda.device_count())

train_dataset = WavPhnDataset('Buckeye_fin/Train')
val_dataset = WavPhnDataset('Buckeye_fin/Valid')
test_dataset = WavPhnDataset('Buckeye_fin/Test')

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=4, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=4, collate_fn = collate_fn)

model = SegmentTransformer(config)

logger = TensorBoardLogger("buckeye_r_val", name="my_model")

checkpoint_callback = ModelCheckpoint(
    monitor="val_r_metr",
    dirpath="./",
    filename="buckeye_r_val",
    save_top_k=3,
    mode="max",
)
trainer = Trainer(max_epochs=20,
                  gpus=AVAIL_GPUS,
                  progress_bar_refresh_rate=1,
                  logger=logger,
                  accumulate_grad_batches=accumulate_grad_batches,
                  gradient_clip_val=2, 
                  callbacks=[checkpoint_callback])


Buckeye_fin/Train
38656
Buckeye_fin/Valid
4267
Buckeye_fin/Test
4015
38656
4267


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ModelForSegmentation: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ModelForSegmentation were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False

In [39]:
%tensorboard --logdir ./buckeye_r_val

Launching TensorBoard...


In [37]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                         | Params
--------------------------------------------------------------
0 | wav2vec_segm | Wav2Vec2ModelForSegmentation | 94.4 M
--------------------------------------------------------------
94.4 M    Trainable params
0         Non-trainable params
94.4 M    Total params
377.487   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0.0005
)

