In [1]:
import torch

from Modules import LoadingModule
from Modules import Features_encoder
from Modules import quantizationModule
from Modules import wav2vec_transformer
from Modules import ContrastiveLoss

from Modules import TempLibriSpeech

In [2]:
"""

#data loader module init
StandardScalerTransform = LoadingModule.StandardScalerTransform
LargeDataModule = LoadingModule.LargeDataModule("./data/Librispeech", batch_size=16, num_workers=1, transform=StandardScalerTransform)
"""

'\n\n#data loader module init\nStandardScalerTransform = LoadingModule.StandardScalerTransform\nLargeDataModule = LoadingModule.LargeDataModule("./data/Librispeech", batch_size=16, num_workers=1, transform=StandardScalerTransform)\n'

In [3]:
#Temp import dataloader ### rendre compatible PLightning quand on aura le GPU
# en attendant import manuel
from torch.utils.data import DataLoader



dataset = TempLibriSpeech.LibriSpeech(split="train-clean-100", target_length=16000, device='cuda')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [4]:
for i, (audio, text) in enumerate(data_loader):
    print(f"Exemple {i+1}")
    print(f"Audio shape: {audio.shape}")
    print(f"Texte: {text}")
    print("-" * 50)
    if i == 1: 
        break

Exemple 1
Audio shape: torch.Size([16, 16000])
Texte: ('THE FACTOR THAT UNDERLIES ALL THE PERPLEXITIES AND MOST OF THE CONTENTMENT OF MARRIAGE IS ITS UNIQUE DEGREE OF CONCENTRATED INTIMACY HERE THE SUPREME TESTING ALWAYS COMES EACH MEANS SO MUCH TO THE OTHER', "YOUR HONOUR REPLIED THE CORPORAL KNOWS OF TOM'S MISFORTUNES BUT THIS AFFAIR HAS NOTHING TO DO WITH THEM ANY FURTHER THAN THIS", "YES I SUPPOSE IT'S SO WELL SENATOR BALLOON PUT FIFTEEN CENTS WORTH OF STAMPS ON EACH OF THOSE SEVEN HUGE BOXES OF OLD CLOTHES AND SHIPPED THAT TON OF SECOND HAND RUBBISH OLD BOOTS AND PANTALOONS AND WHAT NOT THROUGH THE MAILS AS REGISTERED MATTER", 'I IMAGINED THAT THEY WOULD BE DISGUSTED UNTIL', "BUT PIERCED BY ONE ONLY THE MOTTO BEING SHE ALONE THE HEART WAS MADE OF A SINGLE RUBY AS BIG AS AN OSTRICH'S EGG", 'STAY COUNT HE ADDED YOU WHO MAY BE CALLED THE EMPEROR IF I CLAIM THE TITLE OF KING OF FINANCE HAVE YOU MANY PIECES OF PAPER OF THIS SIZE EACH WORTH A MILLION', 'MISTER WILMINGTON AND MISSUS MUNG

In [5]:
### Model dev ###

In [6]:
import torch
import torch.nn as nn

class Model_W2V(nn.Module):
    def __init__(self, embed_size, num_heads, dropout, forward_expansion, kernel_size, groups, d_model, num_layers, max_relative_position):

        #EAB
        self.batch_size = batch_size
        #self.seq_length = seq_length
        self.embed_size = embed_size
        self.mask_prob = 0.50
        self.mask_length = 10
        self.num_heads = num_heads
        self.dropout = dropout
        self.forward_expansion = forward_expansion
        self.kernel_size = kernel_size
        self.groups = groups
        self.d_model = d_model
        self.num_layers = num_layers

        self.num_codebooks = 2
        self.num_codes = 320
        
        self.code_dim = 256
        self.output_dim = 512
        self.temperature= 0.05

        self.max_relative_position = max_relative_position

        super(Model_W2V, self).__init__()

        

        self.FeaturesEncoder = Features_encoder.FeatureEncoder(input_channels=1, feature_dim=512) #1501 ?
        self.masking = wav2vec_transformer.MaskingWithLearnableEmbedding()
        # d_model, num_heads, dropout, forward_expansion):

        #embed_size, num_heads, dropout, forward_expansion,max_relative_position):
        self.TranformerBlock = wav2vec_transformer.TransformerBlockW(self.embed_size, self.num_heads, self.dropout, self.forward_expansion, self.max_relative_position)   #(self.embed_size, self.num_heads, self.dropout, self.forward_expansion, self.kernel_size, self.groups, self.d_model, self.max_relative_position)
        self.quantization = quantizationModule.QuantizationModule(
            input_dim=512,  # Should match feature_dim from FeatureEncoder
            codebook_size=self.num_codes,
            num_codebooks=self.num_codebooks,
            output_dim=self.output_dim,
            temperature=self.temperature
        )

        # (K , k temp, G codevectorgroup, Vcodevectorpergroup, a 0,05)
        
        self.LossItem = ContrastiveLoss.Wav2vec2Loss(K=100,k=self.temperature, G=self.num_codebooks, V = self.num_codes, a=0.05)
        

    
    def forward(self, x):


       # print("ORIGINAL , ", x.shape)
        x = x.to(next(self.parameters()).device)
        x = x.unsqueeze(1)

        x = self.FeaturesEncoder(x)
        
       #
        
        
        # print("q",x.shape)
        
        quantized_repr, diversity_loss = self.quantization(x)
        
        masked_reps, mask = self.masking(x, self.mask_prob, self.mask_length) #(self, x, mask_prob, mask_length)
        
        contextualized_reps = self.TranformerBlock(masked_reps, masked_reps, masked_reps, mask)
                                                # value, key, query, mask=None
        

        #print("Debug", contextualized_reps.shape, quantized_repr.shape, mask.shape)
        
        
        # context_repr, quantized_repr, perplexity, time_mask_indices):

        
        loss = self.LossItem(contextualized_reps, quantized_repr, diversity_loss, mask)
        
        # print("Context Representation shape:", contextualized_reps.shape)
        # print("Quantized Representation shape:", quantized_repr.shape)
        # print("Mask Indices shape:", masked_reps.shape)
        # mask = torch.tensor(mask)
        # print("Unique Mask Indices values:", mask.unique())

        
   # embed_size, num_heads, dropout, forward_expansion, kernel_size, groups,d_model
        
        return x, contextualized_reps, loss
    

In [7]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm


loss_123 = []

def train_model(model, dataset, epochs, learning_rate, device):

    dataloader = DataLoader(dataset, batch_size=model.batch_size, shuffle=True)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        
        epoch_loss = 0
        total_loss = 0.0

        num_batches = len(dataloader) - 1

        for batch_idx, (inputs, _) in enumerate(tqdm( data_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            
            if batch_idx >= num_batches:
                print("break")
                break  # S'arrêter avant la dernière itération

            optimizer.zero_grad()
            
            _,_, loss = model(inputs)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{num_batches}], Loss: {loss.item():.4f}")
                loss_123.append(loss.item())
                
        avg_loss = total_loss / num_batches
        print(f"Epoch [{epoch+1}/{epochs}] Average Loss: {avg_loss:.4f}")

    return loss_123


In [None]:
batch_size = 16
seq_length = 151
embed_size = 512
num_heads = 8
dropout = 0.0
forward_expansion = 4
kernel_size = 7
groups = 2
d_model = 512
num_layers = 12

max_relative_position=128
torch.autograd.set_detect_anomaly(True)

device = 'cuda'#torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model_W2V(embed_size, num_heads, dropout, forward_expansion, kernel_size, groups, embed_size, num_layers, max_relative_position).to(device)


loss_123 = train_model(model, dataset, epochs=7, learning_rate=5e-4, device=device)


Epoch 1/7:   0%|          | 1/1784 [00:03<1:50:40,  3.72s/it]

Epoch [1/7], Step [1/1783], Loss: 5.9177


Epoch 1/7:   6%|▌         | 101/1784 [02:03<32:46,  1.17s/it]

Epoch [1/7], Step [101/1783], Loss: 5.7112


Epoch 1/7:  11%|█▏        | 201/1784 [04:02<32:11,  1.22s/it]

Epoch [1/7], Step [201/1783], Loss: 5.7115


Epoch 1/7:  17%|█▋        | 301/1784 [06:01<29:41,  1.20s/it]

Epoch [1/7], Step [301/1783], Loss: 5.7094


Epoch 1/7:  22%|██▏       | 401/1784 [07:58<28:04,  1.22s/it]

Epoch [1/7], Step [401/1783], Loss: 5.7102


Epoch 1/7:  28%|██▊       | 501/1784 [09:57<25:44,  1.20s/it]

Epoch [1/7], Step [501/1783], Loss: 5.7145


Epoch 1/7:  34%|███▎      | 601/1784 [11:56<24:26,  1.24s/it]

Epoch [1/7], Step [601/1783], Loss: 5.7132


Epoch 1/7:  39%|███▉      | 701/1784 [13:55<21:41,  1.20s/it]

Epoch [1/7], Step [701/1783], Loss: 5.7066


Epoch 1/7:  45%|████▍     | 801/1784 [15:53<20:28,  1.25s/it]

Epoch [1/7], Step [801/1783], Loss: 5.7186


Epoch 1/7:  51%|█████     | 901/1784 [17:52<17:57,  1.22s/it]

Epoch [1/7], Step [901/1783], Loss: 5.7084


Epoch 1/7:  56%|█████▌    | 1001/1784 [19:50<15:31,  1.19s/it]

Epoch [1/7], Step [1001/1783], Loss: 5.7098


Epoch 1/7:  62%|██████▏   | 1101/1784 [21:49<13:31,  1.19s/it]

Epoch [1/7], Step [1101/1783], Loss: 5.7120


Epoch 1/7:  67%|██████▋   | 1201/1784 [23:50<11:35,  1.19s/it]

Epoch [1/7], Step [1201/1783], Loss: 5.7106


Epoch 1/7:  73%|███████▎  | 1301/1784 [25:50<09:39,  1.20s/it]

Epoch [1/7], Step [1301/1783], Loss: 5.7095


Epoch 1/7:  79%|███████▊  | 1401/1784 [27:50<07:29,  1.17s/it]

Epoch [1/7], Step [1401/1783], Loss: 5.7153


Epoch 1/7:  84%|████████▍ | 1501/1784 [29:48<05:16,  1.12s/it]

Epoch [1/7], Step [1501/1783], Loss: 5.7122


Epoch 1/7:  90%|████████▉ | 1601/1784 [31:44<03:33,  1.17s/it]

Epoch [1/7], Step [1601/1783], Loss: 5.7124


Epoch 1/7:  95%|█████████▌| 1701/1784 [33:44<01:39,  1.20s/it]

Epoch [1/7], Step [1701/1783], Loss: 5.7151


Epoch 1/7: 100%|█████████▉| 1783/1784 [35:24<00:01,  1.19s/it]


break
Epoch [1/7] Average Loss: 5.7145


Epoch 2/7:   0%|          | 1/1784 [00:01<35:36,  1.20s/it]

Epoch [2/7], Step [1/1783], Loss: 5.7157


Epoch 2/7:   6%|▌         | 101/1784 [01:59<33:50,  1.21s/it]

Epoch [2/7], Step [101/1783], Loss: 5.7162


Epoch 2/7:  11%|█▏        | 201/1784 [03:58<32:30,  1.23s/it]

Epoch [2/7], Step [201/1783], Loss: 5.7072


Epoch 2/7:  17%|█▋        | 301/1784 [06:00<29:37,  1.20s/it]

Epoch [2/7], Step [301/1783], Loss: 5.7144


Epoch 2/7:  22%|██▏       | 401/1784 [08:01<27:28,  1.19s/it]

Epoch [2/7], Step [401/1783], Loss: 5.7176


Epoch 2/7:  28%|██▊       | 501/1784 [09:59<25:06,  1.17s/it]

Epoch [2/7], Step [501/1783], Loss: 5.7099


Epoch 2/7:  34%|███▎      | 601/1784 [12:00<24:10,  1.23s/it]

Epoch [2/7], Step [601/1783], Loss: 5.7099


Epoch 2/7:  39%|███▉      | 701/1784 [14:00<21:52,  1.21s/it]

Epoch [2/7], Step [701/1783], Loss: 5.7151


Epoch 2/7:  45%|████▍     | 801/1784 [15:59<18:51,  1.15s/it]

Epoch [2/7], Step [801/1783], Loss: 5.7090


Epoch 2/7:  51%|█████     | 901/1784 [17:58<17:31,  1.19s/it]

Epoch [2/7], Step [901/1783], Loss: 5.7078


Epoch 2/7:  56%|█████▌    | 1001/1784 [19:58<15:30,  1.19s/it]

Epoch [2/7], Step [1001/1783], Loss: 5.7120


Epoch 2/7:  62%|██████▏   | 1101/1784 [22:00<13:48,  1.21s/it]

Epoch [2/7], Step [1101/1783], Loss: 5.7142


Epoch 2/7:  67%|██████▋   | 1201/1784 [24:01<12:14,  1.26s/it]

Epoch [2/7], Step [1201/1783], Loss: 5.7046


Epoch 2/7:  73%|███████▎  | 1301/1784 [26:02<09:27,  1.17s/it]

Epoch [2/7], Step [1301/1783], Loss: 5.7082


Epoch 2/7:  79%|███████▊  | 1401/1784 [28:02<07:35,  1.19s/it]

Epoch [2/7], Step [1401/1783], Loss: 5.7151


Epoch 2/7:  84%|████████▍ | 1501/1784 [30:01<05:40,  1.20s/it]

Epoch [2/7], Step [1501/1783], Loss: 5.7160


Epoch 2/7:  90%|████████▉ | 1601/1784 [32:00<03:24,  1.12s/it]

Epoch [2/7], Step [1601/1783], Loss: 5.7147


Epoch 2/7:  95%|█████████▌| 1701/1784 [34:00<01:36,  1.16s/it]

Epoch [2/7], Step [1701/1783], Loss: 5.7162


Epoch 2/7: 100%|█████████▉| 1783/1784 [35:37<00:01,  1.20s/it]


break
Epoch [2/7] Average Loss: 5.7121


Epoch 3/7:   0%|          | 1/1784 [00:01<36:57,  1.24s/it]

Epoch [3/7], Step [1/1783], Loss: 5.7163


Epoch 3/7:   6%|▌         | 101/1784 [02:01<33:39,  1.20s/it]

Epoch [3/7], Step [101/1783], Loss: 5.7118


Epoch 3/7:  11%|█▏        | 201/1784 [04:00<31:59,  1.21s/it]

Epoch [3/7], Step [201/1783], Loss: 5.7163


Epoch 3/7:  17%|█▋        | 301/1784 [06:01<29:03,  1.18s/it]

Epoch [3/7], Step [301/1783], Loss: 5.7121


Epoch 3/7:  22%|██▏       | 401/1784 [08:03<26:48,  1.16s/it]

Epoch [3/7], Step [401/1783], Loss: 5.7102


Epoch 3/7:  28%|██▊       | 501/1784 [10:04<26:19,  1.23s/it]

Epoch [3/7], Step [501/1783], Loss: 5.7119


Epoch 3/7:  34%|███▎      | 601/1784 [12:06<24:06,  1.22s/it]

Epoch [3/7], Step [601/1783], Loss: 5.7136


Epoch 3/7:  39%|███▉      | 701/1784 [14:07<22:33,  1.25s/it]

Epoch [3/7], Step [701/1783], Loss: 5.7111


Epoch 3/7:  45%|████▍     | 801/1784 [16:08<19:11,  1.17s/it]

Epoch [3/7], Step [801/1783], Loss: 5.7150


Epoch 3/7:  51%|█████     | 901/1784 [18:07<17:59,  1.22s/it]

Epoch [3/7], Step [901/1783], Loss: 5.7135


Epoch 3/7:  56%|█████▌    | 1001/1784 [20:08<15:29,  1.19s/it]

Epoch [3/7], Step [1001/1783], Loss: 5.7149


Epoch 3/7:  62%|██████▏   | 1101/1784 [22:08<13:54,  1.22s/it]

Epoch [3/7], Step [1101/1783], Loss: 5.7147


Epoch 3/7:  67%|██████▋   | 1201/1784 [24:08<11:37,  1.20s/it]

Epoch [3/7], Step [1201/1783], Loss: 5.7066


Epoch 3/7:  73%|███████▎  | 1301/1784 [26:08<09:36,  1.19s/it]

Epoch [3/7], Step [1301/1783], Loss: 5.7112


Epoch 3/7:  79%|███████▊  | 1401/1784 [28:10<07:46,  1.22s/it]

Epoch [3/7], Step [1401/1783], Loss: 5.7115


Epoch 3/7:  84%|████████▍ | 1501/1784 [30:11<05:59,  1.27s/it]

Epoch [3/7], Step [1501/1783], Loss: 5.7147


Epoch 3/7:  90%|████████▉ | 1601/1784 [32:12<03:51,  1.27s/it]

Epoch [3/7], Step [1601/1783], Loss: 5.7116


Epoch 3/7:  95%|█████████▌| 1701/1784 [34:13<01:47,  1.29s/it]

Epoch [3/7], Step [1701/1783], Loss: 5.7100


Epoch 3/7: 100%|█████████▉| 1783/1784 [35:51<00:01,  1.21s/it]


break
Epoch [3/7] Average Loss: 5.7120


Epoch 4/7:   0%|          | 1/1784 [00:01<36:25,  1.23s/it]

Epoch [4/7], Step [1/1783], Loss: 5.7142


Epoch 4/7:   6%|▌         | 101/1784 [02:02<35:15,  1.26s/it]

Epoch [4/7], Step [101/1783], Loss: 5.7109


Epoch 4/7:  11%|█▏        | 201/1784 [04:03<31:08,  1.18s/it]

Epoch [4/7], Step [201/1783], Loss: 5.7069


Epoch 4/7:  17%|█▋        | 301/1784 [06:03<29:30,  1.19s/it]

Epoch [4/7], Step [301/1783], Loss: 5.7148


Epoch 4/7:  22%|██▏       | 401/1784 [08:04<27:57,  1.21s/it]

Epoch [4/7], Step [401/1783], Loss: 5.7114


Epoch 4/7:  28%|██▊       | 501/1784 [10:14<26:18,  1.23s/it]

Epoch [4/7], Step [501/1783], Loss: 5.7095


Epoch 4/7:  34%|███▎      | 601/1784 [12:14<24:08,  1.22s/it]

Epoch [4/7], Step [601/1783], Loss: 5.7121


Epoch 4/7:  39%|███▉      | 701/1784 [14:16<22:44,  1.26s/it]

Epoch [4/7], Step [701/1783], Loss: 5.7142


Epoch 4/7:  45%|████▍     | 801/1784 [16:18<19:27,  1.19s/it]

Epoch [4/7], Step [801/1783], Loss: 5.7122


Epoch 4/7:  51%|█████     | 901/1784 [18:18<17:31,  1.19s/it]

Epoch [4/7], Step [901/1783], Loss: 5.7140


Epoch 4/7:  56%|█████▌    | 1001/1784 [20:19<15:25,  1.18s/it]

Epoch [4/7], Step [1001/1783], Loss: 5.7059


Epoch 4/7:  62%|██████▏   | 1101/1784 [22:20<13:26,  1.18s/it]

Epoch [4/7], Step [1101/1783], Loss: 5.7076


Epoch 4/7:  67%|██████▋   | 1201/1784 [24:21<11:45,  1.21s/it]

Epoch [4/7], Step [1201/1783], Loss: 5.7111


Epoch 4/7:  73%|███████▎  | 1301/1784 [26:20<09:56,  1.23s/it]

Epoch [4/7], Step [1301/1783], Loss: 5.7184


Epoch 4/7:  79%|███████▊  | 1401/1784 [28:21<07:40,  1.20s/it]

Epoch [4/7], Step [1401/1783], Loss: 5.7107


Epoch 4/7:  84%|████████▍ | 1501/1784 [30:23<05:45,  1.22s/it]

Epoch [4/7], Step [1501/1783], Loss: 5.7134


Epoch 4/7:  90%|████████▉ | 1601/1784 [32:23<03:42,  1.22s/it]

Epoch [4/7], Step [1601/1783], Loss: 5.7096


Epoch 4/7:  95%|█████████▌| 1701/1784 [34:24<01:46,  1.28s/it]

Epoch [4/7], Step [1701/1783], Loss: 5.7133


Epoch 4/7: 100%|█████████▉| 1783/1784 [36:04<00:01,  1.21s/it]


break
Epoch [4/7] Average Loss: 5.7120


Epoch 5/7:   0%|          | 1/1784 [00:01<35:25,  1.19s/it]

Epoch [5/7], Step [1/1783], Loss: 5.7101


Epoch 5/7:   6%|▌         | 101/1784 [02:02<34:56,  1.25s/it]

Epoch [5/7], Step [101/1783], Loss: 5.7140


Epoch 5/7:  11%|█▏        | 201/1784 [04:02<31:25,  1.19s/it]

Epoch [5/7], Step [201/1783], Loss: 5.7127


Epoch 5/7:  17%|█▋        | 301/1784 [06:03<28:31,  1.15s/it]

Epoch [5/7], Step [301/1783], Loss: 5.7085


Epoch 5/7:  22%|██▏       | 401/1784 [08:03<27:16,  1.18s/it]

Epoch [5/7], Step [401/1783], Loss: 5.7079


Epoch 5/7:  28%|██▊       | 501/1784 [10:06<27:02,  1.26s/it]

Epoch [5/7], Step [501/1783], Loss: 5.7114


Epoch 5/7:  34%|███▎      | 601/1784 [12:08<23:57,  1.22s/it]

Epoch [5/7], Step [601/1783], Loss: 5.7133


Epoch 5/7:  39%|███▉      | 701/1784 [14:08<22:21,  1.24s/it]

Epoch [5/7], Step [701/1783], Loss: 5.7070


Epoch 5/7:  45%|████▍     | 801/1784 [16:12<20:33,  1.26s/it]

Epoch [5/7], Step [801/1783], Loss: 5.7102


Epoch 5/7:  51%|█████     | 901/1784 [18:14<17:03,  1.16s/it]

Epoch [5/7], Step [901/1783], Loss: 5.7091


Epoch 5/7:  56%|█████▌    | 1001/1784 [20:14<15:15,  1.17s/it]

Epoch [5/7], Step [1001/1783], Loss: 5.7054


Epoch 5/7:  62%|██████▏   | 1101/1784 [22:14<13:45,  1.21s/it]

Epoch [5/7], Step [1101/1783], Loss: 5.7164


Epoch 5/7:  67%|██████▋   | 1201/1784 [24:16<11:51,  1.22s/it]

Epoch [5/7], Step [1201/1783], Loss: 5.7098


Epoch 5/7:  73%|███████▎  | 1301/1784 [26:16<09:50,  1.22s/it]

Epoch [5/7], Step [1301/1783], Loss: 5.7184


Epoch 5/7:  79%|███████▊  | 1401/1784 [28:17<07:56,  1.24s/it]

Epoch [5/7], Step [1401/1783], Loss: 5.7110


Epoch 5/7:  84%|████████▍ | 1501/1784 [30:18<05:31,  1.17s/it]

Epoch [5/7], Step [1501/1783], Loss: 5.7057


Epoch 5/7:  90%|████████▉ | 1601/1784 [32:21<03:44,  1.23s/it]

Epoch [5/7], Step [1601/1783], Loss: 5.7129


Epoch 5/7:  95%|█████████▌| 1701/1784 [34:22<01:34,  1.14s/it]

Epoch [5/7], Step [1701/1783], Loss: 5.7083


Epoch 5/7: 100%|█████████▉| 1783/1784 [36:02<00:01,  1.21s/it]


break
Epoch [5/7] Average Loss: 5.7118


Epoch 6/7:   0%|          | 1/1784 [00:01<35:29,  1.19s/it]

Epoch [6/7], Step [1/1783], Loss: 5.7078


Epoch 6/7:   6%|▌         | 101/1784 [02:02<34:38,  1.24s/it]

Epoch [6/7], Step [101/1783], Loss: 5.7125


Epoch 6/7:  11%|█▏        | 201/1784 [04:05<32:05,  1.22s/it]

Epoch [6/7], Step [201/1783], Loss: 5.7141


Epoch 6/7:  17%|█▋        | 301/1784 [06:09<30:22,  1.23s/it]

Epoch [6/7], Step [301/1783], Loss: 5.7093


Epoch 6/7:  22%|██▏       | 401/1784 [08:11<27:36,  1.20s/it]

Epoch [6/7], Step [401/1783], Loss: 5.7106


Epoch 6/7:  28%|██▊       | 501/1784 [10:14<26:12,  1.23s/it]

Epoch [6/7], Step [501/1783], Loss: 5.7130


Epoch 6/7:  34%|███▎      | 601/1784 [12:17<25:37,  1.30s/it]

Epoch [6/7], Step [601/1783], Loss: 5.7094


Epoch 6/7:  39%|███▉      | 701/1784 [14:19<21:38,  1.20s/it]

Epoch [6/7], Step [701/1783], Loss: 5.7119


Epoch 6/7:  45%|████▍     | 801/1784 [16:19<19:53,  1.21s/it]

Epoch [6/7], Step [801/1783], Loss: 5.7096


Epoch 6/7:  51%|█████     | 901/1784 [18:22<18:43,  1.27s/it]

Epoch [6/7], Step [901/1783], Loss: 5.7150


Epoch 6/7:  56%|█████▌    | 1001/1784 [20:24<15:56,  1.22s/it]

Epoch [6/7], Step [1001/1783], Loss: 5.7050


Epoch 6/7:  62%|██████▏   | 1101/1784 [22:26<13:53,  1.22s/it]

Epoch [6/7], Step [1101/1783], Loss: 5.7167


Epoch 6/7:  67%|██████▋   | 1201/1784 [24:28<11:53,  1.22s/it]

Epoch [6/7], Step [1201/1783], Loss: 5.7179


Epoch 6/7:  73%|███████▎  | 1301/1784 [26:28<09:41,  1.20s/it]

Epoch [6/7], Step [1301/1783], Loss: 5.7127


Epoch 6/7:  79%|███████▊  | 1401/1784 [28:31<07:59,  1.25s/it]

Epoch [6/7], Step [1401/1783], Loss: 5.7161


Epoch 6/7:  84%|████████▍ | 1501/1784 [30:32<05:43,  1.22s/it]

Epoch [6/7], Step [1501/1783], Loss: 5.7113


Epoch 6/7:  90%|████████▉ | 1601/1784 [32:34<03:48,  1.25s/it]

Epoch [6/7], Step [1601/1783], Loss: 5.7095


Epoch 6/7:  95%|█████████▌| 1701/1784 [34:39<01:39,  1.20s/it]

Epoch [6/7], Step [1701/1783], Loss: 5.7101


Epoch 6/7: 100%|█████████▉| 1783/1784 [36:19<00:01,  1.22s/it]


break
Epoch [6/7] Average Loss: 5.7117


Epoch 7/7:   0%|          | 1/1784 [00:01<35:54,  1.21s/it]

Epoch [7/7], Step [1/1783], Loss: 5.7095


Epoch 7/7:   6%|▌         | 101/1784 [02:03<34:12,  1.22s/it]

Epoch [7/7], Step [101/1783], Loss: 5.7107


Epoch 7/7:  11%|█▏        | 201/1784 [04:05<31:40,  1.20s/it]

Epoch [7/7], Step [201/1783], Loss: 5.7138


Epoch 7/7:  17%|█▋        | 301/1784 [06:06<30:23,  1.23s/it]

Epoch [7/7], Step [301/1783], Loss: 5.7161


Epoch 7/7:  22%|██▏       | 401/1784 [08:07<27:00,  1.17s/it]

Epoch [7/7], Step [401/1783], Loss: 5.7149


Epoch 7/7:  26%|██▋       | 469/1784 [09:29<26:49,  1.22s/it]

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(loss_123)

In [None]:
1

In [None]:
from sklearn.metrics import accuracy_score, f1_score

In [27]:
len(dataset)

28539

In [None]:
model.eval()

# Créez un DataLoader pour le jeu de données de test
dataset = TempLibriSpeech.LibriSpeech(split="test-clean", target_length=480000, device='cuda')
data_loader = DataLoader(dataset, batch_size=16, shuffle=False)

all_quantized_reps = []
all_contextualized_reps = []

# Calcul de la perte de reconstruction ou de quantification
total_loss = 0
with torch.no_grad():
    for batch_idx, batch in enumerate(data_loader):
        inputs, _ = batch  # Ici, _ signifie qu'il n'y a pas de labels
        
        inputs = inputs.to('cuda')  # Assurez-vous que les inputs sont sur le bon device

        # Passe avant
        quantized_repr, contextualized_reps, loss = model(inputs)

        # Ajout des représentations à la liste
        all_quantized_reps.append(quantized_repr.cpu().numpy())
        all_contextualized_reps.append(contextualized_reps.cpu().numpy())

        # Accumuler la perte
        total_loss += loss.item()

# Calcul de la moyenne de la perte sur l'ensemble du dataset
average_loss = total_loss / len(data_loader)

# Affichage des résultats
print(f"Average Loss: {average_loss:.4f}")


In [None]:
average_loss = total_loss / len(data_loader)

# Affichage des résultats
print(f"Average Loss: {average_loss:.4f}")

In [None]:
inputs

In [16]:
torch.cuda.empty_cache()