In [None]:
# If U are using SageMaker Prepare for the dataset
!pip install awscli
!aws s3 cp s3://handata/ref_youtube_audio/ ref_youtube_audio/ --recursive

In [None]:
!pip install transformers
!pip install -U openai-whisper
!pip install librosa

In [None]:
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
import torch
import torch.nn as nn
import whisper
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import whisper
import pandas as pd
from categories import ytvos_category_dict
import numpy as np
from util import read_aws_json,read_aws_wav,read_local_json,read_local_wav
import logging
from torch import optim
from losses import get_loss_func
from utils.evaluate import Evaluator
from util import infoNCE_loss
SageMaker = True
Local = False

In [None]:
class Audio_Encoder(nn.Module):
    def __init__(self, feature_extractor, model, num_class=66,dropout_prob=0.2,pool_num = 100,bias = True):
        super().__init__()
        self.num_class = num_class
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = feature_extractor
        self.encoder = model.encoder
        for name, param in self.encoder.named_parameters():
          param.requires_grad = False
        self.projector = nn.Linear(in_features=768, out_features=256, bias=True)
        self.classifier = nn.Linear(256, num_class)

        self.avg_pool = nn.AvgPool2d(kernel_size=(pool_num,1), stride=(pool_num,1))
        # self.norm_layer = nn.LayerNorm(256, eps=1e-5, bias=True)
        self.batchnorm = nn.BatchNorm1d(2048, affine=False)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.dropout2 = nn.Dropout(0.5)

        self.fc1 = nn.Linear(1500//pool_num * 256, 2048)
        self.fc2 = nn.Linear(2048, 256)
        self.fc3 = nn.Linear(256, 65)

    def forward(self, audios):
        input_features = []
        for audio in audios:

            feature = self.feature_extractor(audio.cpu(),sampling_rate=16000,return_tensors="pt").input_features
            input_features.append(feature)

        input_features = torch.cat(input_features, dim=0).to(self.device)
        hidden_states = self.encoder(input_features)
        # hidden_states = self.projector(hidden_states)
        # pooled_output = hidden_states.mean(dim=1)
        # logits = self.classifier(pooled_output)

        x = self.avg_pool(hidden_states)

        x = self.projector(x)
        # x = self.positionencoding(x)
        feature = x.reshape(x.shape[0], -1)

        x = self.dropout(feature)

        x = self.fc1(x)
        # x = self.batchnorm(x)
        x = self.dropout(x)
        x = self.fc2(x)

        x = self.dropout(x)
        x = self.fc3(x)

        output_dict = {
            'clipwise_output': x,
            'feature': feature,
            'embedding': hidden_states}

        return output_dict


In [29]:
class ytvos_Dataset(Dataset):
    def __init__(self, data_frame: pd.DataFrame, sr=44100, num_class=65):
        self.data_frame = data_frame
        self.sr = sr
        self.num_class = num_class
        self.data_root = '/home/user/SED_Adaptation_Classifier-main/data/ref_youtube_audio/audio'

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        audio_name = self.data_frame.iloc[index]["video"]
        audio_id = self.data_frame.iloc[index]["audio"]
        audio_path = 'ref_youtube_audio/audio' + '/' + audio_name + '/' + audio_id + '.wav'
        name = audio_name + self.data_frame.iloc[index]["exp"]

        waveform = read_wav(file_key)
#         waveform = whisper.load_audio(audio_path,sr = 16000)

        tag = self.data_frame.iloc[index]["category"]
        target = ytvos_category_dict[self.data_frame.iloc[index]["category"]]
        target = np.eye(self.num_class)[target]
        data_dict = {'audio_name': name, 'waveform': waveform, 'target': target, 'tag': tag}

        return data_dict
def get_datalist(cur_iter):
        task_id = cur_iter
        task_train_metas = []
        task_test_metas = []

        if SageMaker:
            metas = read_aws_json('task_split_1/metas.json')['metas']
            tasks = read_aws_json('task_split_1/task{}.json'.format(task_id))[str(task_id)]

        for category,task_metas_dict in tasks.items():
            train_ids = task_metas_dict['train']
            test_ids = task_metas_dict['test']
            for train_id in train_ids:
                task_train_metas.append(metas[train_id])
            for test_id in test_ids:
                task_test_metas.append(metas[test_id])

        return task_train_metas,task_test_metas
    
def default_collate_fn(batch):
    audio_name = [data['audio_name'] for data in batch]
    waveform = [torch.from_numpy(data['waveform']) for data in batch]
    target = [data['target'] for data in batch]

    # waveform = torch.FloatTensor(waveform)
    # waveform = pad_sequence(waveform, batch_first=True, padding_value=0)
    target = torch.FloatTensor(target)

    return {'audio_name': audio_name, 'waveform': waveform, 'target': target}

def get_dataloader(data_frame, dataset, split, batch_size, num_class, num_workers=8):
    assert dataset == "ref_youtube_audio"
    dataset = ytvos_Dataset(data_frame=data_frame)
    return DataLoader(dataset=dataset, batch_size=batch_size,
                      shuffle=True, drop_last=False,
                      num_workers=num_workers, collate_fn=default_collate_fn)

def get_train_test_dataloader(batch_size, n_worker, train_list, test_list):
    train_loader = get_dataloader(pd.DataFrame(train_list), 'ref_youtube_audio', split='train', batch_size=batch_size, num_class=66,
                                  num_workers=n_worker)
    test_loader = get_dataloader(pd.DataFrame(test_list), 'ref_youtube_audio', split='test', batch_size=batch_size, num_class=66,
                                 num_workers=n_worker)
    return train_loader, test_loader




In [1]:
!pip install audiomentations
from audiomentations import Compose, Gain, AddGaussianNoise, PitchShift,TimeStretch,Shift


Collecting audiomentations
  Downloading audiomentations-0.36.0-py3-none-any.whl.metadata (10 kB)
Collecting librosa!=0.10.0,<0.11.0,>=0.8.0 (from audiomentations)
  Downloading librosa-0.10.2.post1-py3-none-any.whl.metadata (8.6 kB)
Collecting soxr<1.0.0,>=0.3.2 (from audiomentations)
  Downloading soxr-0.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting audioread>=2.1.9 (from librosa!=0.10.0,<0.11.0,>=0.8.0->audiomentations)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting soundfile>=0.12.1 (from librosa!=0.10.0,<0.11.0,>=0.8.0->audiomentations)
  Downloading soundfile-0.12.1-py2.py3-none-manylinux_2_17_x86_64.whl.metadata (14 kB)
Collecting pooch>=1.1 (from librosa!=0.10.0,<0.11.0,>=0.8.0->audiomentations)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting lazy-loader>=0.1 (from librosa!=0.10.0,<0.11.0,>=0.8.0->audiomentations)
  Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Col