In [6]:
batch_size=2

spectrogran_size=[2,1,256,256]
video_size=[2,int(30*5),1,16,256]

# Define Model

In [35]:
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np

from transformers.models.vivit.modeling_vivit import VivitModel, VivitConfig, VivitLayer, VivitEncoder
from transformers.models.speecht5.modeling_speecht5 import SpeechT5Decoder, SpeechT5Config

class VivitTubeletEmbeddings(nn.Module):
    """
    Construct Vivit Tubelet embeddings.

    This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
    shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.

    The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
    (width // tubelet_size[2]).
    """

    def __init__(self, config):
        super().__init__()
        self.num_frames = config.num_frames
        self.image_size = config.image_size
        self.patch_size = config.tubelet_size
        self.num_patches = (
            (self.image_size[1] // self.patch_size[2]) # 256/16
            * (self.image_size[0] // self.patch_size[1]) # 16/4
            * (self.num_frames // self.patch_size[0]) # 32/2 
        )
        self.embed_dim = config.hidden_size

        self.projection = nn.Conv3d(
            config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
        )

    def forward(self, pixel_values):
        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height},{width}) doesn't match model ({self.image_size},{self.image_size})."
            )

        # permute to (batch_size, num_channels, num_frames, height, width)
        pixel_values = pixel_values.permute(0, 2, 1, 3, 4)

        x = self.projection(pixel_values)
        # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return x


class VivitEmbeddings(nn.Module):
    """
    Vivit Embeddings.

    Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
    """

    def __init__(self, config):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.patch_embeddings = VivitTubeletEmbeddings(config)

        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.config = config

    def forward(self, pixel_values):
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values)

        cls_tokens = self.cls_token.tile([batch_size, 1, 1])

        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # add positional encoding to each token
        embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings

class VivitPooler(nn.Module): #dont simply pool like this?
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output 
  
class AiSynthModel(nn.Module):
    def __init__(self, config, decoder, image_size=(16,256), tubelet_size=[2,4,16], num_frames = 32, dim = 192, num_layers=4, pool = 'mean', in_channels = 1, dim_head = 64, heads=4, dropout = 0.,
                 emb_dropout = 0., scale_dim = 4, ):
        super().__init__()
        
        config.hidden_size = dim
        config.num_channels = in_channels
        config.num_frames = num_frames
        config.tubelet_size = tubelet_size
        config.image_size = image_size
        config.num_attention_heads = heads
        config.num_hidden_layers = num_layers
        self.config = config

        self.vivit_embeddings = VivitEmbeddings(config)

        self.vivit_encoder = VivitEncoder(config)

        self.speech_t5_decoder = decoder

        
        #nn.Sequential
        #self.pooling = VivitPooler(config)

        #define an operation to get from shape 1024x128 to 


    def forward(self, x):
        x = self.vivit_embeddings(x)
        #from second last dimension, drop the class tokens bcs we aren't doing classification
        x = self.vivit_encoder(x).last_hidden_state[:,1:]
        #x = self.pooling(x)
        #downsample to size of spectrogram
        x = F.avg_pool2d(x, 2)
        #run through speech t5 decoder
        x = self.speech_t5_decoder(x).last_hidden_state

        return x

    

if __name__ == "__main__":
    
    img = torch.ones([1, 64, 1, 16, 256]).cuda()

    config_vivit = VivitConfig()
    conf_dict = {
    "activation_dropout": 0.1,
    "attention_dropout": 0.1,
    "decoder_attention_heads": 8,
    "decoder_ffn_dim": 3072,
    "decoder_layerdrop": 0.1,
    "decoder_layers": 6,
    "decoder_start_token_id": 2,
    "hidden_act": "gelu",
    "hidden_dropout": 0.1,
    "hidden_size": 256,
    "is_encoder_decoder": True,
    "layer_norm_eps": 1e-05,
    "mask_feature_length": 4,
    "mask_feature_min_masks": 0,
    "mask_feature_prob": 0.0,
    "mask_time_length": 4,
    "mask_time_min_masks": 2,
    "positional_dropout": 0.1,
    "transformers_version": "4.40.1",
    "use_guided_attention_loss": True,
    }
    config_speecht5 = SpeechT5Config(**conf_dict)

    decoder = SpeechT5Decoder(config_speecht5)    
    model = AiSynthModel(config_vivit, decoder, image_size=(16,256),tubelet_size=[2,8,32], num_frames = 64, dim = 512, num_layers=4).cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print('Trainable Parameters: %.3fM' % parameters)
    
    out = model(img)
    
    print("Shape of out :", out.shape)      # [B, num_classes]
    print("dtype of out :", out.dtype)      # float32

    

Trainable Parameters: 29.958M
Shape of out : torch.Size([1, 256, 256])
dtype of out : torch.float32


In [36]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from tqdm import tqdm
import torchaudio
from torchvision import transforms
#import wandb
import sys,os

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # self.max_waveform_length = max_waveform_length
        # self.max_num_frames = max_num_frames
        self.frames = [file for file in os.listdir(os.path.join(root_dir, 'frames_pt')) if file.endswith('.pt')]
        self.frames.sort()
        self.graytransform = torchvision.transforms.Grayscale()

    def __len__(self):
        return len(self.frames)


    def __getitem__(self, idx):
        wav_file = os.path.join(self.root_dir, 'wavs', self.frames[idx].replace('.pt', '.wav'))
        frames_path = os.path.join(self.root_dir, 'frames_pt', self.frames[idx])
        spectrogram_path = os.path.join(self.root_dir, 'spectrograms_pt', self.frames[idx])

        waveform, sample_rate = torchaudio.load(wav_file)

        frames = torch.load(frames_path).unsqueeze(1).unsqueeze(0)
        frames = self.graytransform(frames).float() / 127.5 - 1 # 0-1
        spectrogram = torch.load(spectrogram_path).unsqueeze(0)

        name = self.frames[idx].replace('.pt', '')
        return frames, spectrogram, waveform, name

def collate_fn(batch):
    frames, spectrogram, waveform, name = zip(*batch)
    dat = {'frames':torch.vstack(frames).to(device=device), 'spectrogram':  torch.vstack(spectrogram).to(device=device), 'wav': torch.vstack(waveform).to(device=device), 'name':name}
    return dat

#init model
batch_size=8
train_dataset = CustomDataset(root_dir='data/processed/train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
val_dataset = CustomDataset(root_dir='data/processed/val')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


In [24]:
from transformers import VivitConfig
#model_enc = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")

In [34]:
from tqdm import tqdm
import librosa, wandb

config_vivit = VivitConfig()
conf_dict = {
"activation_dropout": 0.1,
"attention_dropout": 0.1,
"decoder_attention_heads": 8,
"decoder_ffn_dim": 3072,
"decoder_layerdrop": 0.1,
"decoder_layers": 3,
"decoder_start_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 256,
"is_encoder_decoder": True,
"layer_norm_eps": 1e-05,
"mask_feature_length": 4,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 4,
"mask_time_min_masks": 2,
"positional_dropout": 0.1,
"transformers_version": "4.40.1",
"use_guided_attention_loss": True,
}
config_speecht5 = SpeechT5Config(**conf_dict)

decoder = SpeechT5Decoder(config_speecht5)    
model = AiSynthModel(config_vivit, decoder, image_size=(16,256),tubelet_size=[2,8,32], num_frames = 64, dim = 512).cuda()

parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)

#opt
optimizer = torch.optim.SGD(model.parameters(),lr=1e-2, momentum=0.9, weight_decay=1e-4)

#loss fun
criterion = nn.CrossEntropyLoss()


# wandb.init(
#     project="test_aisynth",
#     name="vid2audio",
#     job_type="training",
#     reinit=True)

# %% Fit the model
# Number of epochs
epochs = 10
train_losses = []
val_losses = []
step = 0
val_interval=50
#use tqdm to print train loss and val loss as updating instead of constantly printing


tqdm 

val_loss = 10000
for epoch in range(epochs):

    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - Train") as pbar:

        for dat in train_loader:

            model.train()

            out = model(dat['frames'])

            loss = criterion(out, dat['spectrogram'])
            #wandb.log({"train_loss": loss.detach().cpu().item()}, step=step)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix({'train_loss': f'{loss:.4f}','val_loss': f'{val_loss:.4f}'})
            pbar.update()
            if step % val_interval == 0:
                with torch.no_grad():
                    model.eval()

                    val_dat = next(iter(val_loader))

                    val_out = model(val_dat['frames'])

                    val_loss = criterion(val_out, val_dat['spectrogram'])
                    
                    #wandb.log({"val_loss": val_loss.detach().cpu().item()}, step=step)
                    resized_wavs = F.interpolate(val_out.unsqueeze(1), size=(122,122), mode='bilinear', align_corners=False)

                    out_wavs = librosa.feature.inverse.mel_to_audio(resized_wavs.cpu().squeeze().numpy(), 
                                            sr=16000, 
                                            n_fft=2048, 
                                            hop_length=512, 
                                            win_length=None, 
                                            window='hann', 
                                            center=False, 
                                            pad_mode='constant', 
                                            power=2.0, 
                                            n_iter=32)
                    #use first channel of wav

                    #cross entropy loss for reconstructed wav and "ground truth" wav
                    val_loss_wav = criterion(torch.tensor(out_wavs), val_dat['wav'].detach().cpu())
                    # wandb.log({"val_loss_wav": val_loss_wav.detach().cpu().item()}, step=step)


            step += 1





Trainable Parameters: 23.646M


Epoch 1/10 - Train: 100%|██████████| 22/22 [00:13<00:00,  1.68it/s, train_loss=44.9503, val_loss=175.4739]  
Epoch 2/10 - Train: 100%|██████████| 22/22 [00:03<00:00,  6.05it/s, train_loss=44.9324, val_loss=175.4739] 
Epoch 3/10 - Train: 100%|██████████| 22/22 [00:11<00:00,  1.85it/s, train_loss=44.8930, val_loss=175.4096] 
Epoch 4/10 - Train: 100%|██████████| 22/22 [00:03<00:00,  6.14it/s, train_loss=44.8543, val_loss=175.4096] 
Epoch 5/10 - Train: 100%|██████████| 22/22 [00:12<00:00,  1.80it/s, train_loss=44.8699, val_loss=175.4067] 
Epoch 6/10 - Train: 100%|██████████| 22/22 [00:03<00:00,  6.16it/s, train_loss=44.8648, val_loss=175.4067] 
Epoch 7/10 - Train: 100%|██████████| 22/22 [00:12<00:00,  1.71it/s, train_loss=44.8516, val_loss=175.3966] 
Epoch 8/10 - Train: 100%|██████████| 22/22 [00:03<00:00,  6.04it/s, train_loss=44.8425, val_loss=175.3966] 
Epoch 9/10 - Train: 100%|██████████| 22/22 [00:03<00:00,  5.70it/s, train_loss=44.8778, val_loss=175.3966] 
Epoch 10/10 - Train:  14%|█

KeyboardInterrupt: 

In [26]:


# %% Fit the model
# Number of epochs
epochs = 3
train_losses = []
val_losses = []
step = 0
for epoch in range(epochs):

    for dat in tqdm(train_loader):


        #wandb.log({"loss": loss.detach().cpu().item()}, step=step)
        train_losses.append(loss.detach().cpu().item())


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()




  0%|          | 0/22 [00:00<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

# Define model


Trainable Parameters: 29.958M
Shape of out : torch.Size([1, 256, 256])
dtype of out : torch.float32


In [68]:
SpeechT5Decoder(config_speecht5)

SpeechT5Decoder(
  (layers): ModuleList(
    (0-5): 6 x SpeechT5DecoderLayer(
      (self_attn): SpeechT5Attention(
        (k_proj): Linear(in_features=256, out_features=256, bias=True)
        (v_proj): Linear(in_features=256, out_features=256, bias=True)
        (q_proj): Linear(in_features=256, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): SpeechT5Attention(
        (k_proj): Linear(in_features=256, out_features=256, bias=True)
        (v_proj): Linear(in_features=256, out_features=256, bias=True)
        (q_proj): Linear(in_features=256, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (feed_forward): SpeechT5FeedFor

VivitConfig {
  "attention_probs_dropout_prob": 0.0,
  "hidden_act": "gelu_fast",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-06,
  "model_type": "vivit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_frames": 32,
  "num_hidden_layers": 12,
  "qkv_bias": true,
  "transformers_version": "4.40.1",
  "tubelet_size": [
    2,
    16,
    16
  ]
}