# Libraries

In [1]:
cd /workdir/github

/workdir/github


In [2]:
from speaker_verification.transforms import Audio_Transforms
from speaker_verification.transforms import Image_Transforms
from speaker_verification.models import Model
from speaker_verification.dataset import SpeakingFacesDataset
from speaker_verification.dataset import ValidDataset
from speaker_verification.sampler import ProtoSampler
from speaker_verification.loss import PrototypicalLoss
from speaker_verification.train import train_model

In [3]:
import torch
import torchaudio
from torch.utils.data import DataLoader
import numpy as np
import os

In [4]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
torch.set_num_threads(1)

# Check pipeline

In [6]:
device = torch.device(f"cuda:2" if torch.cuda.is_available() else "cpu")

## Check model

### Model 

In [4]:
from transformers import AutoModelForAudioClassification
from transformers import WavLMForXVector
import torch.nn as nn
import torchvision

class Model(nn.Module):
    """
        Parameters
        ----------
        pretrained_weights : bool, default = "True"
            Ways of weights initialization. 
            If "False", it means random initialization and no pretrained weights,
            If "True" it means resnet34 pretrained weights are used.

        fine_tune: bool, default = "False"
            Allows to choose between two types of transfer learning: fine tuning and feature extraction.
            For more details of the description of each mode, 
            read https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

        embedding_size: int, default = 128
            Size of the embedding of the last layer
            
    """

    def __init__(self, 
                library,
                pretrained_weights,
                fine_tune,
                embedding_size,
                model_name,
                pool, 
                data_type):

        super(Model, self).__init__()

        self.library = library
        self.pretrained_weights = pretrained_weights
        self.fine_tune = fine_tune
        self.embedding_size = embedding_size
        self.model_name = model_name
        self.pool = pool
        self.data_type = data_type
        
        if len(data_type) == 1:

            if data_type[0] == "wav":
                print("wav data type")
                self.model = self.wav_model()
            elif data_type[0] == "rgb":
                print("rgb data type")
                self.model = self.image_model()
            elif data_type[0] == "thr":
                print("thr data type")
                self.model = self.image_model()

    def forward(self, x):
        if self.library == "huggingface":
            x = self.model(x).logits
        else:
            x = self.model(x)
        return x

    def image_model(self):

        if self.library == "huggingface":
            print("HuggingFace model is used.")
            pass
        elif self.library == "pytorch":
            print("pytorch model is used.")
            model = self.pytorch_model(in_channels = 3)
        elif self.library == "timm":
            print("timm model is used.")
            model = self.timm_model(in_channels = 3)

        return model
    
    def wav_model(self):

        if self.library == "huggingface":
            print("HuggingFace model is used.")
            model = self.huggingface_model()
        elif self.library == "pytorch":
            print("pytorch model is used.")
            model = self.pytorch_model(in_channels = 1)
        elif self.library == "timm":
            print("timm model is used.")
            model = self.timm_model(in_channels = 1)

        return model
    
    def huggingface_model(self):
        if self.model_name == "WavLM":
            print("WavLM model is used.")
            model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-sv')
            model.classifier = nn.Linear(model.classifier.in_features, self.embedding_size)

            if self.fine_tune:
                for param in model.parameters():
                    param.requires_grad = True
            else:
                for param in model.parameters():
                    param.requires_grad = False
                
                model.classifier.weight.requires_grad = True
                model.classifier.bias.requires_grad = True
        elif self.model_name == "AST":
            print("AST model is used.")
            model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
            model.classifier.dense = nn.Linear(model.classifier.dense.in_features, self.embedding_size)

            if self.fine_tune:
                for param in model.parameters():
                    param.requires_grad = True
            else:
                for param in model.parameters():
                    param.requires_grad = False
                
                model.classifier.dense.weight.requires_grad = True
                model.classifier.dense.bias.requires_grad = True

        return model

    def timm_model(self, in_channels):
        model = timm.create_model(self.model_name, pretrained=self.pretrained_weights, num_classes=self.embedding_size, in_chans=in_channels)

        if self.pool == "SAP":
            model.global_pool = SelfAttentivePool2d()
        if self.fine_tune:
            for param in model.parameters():
                param.requires_grad = True
        else:
            for param in model.parameters():
                param.requires_grad = False

            model.get_classifier().weight.requires_grad = True
            model.get_classifier().bias.requires_grad = True

        return model

    def pytorch_model(self, in_channels):
        if self.model_name == "resnet34":
            if self.pretrained_weights:
                weights = torchvision.models.ResNet34_Weights.DEFAULT
            else:
                weights = None

            model = torchvision.models.resnet34(weights=weights)
            
            model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            model.fc = nn.Linear(model.fc.in_features, self.embedding_size)
            
            if self.pool == "SAP":
                model.avgpool = SelfAttentivePool2d(model.fc.in_features)

        if self.fine_tune:
            for param in model.parameters():
                param.requires_grad = True
        else:
            for param in model.parameters():
                param.requires_grad = False
            
            model.fc.weight.requires_grad = True
            model.fc.bias.requires_grad = True

        return model

### Check wav models

#### check huggingface models

##### check AST

In [5]:
library = "huggingface"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="AST"
pool=None
data_type=["wav"]

In [6]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size, 
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
HuggingFace model is used.
AST model is used.


In [13]:
from transformers import ASTFeatureExtractor

input = torch.rand(45789)
feature_extractor = ASTFeatureExtractor()
input = feature_extractor(input, sampling_rate=16000, padding=True, return_tensors="pt")
input = input.input_values.squeeze()

In [14]:
out = model(input.unsqueeze(dim=0).to(device))

In [15]:
out.shape

torch.Size([1, 128])

##### check WavLM

In [16]:
library = "huggingface"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="WavLM"
pool=None
data_type=["wav"]

In [17]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
HuggingFace model is used.
WavLM model is used.


In [18]:
from transformers import Wav2Vec2FeatureExtractor

input = torch.rand(45789)
feature_extractor = Wav2Vec2FeatureExtractor()
input = feature_extractor(input, sampling_rate=16000, padding=True, return_tensors="pt")
input = input.input_values.squeeze()

In [19]:
out = model(input.unsqueeze(dim=0).to(device))

In [20]:
out.shape

torch.Size([1, 128])

#### check timm

##### check resnet

In [21]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool=None
data_type=["wav"]

# audio
sample_rate=16000
sample_duration=2 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

In [22]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
timm model is used.


In [23]:
import torchaudio

input = torch.rand((1,45789))

to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_fn=torch.hamming_window,
                n_mels=n_mels
            )
            
input = to_MelSpectrogram(input)

In [24]:
input.shape

torch.Size([1, 40, 287])

In [25]:
out = model(input.unsqueeze(dim=0).to(device))

In [26]:
out.shape

torch.Size([1, 128])

##### check resnet34 + SAP

In [27]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool="SAP"
data_type=["wav"]

# audio
sample_rate=16000
sample_duration=2 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

In [28]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
timm model is used.


In [29]:
import torchaudio

input = torch.rand((1,45789))

to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_fn=torch.hamming_window,
                n_mels=n_mels
            )
            
input = to_MelSpectrogram(input)

In [30]:
out = model(input.unsqueeze(dim=0).to(device))

In [31]:
out.shape

torch.Size([1, 128])

##### check vit_base_patch16_224

In [33]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="vit_base_patch16_224"
pool=None
data_type=["wav"]

# audio
sample_rate=16000
sample_duration=3 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

In [34]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
timm model is used.


In [35]:
import torchaudio
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import torchvision.transforms as T
import cv2

input = torch.rand((1,45789))

to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_fn=torch.hamming_window,
                n_mels=n_mels
            )

transform=T.Compose([
    T.ToPILImage(),
    T.Resize(size=256, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias=None),
    T.CenterCrop(size=(224, 224)),
    T.ToTensor(),
    # T.Normalize(mean=torch.tensor([0.4850]), std=torch.tensor([0.2290]))
])
            
input = to_MelSpectrogram(input)
# input = input.repeat(3, 1, 1)
input = transform(input)

In [36]:
out = model(input.unsqueeze(dim=0).to(device))

In [37]:
out.shape

torch.Size([1, 128])

##### check vit + SAP

In [38]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="vit_base_patch16_224"
pool="SAP"
data_type=["wav"]

# audio
sample_rate=16000
sample_duration=3 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

In [39]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type
timm model is used.


In [40]:
import torchaudio
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import torchvision.transforms as T
import cv2

to_PILImage = T.ToPILImage()

input = torch.rand((1,45789))

to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_fn=torch.hamming_window,
                n_mels=n_mels
            )

transform=T.Compose([
    T.ToPILImage(),
    T.Resize(size=256, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias=None),
    T.CenterCrop(size=(224, 224)),
    T.ToTensor(),
    # T.Normalize(mean=torch.tensor([0.4850]), std=torch.tensor([0.2290]))
])
            
input = to_MelSpectrogram(input)
# input = input.repeat(3, 1, 1)
input = transform(input)

In [41]:
out = model(input.unsqueeze(dim=0).to(device))

In [42]:
out.shape

torch.Size([1, 128])

#### check pytorch

##### resnet34

In [43]:
library = "pytorch"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool=None
data_type=["wav"]

# audio
sample_rate=16000
sample_duration=3 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

In [44]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

wav data type


In [45]:
import torchaudio

input = torch.rand((1,45789))

to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_fn=torch.hamming_window,
                n_mels=n_mels
            )
            
input = to_MelSpectrogram(input)

In [46]:
out = model(input.unsqueeze(dim=0).to(device))

In [47]:
out.shape

torch.Size([1, 128])

### check image model

In [50]:
from skimage import io

#### check timm

##### check resnet34

In [48]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool=None
data_type=["rgb"]

In [49]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type
timm model is used.


In [80]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [81]:
transform_image = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((128,128)),
            torchvision.transforms.ToTensor(), 
        ])

input = transform_image(input)

In [82]:
input.shape

torch.Size([3, 128, 128])

In [83]:
out = model(input.unsqueeze(dim=0).to(device))

In [84]:
out.shape

torch.Size([1, 128])

##### check resnet + SAP

In [89]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool="SAP"
data_type=["rgb"]

In [90]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type
timm model is used.


In [91]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [92]:
transform_image = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((128,128)),
            torchvision.transforms.ToTensor(), 
        ])

input = transform_image(input)

In [93]:
input.shape

torch.Size([3, 128, 128])

In [94]:
out = model(input.unsqueeze(dim=0).to(device))

In [95]:
out.shape

torch.Size([1, 128])

##### check vit

In [96]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="vit_base_patch16_224"
pool=None
data_type=["rgb"]

In [97]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type
timm model is used.


In [98]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [99]:
transform_image=torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize(size=256, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
    torchvision.transforms.CenterCrop(size=(224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=torch.tensor([0.4850]), std=torch.tensor([0.2290]))
])

input = transform_image(input)

In [100]:
input.shape

torch.Size([3, 224, 224])

In [101]:
out = model(input.unsqueeze(dim=0).to(device))

In [102]:
out.shape

torch.Size([1, 128])

##### check vit + SAP

In [103]:
library = "timm"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="vit_base_patch16_224"
pool="SAP"
data_type=["rgb"]

In [104]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type
timm model is used.


In [105]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [106]:
transform_image=torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize(size=256, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None),
    torchvision.transforms.CenterCrop(size=(224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=torch.tensor([0.4850]), std=torch.tensor([0.2290]))
])

input = transform_image(input)

In [107]:
input.shape

torch.Size([3, 224, 224])

In [108]:
out = model(input.unsqueeze(dim=0).to(device))

In [109]:
out.shape

torch.Size([1, 128])

#### check pytorch

##### check resnet

In [111]:
library = "pytorch"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool=None
data_type=["rgb"]

In [112]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type


In [113]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [114]:
transform_image = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((128,128)),
            torchvision.transforms.ToTensor(), 
        ])

input = transform_image(input)

In [115]:
input.shape

torch.Size([3, 128, 128])

In [116]:
out = model(input.unsqueeze(dim=0).to(device))

In [117]:
out.shape

torch.Size([1, 128])

##### check resnet + SAP

In [124]:
library = "pytorch"
pretrained_weights=True
fine_tune=True
embedding_size=128
model_name="resnet34"
pool="SAP"
data_type=["rgb"]

In [125]:
model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

model = model.to(device)

rgb data type
pytorch model is used.


In [126]:
input = np.random.rand(175, 130, 3).astype(np.uint8)

In [127]:
transform_image = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize((128,128)),
            torchvision.transforms.ToTensor(), 
        ])

input = transform_image(input)

In [128]:
input.shape

torch.Size([3, 128, 128])

In [129]:
out = model(input.unsqueeze(dim=0).to(device))

In [130]:
out.shape

torch.Size([1, 128])

## Check Transforms

### Audio Transform

In [14]:
import torchvision.transforms as T
from transformers import ASTFeatureExtractor
from transformers import Wav2Vec2FeatureExtractor

class Audio_Transforms:
    def __init__(self, 
                sample_rate,
                sample_duration, # seconds
                n_fft, # from Korean code
                win_length,
                hop_length,
                window_fn,
                n_mels,
                model_name, 
                library
                ):

        self.sample_rate = sample_rate
        self.sample_duration = sample_duration
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.window_fn = window_fn
        self.n_mels = n_mels
        self.model_name = model_name
        self.library = library

        if self.library == "huggingface":
            self.huggingface_init()
        elif self.library == "timm":
            self.timm_init()
        elif self.library == "pytorch":
            self.pytorch_init()

    def huggingface_init(self):
        if self.model_name == "WavLM":
            self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-sv')
        elif self.model_name == "AST":
            self.feature_extractor = ASTFeatureExtractor()

    def timm_init(self):
        self.to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=self.n_fft,
                win_length=self.win_length,
                hop_length=self.hop_length,
                window_fn=self.window_fn,
                n_mels=self.n_mels
            )

        if self.model_name == "vit_base_patch16_224":
            # n_channels = 1
            self.vit_transform=T.Compose([
                T.ToPILImage(),
                T.Resize(size=256, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias=None),
                T.CenterCrop(size=(224, 224)),
                T.ToTensor(),
                # T.Normalize(mean=torch.tensor([0.4850]), std=torch.tensor([0.2290]))
            ])
        

    def pytorch_init(self):
        self.to_MelSpectrogram =  torchaudio.transforms.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=self.n_fft,
                win_length=self.win_length,
                hop_length=self.hop_length,
                window_fn=self.window_fn,
                n_mels=self.n_mels
            )

    # MAIN TRANSFORM FUNCTION
    def transform(self, signal, sample_rate):

        signal = self.basic_transform(signal, sample_rate)

        if self.library == "huggingface":
            inputs = self.huggingface_transform(signal)
        elif self.library == "timm":
            inputs = self.timm_transform(signal)
        elif self.library == "pytorch":
            inputs = self.pytorch_transform(signal)
        
        return inputs
    
    def basic_transform(self, signal, sample_rate):

        # stereo --> mono
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        
        # sample_rate --> 16000
        if sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
            signal = resampler(signal)

        # normalize duration --> 3 seconds (mean duration in dataset)
        sample_length_signal = self.sample_duration * self.sample_rate # sample length of the audio signal
        length_signal = signal.shape[1]
        if length_signal < sample_length_signal:
            num_missing_points = int(sample_length_signal - length_signal)
            dim_padding = (0, num_missing_points) # (left_pad, right_pad)
            # ex: dim_padding = (1,2) --> [1,1,1] -> [0,1,1,1,0,0]
            signal = torch.nn.functional.pad(signal, dim_padding)
        elif length_signal > sample_length_signal:
            middle_of_the_signal = length_signal // 2
            left_edge = int(middle_of_the_signal - sample_length_signal // 2)
            right_edge = int(middle_of_the_signal + sample_length_signal // 2)
            signal = signal[:,left_edge:right_edge]

        return signal
    
    def huggingface_transform(self, audio):
        input = audio.squeeze()
        input = self.feature_extractor(input, sampling_rate=self.sample_rate, padding=True, return_tensors="pt")
        input = input.input_values.squeeze()
        return input

    def timm_transform(self, audio):
        input = self.to_MelSpectrogram(audio)
        if self.model_name == "vit_base_patch16_224":
            # input = input.repeat(3, 1, 1)
            input = self.vit_transform(input)
        return input

    def pytorch_transform(self, audio):
        input = self.to_MelSpectrogram(audio)
        return input


### Check audio

In [8]:
audio = torch.rand((1,45789))

path2wav = "/workdir/sf_pv/data_v2/sub_1/11/wav/574.wav"
audio, sample_rate = torchaudio.load(path2wav)

In [9]:
audio.shape

torch.Size([1, 66834])

In [10]:
sample_rate

16000

#### check huggingface model

In [11]:
model_names = ["AST", "WavLM"]

In [12]:
for model_name in model_names:
    library = "huggingface"

    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40


    # model params
    pretrained_weights=True
    fine_tune=True
    embedding_size=128
    pool=None
    data_type=["wav"]

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)

    model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

    model = model.to(device)

    input = audio_T.transform(audio, sample_rate)
    out = model(input.unsqueeze(dim=0).to(device))
    print(out.shape)

wav data type
HuggingFace model is used.
AST model is used.
torch.Size([1, 128])
wav data type
HuggingFace model is used.
WavLM model is used.
torch.Size([1, 128])


#### check timm model

In [13]:
model_names = ['resnet34', 'vit_base_patch16_224']

In [14]:
for model_name in model_names:
    library = "timm"
    
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40


    # model params
    pretrained_weights=True
    fine_tune=True
    embedding_size=128
    pool=None
    data_type=["wav"]

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)

    model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

    model = model.to(device)

    input = audio_T.transform(audio, sample_rate)
    out = model(input.unsqueeze(dim=0).to(device))
    print(out.shape)

wav data type
timm model is used.
torch.Size([1, 128])
wav data type
timm model is used.
torch.Size([1, 128])


#### check pytorch model

In [15]:
library = "pytorch"
model_name = 'resnet34'
    
# audio transform params
sample_rate=16000
sample_duration=2 # seconds
n_fft=512 # from Korean code
win_length=400
hop_length=160
window_fn=torch.hamming_window
n_mels=40

# model params
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None
data_type=["wav"]

audio_T = Audio_Transforms(sample_rate=sample_rate,
                            sample_duration=sample_duration, # seconds
                            n_fft=n_fft, # from Korean code
                            win_length=win_length,
                            hop_length=hop_length,
                            window_fn=torch.hamming_window,
                            n_mels=n_mels,
                            model_name=model_name,
                            library=library)

model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)

input = audio_T.transform(audio, sample_rate)
out = model(input.unsqueeze(dim=0).to(device))
print(out.shape)

wav data type
pytorch model is used.
torch.Size([1, 128])


### Image Transform

In [16]:
class Image_Transforms:
    def __init__(self, 
                 library,
                 model_name):

        self.library = library
        self.model_name = model_name

        if self.library == "huggingface":
            pass
        elif self.library == "timm":
            self.timm_init()
        elif self.library == "pytorch":
            self.pytorch_init()

    def timm_init(self):
        if self.model_name == "vit_base_patch16_224":
            # n_channels = 3
            self.transform_image=T.Compose([
                T.ToPILImage(),
                T.Resize(size=256, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias=None),
                T.CenterCrop(size=(224, 224)),
                T.ToTensor(),
                T.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
            ]) 
        else:
            self.transform_image = T.Compose([
                T.ToPILImage(),
                T.Resize((128,128)),
                T.ToTensor(), 
            ])
        
    def pytorch_init(self):
        self.transform_image = T.Compose([
                T.ToPILImage(),
                T.Resize((128,128)),
                T.ToTensor(), 
            ])

    def transform(self, image):
        return self.transform_image(image)

### Check image

In [5]:
image = np.random.rand(175, 130, 3).astype(np.uint8)

#### check timm

In [6]:
model_names = ['resnet34', 'vit_base_patch16_224']

In [7]:
for model_name in model_names:
    library = "timm"
    
    # model params
    pretrained_weights=True
    fine_tune=True
    embedding_size=128
    pool=None
    data_type=["rgb"]

    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    model = Model(library=library, 
                pretrained_weights=pretrained_weights, 
                fine_tune=fine_tune, 
                embedding_size=embedding_size,
                model_name = model_name,
                pool=pool,
                data_type=data_type)

    model = model.to(device)

    input = image_T.transform(image)
    out = model(input.unsqueeze(dim=0).to(device))
    print(out.shape)

rgb data type
timm model is used.
torch.Size([1, 128])
rgb data type
timm model is used.
torch.Size([1, 128])


#### check pytorch

In [8]:
library = "pytorch"
model_name = "resnet34"
# model params
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None
data_type=["rgb"]

image_T = Image_Transforms(model_name=model_name,
                            library=library)

model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)

input = image_T.transform(image)
out = model(input.unsqueeze(dim=0).to(device))
print(out.shape)

rgb data type
pytorch model is used.
torch.Size([1, 128])


## Check Dataset

### check wav

In [24]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_SF.csv"
dataset_dir = "/workdir/sf_pv/data_v2"
train_type = 'train'
data_type = ['wav']

# model
library = "huggingface"
model_name = "AST"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

audio_T = None
image_T = None

if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform                  

In [25]:
# Dataset
train_dataset = SpeakingFacesDataset(annotations_file,
                                     dataset_dir,'train',
                                     data_type,
                                     image_transform=image_T, 
                                     audio_transform=audio_T)

In [26]:
data, label = train_dataset[0]

In [27]:
data.shape

torch.Size([1024, 128])

In [28]:
model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)

out = model(data.unsqueeze(dim=0).to(device))
print(out.shape)

wav data type
HuggingFace model is used.
AST model is used.
torch.Size([1, 128])


### check image

In [29]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_SF.csv"
dataset_dir = "/workdir/sf_pv/data_v2"
train_type = 'train'
data_type = ['rgb']

# model
library = "timm"
model_name = "resnet34"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

audio_T = None
image_T = None

if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform                  

In [30]:
# Dataset
train_dataset = SpeakingFacesDataset(annotations_file,
                                     dataset_dir,'train',
                                     data_type,
                                     image_transform=image_T, 
                                     audio_transform=audio_T)

In [31]:
data, label = train_dataset[0]

In [32]:
data.shape

torch.Size([3, 128, 128])

In [33]:
model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)

out = model(data.unsqueeze(dim=0).to(device))
print(out.shape)

rgb data type
timm model is used.
torch.Size([1, 128])


## Check Dataset + Dataloader

In [54]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_SF.csv"
path2datasets = "/workdir/sf_pv"
dataset_dir = f"{path2datasets}/data_v2"
train_type = 'train'
data_type = ['rgb']

# model
library = "timm"
model_name = "vit_base_patch16_224"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

# transform
audio_T = None
image_T = None

# sampler
n_batch=10
n_ways=2
n_support=1
n_query=1

In [55]:
if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform         

model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)         

rgb data type
timm model is used.


In [56]:
# Dataset
train_dataset = SpeakingFacesDataset(annotations_file,dataset_dir,'train',
                                image_transform=image_T, 
                                audio_transform=audio_T,
                                data_type=data_type)
# sampler
train_sampler = ProtoSampler(train_dataset.labels,
                            n_batch,
                            n_ways, # n_way
                            n_support, # n_shots
                            n_query)
# dataloader
train_dataloader = DataLoader(dataset=train_dataset, 
                        batch_sampler=train_sampler)

In [57]:
model.train()
for batch in train_dataloader:
    data_type = sorted(data_type)

    if len(data_type) == 1:
        data, label = batch
        data = data.to(device)
        out = model(data)
        print(out.shape)

    break

torch.Size([4, 128])


In [58]:
valid_dataset = ValidDataset(path2datasets,'valid',
                            image_transform=image_T, 
                            audio_transform=audio_T,
                            data_type=data_type)

In [59]:
valid_dataloader = DataLoader(dataset=valid_dataset,
                        shuffle=True,
                        batch_size=64)

In [60]:
model.eval()
for batch in valid_dataloader:
    data_type = sorted(data_type)

    id1, id2, labels = batch

    if len(data_type) == 1:
        data_id1, _ = id1
        data_id2, _ = id2

        data_id1 = data_id1.to(device)
        data_id2 = data_id2.to(device)

    break

In [61]:
with torch.no_grad():
    id1_out = model(data_id1)
    id2_out = model(data_id2)

## Check Train module

In [7]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_SF.csv"
path2datasets = "/workdir/sf_pv"
dataset_dir = f"{path2datasets}/data_v2"
data_type = ['wav']

# model
library = "huggingface"
model_name = "AST"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

# transform
audio_T = None
image_T = None

# sampler
n_batch=10
n_ways=2
n_support=1
n_query=1

# loss
dist_type='squared_euclidean'

# train
num_epochs=1
save_dir='/workdir/results'
exp_name='chern'
wandb=None

In [8]:
if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform         

model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)         

wav data type
HuggingFace model is used.
AST model is used.


In [9]:
# Dataset
train_dataset = SpeakingFacesDataset(annotations_file,dataset_dir,'train',
                                image_transform=image_T, 
                                audio_transform=audio_T,
                                data_type=data_type)
valid_dataset = ValidDataset(path2datasets,'valid',
                            image_transform=image_T, 
                            audio_transform=audio_T,
                            data_type=data_type)
# sampler
train_sampler = ProtoSampler(train_dataset.labels,
                            n_batch,
                            n_ways, # n_way
                            n_support, # n_shots
                            n_query)
# dataloader
train_dataloader = DataLoader(dataset=train_dataset, 
                        batch_sampler=train_sampler)


valid_dataloader = DataLoader(dataset=valid_dataset,
                        shuffle=True,
                        batch_size=64)

In [10]:
# loss
criterion = PrototypicalLoss(dist_type=dist_type)
criterion = criterion.to(device)

# optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.95)

In [11]:
# train
model = train_model(model=model,
                    train_dataloader=train_dataloader, 
                    valid_dataloader=valid_dataloader,
                    train_sampler=train_sampler,
                    criterion=criterion,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    device=device,
                    num_epochs=num_epochs,
                    save_dir=save_dir,
                    exp_name=exp_name,
                    data_type=data_type,
                    wandb=wandb)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]


Average train loss: 0.7720251381397247
Average train accuracy: 70.0


