In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn 

from tqdm import tqdm

from model.SpeechLP import SLP
from model.NeuralAudioCodec import NAC

from utils.Config import ConfigSLP, ConfigNAC
from utils.MLS import MLSDataset
from utils.Trainer import Trainer
from utils.Processing import Processing

from torch.utils.data import DataLoader

In [None]:
ConfigSLP.display()

In [4]:
# Processing.remove_metadata_from_audio_folder(ConfigSLP.TRAIN_PATH+"/"+"audio", ConfigSLP.TRAIN_PATH+"/"+"audio_clean",)
# Processing.remove_metadata_from_audio_folder(ConfigSLP.TEST_PATH+"/"+"audio", ConfigSLP.TEST_PATH+"/"+"audio_clean",)
# Processing.remove_metadata_from_audio_folder(ConfigSLP.DEV_PATH+"/"+"audio", ConfigSLP.DEV_PATH+"/"+"audio_clean",)

In [None]:
train_set = MLSDataset(
    data_dir=ConfigSLP.TRAIN_PATH,
    max_text_token_length=ConfigSLP.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigSLP.SAMPLE_RATE,
    nb_samples = ConfigSLP.NB_SAMPLES
)

val_set = MLSDataset(
    data_dir=ConfigSLP.DEV_PATH,
    max_text_token_length=ConfigSLP.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigSLP.SAMPLE_RATE,
)


test_set = MLSDataset(
    data_dir=ConfigSLP.TEST_PATH,
    max_text_token_length=ConfigSLP.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigSLP.SAMPLE_RATE,
)

train_loader = DataLoader(train_set, batch_size=ConfigSLP.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)
val_loader = DataLoader(val_set, batch_size=ConfigSLP.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)
test_loader = DataLoader(test_set, batch_size=ConfigSLP.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)

In [None]:
model_slp = SLP(ConfigSLP.NB_CLASSES, ConfigSLP.NHEAD ,ConfigSLP.NUM_LAYERS)
model_slp = model_slp.to(ConfigSLP.DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW

trainer = Trainer()
trainer.set_model(model_slp, name=ConfigSLP.MODEL_NAME)\
    .set_criterion(criterion)\
    .set_optimizer(optimizer)\
    .fit(
        train_data=train_loader, validation_data=val_loader, 
        epochs=ConfigSLP.EPOCHS, learning_rate=ConfigSLP.LEARNING_RATE, checkpoint_interval=1        
    )

In [None]:
train_set = MLSDataset(
    data_dir=ConfigNAC.TRAIN_PATH,
    max_text_token_length=ConfigNAC.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigNAC.SAMPLE_RATE,
    nb_samples = ConfigNAC.NB_SAMPLES,
    tokenizer_model="gpt2"
)

val_set = MLSDataset(
    data_dir=ConfigNAC.DEV_PATH,
    max_text_token_length=ConfigNAC.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigNAC.SAMPLE_RATE,
    tokenizer_model="gpt2"
)

test_set = MLSDataset(
    data_dir=ConfigNAC.TEST_PATH,
    max_text_token_length=ConfigNAC.MAX_TOKEN_LENGTH,
    sampling_rate=ConfigNAC.SAMPLE_RATE,
    tokenizer_model="gpt2"
)

train_loader = DataLoader(train_set, batch_size=ConfigNAC.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)
val_loader = DataLoader(val_set, batch_size=ConfigNAC.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)
test_loader = DataLoader(test_set, batch_size=ConfigNAC.BATCH_SIZE, shuffle=True, collate_fn=MLSDataset.collate_fn)

In [None]:
model_nac = NAC(ConfigNAC.LAMBDA_FACTOR)
model_nac = model_nac.to(ConfigNAC.DEVICE)
optimizer = torch.optim.AdamW


def train(self, train_loader):
    losses = 0
    self.model.train()

    for batch in tqdm(train_loader):
        batch["text"]["input_ids"] = batch["text"]["input_ids"].to(self.device)
        batch["text"]["attention_mask"] = batch["text"]["attention_mask"].to(self.device)

        text = batch["text"]
        audio = batch["audio"].to(self.device)
        padding_mask_audio = batch["padding_mask_audio"].to(self.device)

        output = self.model(text, audio, padding_mask_audio)
        loss = output["total_loss"]

        losses += loss.item()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    return losses / len(train_loader), {"lm_loss" : output["lm_loss"], "reconstruction_loss": output["reconstruction_loss"]}


def validation(self, validation_loader):
    losses = 0
    self.model.eval()
    with torch.no_grad():
        for batch in tqdm(validation_loader):
            batch["text"]["input_ids"] = batch["text"]["input_ids"].to(self.device)
            batch["text"]["attention_mask"] = batch["text"]["attention_mask"].to(self.device)
            
            text = batch["text"]
            audio = batch["audio"].to(self.device)
            padding_mask_audio = batch["padding_mask_audio"].to(self.device)

            output = self.model(text, audio, padding_mask_audio)
            loss = output["total_loss"]
            
            losses += loss.item()

    return losses / len(validation_loader), {"lm_loss" : output["lm_loss"], "reconstruction_loss": output["reconstruction_loss"]}


In [None]:
trainer = Trainer()
trainer.set_model(model_nac, name=ConfigNAC.MODEL_NAME)\
    .set_criterion(torch.nn.MSELoss)\
    .set_optimizer(optimizer)\
    .set_custom_functions(train_func=train, validation_func=validation)\
    .fit(
        train_data=train_loader, validation_data=val_loader, 
        epochs=ConfigNAC.EPOCHS, learning_rate=ConfigNAC.LEARNING_RATE, checkpoint_interval=1        
    )

: 