In [1]:
import torch

from Modules import LoadingModule
from Modules import Features_encoder
from Modules import quantizationModule
from Modules import wav2vec_transformer
from Modules import ContrastiveLoss

from Modules import TempLibriSpeech

In [2]:
"""

#data loader module init
StandardScalerTransform = LoadingModule.StandardScalerTransform
LargeDataModule = LoadingModule.LargeDataModule("./data/Librispeech", batch_size=16, num_workers=1, transform=StandardScalerTransform)
"""

'\n\n#data loader module init\nStandardScalerTransform = LoadingModule.StandardScalerTransform\nLargeDataModule = LoadingModule.LargeDataModule("./data/Librispeech", batch_size=16, num_workers=1, transform=StandardScalerTransform)\n'

In [2]:
#Temp import dataloader ### rendre compatible PLightning quand on aura le GPU
# en attendant import manuel
from torch.utils.data import DataLoader



dataset = TempLibriSpeech.LibriSpeech(split="train-clean-100", target_length=48000, device='cuda')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

In [4]:
for i, (audio, text) in enumerate(data_loader):
    print(f"Exemple {i+1}")
    print(f"Audio shape: {audio.shape}")
    print(f"Texte: {text}")
    print("-" * 50)
    if i == 1: 
        break

Exemple 1
Audio shape: torch.Size([8, 48000])
Texte: ('INTO THESE THE BLOOD OF NECESSITY FLOWS FROM THE HOLLOW VEIN INTO THE RIGHT AND FROM THE VENOUS ARTERY INTO THE LEFT BECAUSE THESE TWO VESSELS ARE ALWAYS FULL OF BLOOD AND THEIR ORIFICES WHICH ARE TURNED TOWARDS THE HEART CANNOT THEN BE CLOSED', 'OTHERS SINGING OTHERS PLAYING THE VARIOUS INSTRUMENTS ALREADY MENTIONED', 'STILL I WAS SORRY FOR THE POOR LITTLE OLD LADY I WISH SOMEHOW SHE COULD HAVE THAT HUNDRED DOLLARS IT WAS THE MAN WHO SAID THIS NOT THE COLLECTOR SO DO I REJOINED BILLY DOLEFULLY', "WE OUGHT TO BE ABLE TO GET A GOOD DINNER ZVERKOV OF COURSE WON'T PAY OF COURSE NOT SINCE WE ARE INVITING HIM SIMONOV DECIDED CAN YOU IMAGINE FERFITCHKIN INTERRUPTED HOTLY AND CONCEITEDLY", 'ONE CARRIAGE DELAYED SUFFICED TO PARALYZE THE WHOLE LINE THEN THEY SET OUT AGAIN ON THE MARCH THE WEDDING CARRIAGES WERE IN THE FILE PROCEEDING TOWARDS THE BASTILLE AND SKIRTING THE RIGHT SIDE OF THE BOULEVARD', 'AND FORGOT HER OWN GRIEF IN SOLACING TH

In [5]:
### Model dev ###

In [6]:
import torch
import torch.nn as nn

class Model_W2V(nn.Module):
    def __init__(self, embed_size, num_heads, dropout, forward_expansion, kernel_size, groups, d_model, num_layers, max_relative_position):

        #EAB
        self.batch_size = batch_size
        #self.seq_length = seq_length
        self.embed_size = embed_size
        self.mask_prob = 0.50
        self.mask_length = 1
        self.num_heads = num_heads
        self.dropout = dropout
        self.forward_expansion = forward_expansion
        self.kernel_size = kernel_size
        self.groups = groups
        self.d_model = d_model
        self.num_layers = num_layers

        self.num_codebooks = 2
        self.num_codes = 320
        
        self.code_dim = 256
        self.output_dim = 512
        self.temperature= 0.5

        self.max_relative_position = max_relative_position

        super(Model_W2V, self).__init__()

        

        self.FeaturesEncoder = Features_encoder.FeatureEncoder(input_channels=1, feature_dim=512) #1501 ?
        self.masking = wav2vec_transformer.MaskingWithLearnableEmbedding()
        # d_model, num_heads, dropout, forward_expansion):
        self.TranformerBlock = wav2vec_transformer.TransformerBlockW(self.d_model, self.num_heads, self.dropout, self.forward_expansion)   #(self.embed_size, self.num_heads, self.dropout, self.forward_expansion, self.kernel_size, self.groups, self.d_model, self.max_relative_position)
        self.quantization = quantizationModule.QuantizationModule(
            input_dim=512,  # Should match feature_dim from FeatureEncoder
            codebook_size=self.num_codes,
            num_codebooks=self.num_codebooks,
            output_dim=self.output_dim,
            temperature=self.temperature
        )
        self.LossItem = ContrastiveLoss.LossW2V(20,self.temperature)
#embed_size, num_heads, dropout, forward_expansion, kernel_size, groups, d_model):
    def forward(self, x):


       # print("ORIGINAL , ", x.shape)
        x = x.to(next(self.parameters()).device)
        x = x.unsqueeze(1)

        x = self.FeaturesEncoder(x)
        
       #
        
        
        # print("q",x.shape)
        
        quantized_repr = self.quantization(x)
        
        masked_reps, mask = self.masking(x, self.mask_prob, self.mask_length) #(self, x, mask_prob, mask_length)
        
        contextualized_reps = self.TranformerBlock(masked_reps, masked_reps, masked_reps, mask)
                                                # value, key, query, mask=None
        

        #print("Debug", contextualized_reps.shape, quantized_repr.shape, mask.shape)
        loss = self.LossItem.compute_loss(contextualized_reps, quantized_repr, mask, self.batch_size)
        # print("Context Representation shape:", contextualized_reps.shape)
        # print("Quantized Representation shape:", quantized_repr.shape)
        # print("Mask Indices shape:", masked_reps.shape)
        # mask = torch.tensor(mask)
        # print("Unique Mask Indices values:", mask.unique())

        
   # embed_size, num_heads, dropout, forward_expansion, kernel_size, groups,d_model
        
        return x, contextualized_reps, loss
    

In [8]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

def train_model(model, dataset, epochs, learning_rate, device):

    dataloader = DataLoader(dataset, batch_size=model.batch_size, shuffle=True)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        
        epoch_loss = 0
        total_loss = 0.0

        num_batches = len(dataloader) - 1

        for batch_idx, (inputs, _) in enumerate(tqdm( data_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            if batch_idx >= num_batches:
                break  # S'arrêter avant la dernière itération
                
            optimizer.zero_grad()
            
            _,_, loss = model(inputs)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{num_batches}], Loss: {loss.item():.4f}")

        avg_loss = total_loss / num_batches
        print(f"Epoch [{epoch+1}/{epochs}] Average Loss: {avg_loss:.4f}")


In [9]:
batch_size = 8
seq_length = 151
embed_size = 512
num_heads = 8
dropout = 0.1
forward_expansion = 4
kernel_size = 7
groups = 2
d_model = 512
num_layers = 12

max_relative_position=128
torch.autograd.set_detect_anomaly(True)

device = 'cuda'#torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model_W2V(embed_size, num_heads, dropout, forward_expansion, kernel_size, groups, d_model, num_layers, max_relative_position).to(device)


train_model(model, dataset, epochs=1, learning_rate=1e-4, device=device)


Epoch 1/1:   0%|          | 1/3568 [00:00<49:42,  1.20it/s]

Epoch [1/1], Step [1/3567], Loss: 3.0481


Epoch 1/1:   3%|▎         | 101/3568 [00:57<32:34,  1.77it/s]

Epoch [1/1], Step [101/3567], Loss: 3.0419


Epoch 1/1:   6%|▌         | 201/3568 [01:53<31:35,  1.78it/s]

Epoch [1/1], Step [201/3567], Loss: 2.9518


Epoch 1/1:   8%|▊         | 301/3568 [02:51<31:11,  1.75it/s]

Epoch [1/1], Step [301/3567], Loss: 2.9485


Epoch 1/1:  11%|█         | 401/3568 [03:49<30:01,  1.76it/s]

Epoch [1/1], Step [401/3567], Loss: 2.9473


Epoch 1/1:  14%|█▍        | 501/3568 [04:45<28:47,  1.77it/s]

Epoch [1/1], Step [501/3567], Loss: 2.9472


Epoch 1/1:  17%|█▋        | 601/3568 [05:42<28:16,  1.75it/s]

Epoch [1/1], Step [601/3567], Loss: 2.9466


Epoch 1/1:  20%|█▉        | 701/3568 [06:38<27:19,  1.75it/s]

Epoch [1/1], Step [701/3567], Loss: 2.9458


Epoch 1/1:  22%|██▏       | 801/3568 [07:35<25:58,  1.78it/s]

Epoch [1/1], Step [801/3567], Loss: 2.9460


Epoch 1/1:  25%|██▌       | 901/3568 [08:32<25:08,  1.77it/s]

Epoch [1/1], Step [901/3567], Loss: 2.9457


Epoch 1/1:  28%|██▊       | 1001/3568 [09:29<25:04,  1.71it/s]

Epoch [1/1], Step [1001/3567], Loss: 2.9453


Epoch 1/1:  31%|███       | 1101/3568 [10:26<23:02,  1.78it/s]

Epoch [1/1], Step [1101/3567], Loss: 2.9458


Epoch 1/1:  34%|███▎      | 1201/3568 [11:23<22:04,  1.79it/s]

Epoch [1/1], Step [1201/3567], Loss: 2.9453


Epoch 1/1:  36%|███▋      | 1301/3568 [12:20<21:22,  1.77it/s]

Epoch [1/1], Step [1301/3567], Loss: 2.9452


Epoch 1/1:  39%|███▉      | 1401/3568 [13:16<20:18,  1.78it/s]

Epoch [1/1], Step [1401/3567], Loss: 2.9453


Epoch 1/1:  42%|████▏     | 1501/3568 [14:13<19:57,  1.73it/s]

Epoch [1/1], Step [1501/3567], Loss: 2.9452


Epoch 1/1:  45%|████▍     | 1601/3568 [15:10<18:31,  1.77it/s]

Epoch [1/1], Step [1601/3567], Loss: 2.9450


Epoch 1/1:  48%|████▊     | 1701/3568 [16:07<17:40,  1.76it/s]

Epoch [1/1], Step [1701/3567], Loss: 2.9450


Epoch 1/1:  50%|█████     | 1801/3568 [17:04<16:31,  1.78it/s]

Epoch [1/1], Step [1801/3567], Loss: 2.9452


Epoch 1/1:  53%|█████▎    | 1901/3568 [18:01<15:49,  1.76it/s]

Epoch [1/1], Step [1901/3567], Loss: 2.9450


Epoch 1/1:  56%|█████▌    | 2001/3568 [18:58<14:25,  1.81it/s]

Epoch [1/1], Step [2001/3567], Loss: 2.9451


Epoch 1/1:  59%|█████▉    | 2101/3568 [19:55<14:30,  1.68it/s]

Epoch [1/1], Step [2101/3567], Loss: 2.9450


Epoch 1/1:  62%|██████▏   | 2201/3568 [20:52<12:36,  1.81it/s]

Epoch [1/1], Step [2201/3567], Loss: 2.9448


Epoch 1/1:  64%|██████▍   | 2301/3568 [21:49<11:56,  1.77it/s]

Epoch [1/1], Step [2301/3567], Loss: 2.9449


Epoch 1/1:  67%|██████▋   | 2401/3568 [22:46<10:36,  1.83it/s]

Epoch [1/1], Step [2401/3567], Loss: 2.9450


Epoch 1/1:  70%|███████   | 2501/3568 [23:43<10:06,  1.76it/s]

Epoch [1/1], Step [2501/3567], Loss: 2.9450


Epoch 1/1:  73%|███████▎  | 2601/3568 [24:39<08:58,  1.79it/s]

Epoch [1/1], Step [2601/3567], Loss: 2.9448


Epoch 1/1:  76%|███████▌  | 2701/3568 [25:36<08:05,  1.79it/s]

Epoch [1/1], Step [2701/3567], Loss: 2.9448


Epoch 1/1:  79%|███████▊  | 2801/3568 [26:32<08:14,  1.55it/s]

Epoch [1/1], Step [2801/3567], Loss: 2.9448


Epoch 1/1:  81%|████████▏ | 2901/3568 [27:29<06:17,  1.77it/s]

Epoch [1/1], Step [2901/3567], Loss: 2.9448


Epoch 1/1:  84%|████████▍ | 3001/3568 [28:26<05:19,  1.77it/s]

Epoch [1/1], Step [3001/3567], Loss: 2.9449


Epoch 1/1:  87%|████████▋ | 3101/3568 [29:24<04:23,  1.77it/s]

Epoch [1/1], Step [3101/3567], Loss: 2.9449


Epoch 1/1:  90%|████████▉ | 3201/3568 [30:20<03:31,  1.73it/s]

Epoch [1/1], Step [3201/3567], Loss: 2.9448


Epoch 1/1:  93%|█████████▎| 3301/3568 [31:16<02:31,  1.76it/s]

Epoch [1/1], Step [3301/3567], Loss: 2.9448


Epoch 1/1:  95%|█████████▌| 3401/3568 [32:13<01:33,  1.78it/s]

Epoch [1/1], Step [3401/3567], Loss: 2.9448


Epoch 1/1:  98%|█████████▊| 3501/3568 [33:10<00:38,  1.72it/s]

Epoch [1/1], Step [3501/3567], Loss: 2.9447


Epoch 1/1: 100%|█████████▉| 3567/3568 [33:48<00:00,  1.76it/s]

Epoch [1/1] Average Loss: 2.9492





In [14]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(avg_loss)

NameError: name 'avg_loss' is not defined

<Figure size 640x480 with 0 Axes>

In [None]:
from sklearn.metrics import accuracy_score, f1_score

In [None]:
model.eval()

# Créez un DataLoader pour le jeu de données de test
dataset = TempLibriSpeech.LibriSpeech(split="test-clean", target_length=480000, device='cuda')
data_loader = DataLoader(dataset, batch_size=16, shuffle=False)

all_quantized_reps = []
all_contextualized_reps = []

# Calcul de la perte de reconstruction ou de quantification
total_loss = 0
with torch.no_grad():
    for batch_idx, batch in enumerate(data_loader):
        inputs, _ = batch  # Ici, _ signifie qu'il n'y a pas de labels
        
        inputs = inputs.to('cuda')  # Assurez-vous que les inputs sont sur le bon device

        # Passe avant
        quantized_repr, contextualized_reps, loss = model(inputs)

        # Ajout des représentations à la liste
        all_quantized_reps.append(quantized_repr.cpu().numpy())
        all_contextualized_reps.append(contextualized_reps.cpu().numpy())

        # Accumuler la perte
        total_loss += loss.item()

# Calcul de la moyenne de la perte sur l'ensemble du dataset
average_loss = total_loss / len(data_loader)

# Affichage des résultats
print(f"Average Loss: {average_loss:.4f}")


In [None]:
average_loss = total_loss / len(data_loader)

# Affichage des résultats
print(f"Average Loss: {average_loss:.4f}")

In [None]:
inputs

In [16]:
torch.cuda.empty_cache()