In [None]:
!pip -q install bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch

device = torch.device('cuda')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/SpeechDiffusion/speechcoco-data.zip

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

wav2cap = pd.read_csv('data/wav2capt.txt', sep = ' ', header=None)

In [None]:
wav2cap_dict = dict(zip(wav2cap[0], wav2cap[1]))

In [None]:
# x, _ = sf.read(f'data/wavs/{wav2cap[0].values[0]}', samplerate=None)
# print(len(processor(x).input_values[0]))

# x, _ = sf.read(f'data/wavs/{wav2cap[0].values[1]}', samplerate=None)
# print(len(processor(x).input_values[0]))

In [None]:
# raw_lengths = []
# lengths = []
# processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
# print(len(wav2cap[0].values))

# for i in wav2cap[0].values:
#   print(i)
#   raw, _ = sf.read(f"data/wavs/{i}")
#   raw_lengths.append(len(raw))

#   processed = processor(raw, return_tensors="pt").input_values[0]
#   lengths.append(len(processed))

In [None]:
# samples_500 = list(set(wav2cap[1]))[:500]

with open('/content/drive/MyDrive/SpeechDiffusion/samples500.txt', 'r') as file:
  samples_500 = [line.strip() for line in file.readlines()]

In [None]:
subset500 = wav2cap[wav2cap[1].isin(samples_500)]
wav2cap_subset500_dict = dict(zip(subset500[0], subset500[1]))

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import soundfile as sf

from torchvision import transforms
from transformers import AutoProcessor, HubertModel
from torch.nn.utils.rnn import pad_sequence

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")

# Get the number of available processors
num_processors = os.cpu_count()

batch_size = 8

def collate_batch(batch):
    images, audio_sequences = zip(*batch)
    images = torch.stack(images, 0)

    # Pad the audio sequences to have the same length
    audio_sequences = pad_sequence(audio_sequences, batch_first=True)

    return images, audio_sequences

# Example dataset class
class ImageTextDataset(Dataset):
    def __init__(self, wav2cap_dict, speech_processor, image_size=224, data_path='data'):
        self.caption_filenames = list(wav2cap_dict.keys())
        self.transform = transforms.Compose([
                    transforms.Resize((image_size, image_size)),
                    transforms.ToTensor()])

        self.wav2cap_dict = wav2cap_dict
        self.data_path = data_path
        self.speech_processor = speech_processor

    def __len__(self):
        return len(self.caption_filenames)

    def __getitem__(self, index):

        item = self.caption_filenames[index]

        # IMAGE
        image = Image.open(os.path.join(self.data_path, 'images', self.wav2cap_dict.get(item)))
        image = self.transform(image)

        # SPEECH
        raw_speech, sampling_rate = sf.read(os.path.join(self.data_path, 'wavs', item))
        speech_output = self.speech_processor(raw_speech, return_tensors="pt", sampling_rate=sampling_rate).input_values.squeeze(0)
        # print(speech_output[0].shape)
        # return image, raw_speech

        return image, speech_output

# dataset = ImageTextDataset(wav2cap_dict, processor)
dataset = ImageTextDataset(wav2cap_dict, processor)
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        collate_fn=collate_batch,
                        num_workers=num_processors,
                        pin_memory=True)

### **Defining Speech Model**

In [None]:
import torch.nn as nn

class SpeechProjection(nn.Module):
    def __init__(self, speech_embedding_size, shared_embedding_size, dropout=0.1):
        super(SpeechProjection, self).__init__()
        self.speech_projection = nn.Linear(speech_embedding_size, shared_embedding_size)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(shared_embedding_size, shared_embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(shared_embedding_size)

    def forward(self, text_embeddings):
        projected_embeddings = self.speech_projection(text_embeddings)

        x = self.gelu(projected_embeddings)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected_embeddings
        x = self.layer_norm(x)

        return x # projected_embeddings

### **Defining Image Model**

In [None]:
from transformers import AutoImageProcessor, ViTModel

# image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

# inputs = image_processor(image, return_tensors="pt")

# with torch.no_grad():
#     outputs = model(**inputs)

# last_hidden_states = outputs.last_hidden_state
# list(last_hidden_states.shape)

In [None]:
class ImageProjection(nn.Module):
    def __init__(self, image_embedding_size, shared_embedding_size, dropout=0.1):
        super(ImageProjection, self).__init__()
        self.image_projection = nn.Linear(image_embedding_size, shared_embedding_size)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(shared_embedding_size, shared_embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(shared_embedding_size)

    def forward(self, image_embeddings):
        projected_embeddings = self.image_projection(image_embeddings)

        x = self.gelu(projected_embeddings)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected_embeddings
        x = self.layer_norm(x)

        return x # projected_embeddings

### **CLIP**

In [None]:
import torch.nn.functional as F

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

temperature_value = 1

def contrastive_clip_loss_function(text_projection,  image_projection, mode="eval"):
    logits = (text_projection @ image_projection.T) / temperature_value
    if mode=="train":
        images_similarity = image_projection @ image_projection.T
        texts_similarity = text_projection @ text_projection.T
        targets = F.softmax( (images_similarity + texts_similarity) / 2 * temperature_value, dim=-1 )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()
    elif mode=="eval":
        return logits
    else:
        print("Mention mode")
        return None

### **Training**

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import torch.optim as optim
import itertools
import bitsandbytes as bnb

from transformers import Wav2Vec2Model
from torch.optim import AdamW

PATH = '/content/drive/MyDrive/checkpoints3.pt'

shared_embedding_size = 512
num_epochs = 15

# scaler = torch.cuda.amp.GradScaler()

# IMAGE
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
vit_output_shape = 768

image_projection = ImageProjection(image_embedding_size = vit_output_shape,
                                   shared_embedding_size = shared_embedding_size).to(device)

# SPEECH

# hubert = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(device)
# hubert_output_shape = 768

wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to(device)
wav2vec2_output_shape = 768

# for param in hubert.feature_extractor.parameters():
#     param.requires_grad = False

# for param in hubert.feature_projection.parameters():
#     param.requires_grad = False

speech_projection = SpeechProjection(speech_embedding_size = wav2vec2_output_shape,
                                     shared_embedding_size = shared_embedding_size).to(device)

# Parameters

# params = [{'params': vit.parameters(), 'lr':1e-4},
#           {'params': itertools.chain(image_projection.parameters(),
#                                      speech_projection.parameters()), 'lr':1e-3, 'weight_decay':1e-3},
#           {'params': hubert.parameters(), 'lr':1e-3}]

params = [{'params': itertools.chain(image_projection.parameters(),
                                     speech_projection.parameters()), 'lr':3e-3, 'weight_decay':1e-3},
          {'params': wav2vec2.parameters(), 'lr':3e-3}]

# optimizer = bnb.optim.AdamW8bit(params)
optimizer = AdamW(params)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1 , factor=0.8, )

# - - - - - - - - -  Training loop  - - - - - - - - -

for epoch in range(num_epochs):
  print(f"Epoch: {epoch+1}")
  vit.train()
  wav2vec2.train()
  image_projection.train()
  speech_projection.train()
  total_loss = 0.0

  for batch_idx, (images, speech) in enumerate(dataloader):
    # - - - - - - - - -  Forward pass  - - - - - - - - -
    optimizer.zero_grad()


    speech = speech.to(device)
    speech_outputs = wav2vec2(speech)
    speech_emb = speech_projection(speech_outputs.last_hidden_state.mean(dim=1))

    images = images.to(device)
    image_outputs = vit(images)
    image_emb = image_projection(image_outputs.last_hidden_state.mean(dim=1))

    loss = contrastive_clip_loss_function(speech_emb, image_emb, mode="train")

    loss.backward()  # Scale the loss to prevent underflow
    optimizer.step()

    # - - - -  Text  - - - -
    speech = speech.to(device)
    speech = wav2vec2(speech)
    speech = speech_projection(speech.last_hidden_state.mean(dim=1))

    # - - - -  Image  - - - -
    images = images.to(device)
    images = vit(images)
    images = image_projection(images.last_hidden_state.mean(dim=1))

    # - - - -  Compute Loss  - - - -
    loss = contrastive_clip_loss_function(speech, images, mode="train")

    # - - - -  Backpropagation  - - - -
    loss.backward()
    optimizer.step()
    # scaler.scale(loss).backward()
    # scaler.step(optimizer)
    # scaler.update()

    # - - - - Loss print - - - -
    total_loss += loss.item()
    if batch_idx%100==0:
      print(f"Batch {batch_idx}/{len(dataloader)}, Loss:{total_loss/((batch_idx+1)*batch_size)}")
    torch.cuda.empty_cache()


  avg_loss = total_loss / (len(dataloader) * batch_size)
  lr_scheduler.step(avg_loss)
  print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

  torch.save({
        'epoch': epoch,
        'hubert_state_dict': wav2vec2.state_dict(),
        'image_projection_state_dict': image_projection.state_dict(),
        'speech_projection_state_dict': speech_projection.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        }, PATH)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1


  self.pid = os.fork()


Batch 0/313, Loss:8.154058456420898
Batch 100/313, Loss:1.7144919129291383
Batch 200/313, Loss:1.0697100924615244
Batch 300/313, Loss:0.8163890529708608
Epoch [1/15], Average Loss: 0.7957
Epoch: 2
Batch 0/313, Loss:0.3013402223587036
Batch 100/313, Loss:0.2774483220707072
Batch 200/313, Loss:0.27793322221853245
Batch 300/313, Loss:0.27582898330054806
Epoch [2/15], Average Loss: 0.2757
Epoch: 3
Batch 0/313, Loss:0.28349706530570984
Batch 100/313, Loss:0.2710456057350234
Batch 200/313, Loss:0.2716118312297176
Batch 300/313, Loss:0.270365630075385
Epoch [3/15], Average Loss: 0.2698
Epoch: 4
Batch 0/313, Loss:0.2640957832336426
Batch 100/313, Loss:0.2659936424824271
Batch 200/313, Loss:0.2658328067752259
Batch 300/313, Loss:0.2656818027314158
Epoch [4/15], Average Loss: 0.2653
Epoch: 5
Batch 0/313, Loss:0.25839895009994507
Batch 100/313, Loss:0.2646292191330749
Batch 200/313, Loss:0.2640813586130664
Batch 300/313, Loss:0.2636996157640635
Epoch [5/15], Average Loss: 0.2634
Epoch: 6
Batch 0/

In [None]:
torch.save(wav2vec2.state_dict(), '/content/drive/MyDrive/speechmodel-wav2vec2.pt')
torch.save(image_projection.state_dict(), '/content/drive/MyDrive/image_projection-w2v.pt')
torch.save(speech_projection.state_dict(), '/content/drive/MyDrive/speech_projection3-w2v.pt')

In [None]:
vit.eval()
image_projection.eval()

def create_image_embeddings(images):
    with torch.no_grad():
        image_embeddings = vit(images)
        image_projection = image_projection(image_embeddings)
    return image_projection

image_embeddings_list_train = []

for index in range(len(dataset)):
    images = train_dataset[index][0]
    images = images.to(device)
    image_projection = create_image_embeddings(images.unsqueeze(0))
    image_embeddings_list_train.append( image_projection[0] )


def image_retrieval_function(input_query, n , display=False): # n --> number of images
    with torch.no_grad():
        inputs = tokenizer(input_query, return_tensors='pt', padding="max_length", max_length=max_length, truncation=True)
        inputs = inputs.to(device)
        outputs = text_model(**inputs)
        text_embeddings = outputs.last_hidden_state.mean(dim=1)
        text_projection = text_projector(text_embeddings)

    similarity_scores_list = []
    for index in tqdm(range(len(image_embeddings_list_train))):
        score = torch.dot( text_projection[0], image_embeddings_list_train[index] )
        similarity_scores_list.append( score.cpu().numpy() )

    max_indexes = np.array(similarity_scores_list).argsort()[-n:][::-1]
    if display:
        for index in max_indexes:
            image_tensor = train_dataset[index][0]
            plt.imshow( torch.moveaxis(image_tensor, 0,2) )
            plt.show()
        return None
    else:
        return max_indexes

NameError: name 'train_dataset' is not defined

In [None]:
vit.eval()
image_projection.eval()

with torch.no_grad():
  images_embeddings = torch.Tensor()
  for i, (image, speech) in enumerate(dataloader):
    outputs = vit(image.to(device))
    emb = image_projection(outputs.last_hidden_state.mean(dim=1)).cpu()
    torch.concat((images_embeddings, emb))
    print(f"{i}/{len(dataloader)}")

In [None]:
torch.concat((torch.Tensor(), emb))

tensor([[-0.0999, -0.0604, -0.1155,  ...,  0.5536, -1.2522, -0.3032],
        [-0.1201, -0.0741, -0.1117,  ...,  0.5133, -1.2332, -0.2707],
        [-0.0889, -0.0658, -0.0913,  ...,  0.5349, -1.2376, -0.2617],
        [-0.1038, -0.1076, -0.1018,  ...,  0.5128, -1.2385, -0.2707]])

In [None]:
torch.Tensor()

TypeError: new(): data must be a sequence (got NoneType)

In [None]:
torch.save(hubert.state_dict(), 'hubert_3epochs.pt')

In [None]:
!cp /content/hubert_3epochs.pt /content/drive/MyDrive/hubert_3epochs.pt

In [None]:
torch.load('/content/hubert_3epochs.pt')

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.75 GiB of which 1.06 MiB is free. Process 416813 has 14.74 GiB memory in use. Of the allocated memory 14.40 GiB is allocated by PyTorch, and 194.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
hubert.load_state_dict('/content/hubert_3epochs.pt')

TypeError: Expected state_dict to be dict-like, got <class 'str'>.

In [None]:
!sudo kill -9 416813

kill: (416813): No such process


In [None]:
print(torch.cuda.memory_stats())

OrderedDict([('active.all.allocated', 13536346), ('active.all.current', 1498), ('active.all.freed', 13534848), ('active.all.peak', 1901), ('active.large_pool.allocated', 8861357), ('active.large_pool.current', 503), ('active.large_pool.freed', 8860854), ('active.large_pool.peak', 709), ('active.small_pool.allocated', 4674989), ('active.small_pool.current', 995), ('active.small_pool.freed', 4673994), ('active.small_pool.peak', 1223), ('active_bytes.all.allocated', 57968531498496), ('active_bytes.all.current', 15175254016), ('active_bytes.all.freed', 57953356244480), ('active_bytes.all.peak', 15463658496), ('active_bytes.large_pool.allocated', 56703870142976), ('active_bytes.large_pool.current', 15108582912), ('active_bytes.large_pool.freed', 56688761560064), ('active_bytes.large_pool.peak', 15396592640), ('active_bytes.small_pool.allocated', 1264661355520), ('active_bytes.small_pool.current', 66671104), ('active_bytes.small_pool.freed', 1264594684416), ('active_bytes.small_pool.peak', 9