# Import packages needed

In [17]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import librosa


In [18]:
# Clear GPU memory
torch.cuda.empty_cache()

import gc

# Garbage collect to remove any other lingering objects
gc.collect()

# Verify if GPU memory is cleared
!nvidia-smi

Sun May 26 14:36:52 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:04:01.0 Off |                    0 |
| N/A   61C    P0              30W /  72W |   2398MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [19]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Using GPU:", torch.cuda.get_device_name(device))
else:
    device = torch.device("cpu")
    print("GPU is not available. Using CPU.")
# device = torch.device("cpu")

GPU is available. Using GPU: NVIDIA L4


# Load data and define collate function

In [20]:
def load_audio(file_path, target_sr=16000):
    # Load the audio file
    audio, sr = librosa.load(file_path, sr=target_sr)
    return audio

from torch.nn.utils.rnn import pad_sequence

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
speech_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Load text tokenizer and text model
text_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text_model = BertModel.from_pretrained("bert-base-uncased")
# speech_model.to("cuda")

def collate_batch(batch):
    speech_model.to("cuda")
    audio_inputs = []
    transcriptions = []

    for item in batch:
        audio = load_audio(item['audio']['path'])
        audio_tensor = torch.tensor(audio).float()  # Ensure data is float

        # Ensure the tensor is 1D
        if audio_tensor.dim() != 1:
            audio_tensor = audio_tensor.squeeze()  # Remove any singleton dimensions
        if audio_tensor.dim() == 0:
            audio_tensor = audio_tensor.unsqueeze(0)  # Handle rare case of a single sample

        audio_inputs.append(audio_tensor)
        transcriptions.append(item['transcription'])

        # Process transcriptions
        #print("transcriptions collate", transcriptions)
    transcription_inputs = text_tokenizer(transcriptions, return_tensors="pt", padding=True, truncation=True)
    audio_inputs_padded = pad_sequence(audio_inputs, batch_first=True, padding_value=0.0)
    audio_inputs_processed = processor(audio_inputs_padded, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to("cuda")
    # print(audio_inputs_processed.shape)
    if audio_inputs_processed.dim() == 3 and audio_inputs_processed.shape[0] == 1:
        audio_inputs_processed = audio_inputs_processed.squeeze(0)
    speech_embeddings = speech_model(input_values=audio_inputs_processed).last_hidden_state.mean(dim=1)
    text_embeddings = text_model(**transcription_inputs).last_hidden_state[:, 0, :]
    del audio_inputs_processed
    torch.cuda.empty_cache()
    return {
        'audio': speech_embeddings,  # Now properly padded
        'transcription': text_embeddings
    }


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:


# Define your DataLoader

# Load dataset and split
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
split_dataset = dataset.train_test_split(test_size=0.20)
train_dataset = split_dataset['train']
valid_dataset = split_dataset['test']

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=collate_batch)


In [22]:
class TransformationModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(TransformationModel, self).__init__()
        self.linear = nn.Linear(input_dim, 128)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, output_dim)
    def forward(self, x):
        out = self.linear(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

# Update the initialization of the transformation model
transform_model = TransformationModel(768, 768)  # Adjust input_dim to 1024
optimizer = torch.optim.Adam(transform_model.parameters(), lr=1e-4)

In [23]:
def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    total_batches = 0

    with torch.no_grad():
        for batch in dataloader:

            speech_embeddings = batch['audio'].to(device)
            text_embeddings = batch['transcription'].to(device)
            # Compute loss
            transformed_speech_emb = model(speech_embeddings)
            loss = nn.MSELoss()(transformed_speech_emb, text_embeddings)

            total_loss += loss.item()
            total_batches += 1

            del speech_embeddings, text_embeddings
    average_loss = total_loss / total_batches
    return average_loss


In [24]:
loss_fn = nn.MSELoss()
transform_model.to(device)

for epoch in range(10):
    torch.cuda.empty_cache()
    transform_model.train()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        speech_embeddings = batch['audio'].to(device)
        text_embeddings = batch['transcription'].to(device)
        transformed_speech_emb = transform_model(speech_embeddings)
        loss = loss_fn(transformed_speech_emb, text_embeddings)
        loss.backward()
        optimizer.step()

        del speech_embeddings, text_embeddings
        
        if step % 100 == 0:
            val_loss = evaluate(transform_model, valid_dataloader)
            print(f"Validation Loss after {step+1} steps: {val_loss}")
            torch.save(transform_model.state_dict(), "out/transform_model.pth")



Validation Loss after 1 steps: 0.27566812069792496
Validation Loss after 101 steps: 0.16079395408170266
Validation Loss after 201 steps: 0.06162018877895255
Validation Loss after 1 steps: 0.05297326211605156
Validation Loss after 101 steps: 0.04130614235212928
Validation Loss after 201 steps: 0.03869942601835519
Validation Loss after 1 steps: 0.038414167305618
Validation Loss after 101 steps: 0.03788424949896963
Validation Loss after 201 steps: 0.036764477134535185
Validation Loss after 1 steps: 0.036831275674334744
Validation Loss after 101 steps: 0.035892260702032795
Validation Loss after 201 steps: 0.03555959668990813
Validation Loss after 1 steps: 0.03589983359632785
Validation Loss after 101 steps: 0.03463627244427539
Validation Loss after 201 steps: 0.034224107175281175
Validation Loss after 1 steps: 0.03424113302638656
Validation Loss after 101 steps: 0.03366027628643471
Validation Loss after 201 steps: 0.03340135825177034
Validation Loss after 1 steps: 0.03354299192627271


KeyboardInterrupt: 

In [None]:
# Print the first batch to see the data structure
first_batch = next(iter(train_dataloader))
print(first_batch.keys())  # Check top-level keys
#print(first_batch['audio'].keys())  # Check audio data structure
#print(first_batch['text'].keys())  # Replace 'text' with the correct key if different