# Synthesize data

In [2]:
#uncomment if not installed already:
!pip install einops
!!apt-get install timidity
!pip install pretty_midi


Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m741.0 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.2-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting packaging~=23.1 (from mido>=1.1.16->pretty_midi)
  Downloading packaging-23.2-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.0/53.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for colle

In [11]:
#### Functions to synthesize data...



import argparse
import os
import subprocess
import matplotlib.pyplot as plt
import pretty_midi
import PIL
import PIL.Image
import numpy as np
import tqdm
import torch
from pathlib import Path
import torchaudio
import librosa
import shutil

accidentals = "flat"

white_notes = {0: "C", 2: "D", 4: "E", 5: "F", 7: "G", 9: "A", 11: "B"}
sharp_notes = {1: "C#", 3: "D#", 6: "F#", 8: "G#", 10: "A#"}
flat_notes  = {1: "Bb", 3: "Eb", 6: "Gb", 8: "Ab", 10: "Bb"}

white_notes_scale = {0: 0, 2: 1, 4: 2, 5: 3, 7: 4, 9: 5, 11: 6}

print_notes_to_stdout = False


def note_breakdown(midi_note):
    note_in_chromatic_scale = midi_note % 12
    octave = round((midi_note - note_in_chromatic_scale) / 12 - 1)

    return [note_in_chromatic_scale, octave]

def is_white_key(note):
    return (note % 12) in white_notes

def pixel_range(midi_note, image_width):
    # Returns the min and max x-values for a piano key, in pixels.

    width_per_white_key = image_width / 52

    if is_white_key(midi_note):
        [in_scale, octave] = note_breakdown(midi_note)
        offset = 0
        width = 1
    else:
        [in_scale, octave] = note_breakdown(midi_note - 1)
        offset = 0.5
        width = 0.5

    white_note_n = white_notes_scale[in_scale] + 7*octave - 5

    start_pixel = round(width_per_white_key*(white_note_n + offset)) + 1
    end_pixel    = round(width_per_white_key*(white_note_n + 1 + offset)) - 1

    if width != 1:
        mid_pixel = round(0.5*(start_pixel + end_pixel))
        half_pixel_width = 0.5*width_per_white_key
        half_pixel_width *= width

        start_pixel = round(mid_pixel - half_pixel_width)
        end_pixel    = round(mid_pixel + half_pixel_width)

    return [start_pixel, end_pixel]

def create_video(input_midi: str,
        image_width = 360,
        img_resize = 224, #resize to base2 number after using synthviz code
        image_height = 360,
        black_key_height = 2/3,
        piano_height=64,
        falling_note_color = [75, 105, 177],     # darker blue
        pressed_key_color = [220, 10, 10], # lighter blue
        vertical_speed = 1/4, # Speed in main-image-heights per second
        fps = 16,
        end_t = 0.0,
        silence = 0.0,
        sample_rate = 16000,
        split="train"
    ):

    midi_save_name = input_midi.split('/')[-1].split('.')[0]  #(str_list[0] + '~' + str_list[1]).split('.')[0]
    frames_folder = os.path.join( f"data/processed/{split}/frames", midi_save_name)
    piano_height = piano_height
    main_height = image_height - piano_height
    pixels_per_frame = main_height*vertical_speed / fps # (pix/image) * (images/s) / (frames / s) =


    note_names = {}

    for note in range(21, 109):
        [note_in_chromatic_scale, octave] = note_breakdown(note)

        if note_in_chromatic_scale in white_notes:
            note_names[note] = "{}{:d}".format(
                white_notes[note_in_chromatic_scale], octave)
        else:
            if accidentals == "flat":
                note_names[note] = "{}{:d}".format(
                    flat_notes[note_in_chromatic_scale], octave)
            else:
                note_names[note] = "{}{:d}".format(
                    sharp_notes[note_in_chromatic_scale], octave)

    # The 'notes' list will store each note played, with start and end
    # times in seconds.
    #print("Loading MIDI file:", input_midi)
    midi_data = pretty_midi.PrettyMIDI(input_midi)

    try:
        notes = [
            { "note": n.pitch, "start": n.start, "end": n.end}
            for n in midi_data.instruments[0].notes
        ]
    except:
        return



    #print(f"Loaded {len(notes)} notes from MIDI file")

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # ~ The rest of the code is about making the video. ~
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if not os.path.isdir(frames_folder):
        os.makedirs(frames_folder)

    # Delete all previous image frames:
    for f in os.listdir(frames_folder):
        os.remove("{}/{}".format(frames_folder, f))

    im_base = np.zeros((image_height, image_width, 3), dtype=np.uint8)

    # Draw the piano, and the grey lines next to the C's for the main area:
    key_start = image_height - piano_height #0
    white_key_end = image_height - 1
    black_key_end = round(image_height - (1-black_key_height)*piano_height)

    im_lines = im_base.copy()

    for i in range(21, 109):
        # draw white keys
        if is_white_key(i):
            [x0, x1] = pixel_range(i, image_width)
            im_base[key_start:white_key_end, x0:x1] = [255, 255, 255]

        # draw lines separating octaves
        if i % 12 == 0:
            im_lines[0:(key_start-1), (x0-2):(x0-1)] = [20, 20, 20]

    for i in range(21, 109):
        # draw black keys
        if not is_white_key(i):
            [x0, x1] = pixel_range(i, image_width)
            im_base[key_start:black_key_end, x0:x1] = [0, 0, 0]

    im_piano = im_base[key_start:white_key_end, :]

    im_frame = im_base.copy()
    im_frame += im_lines

    # Timidity (the old version that I have!) always starts the audio
    # at time = 0.  Add a second of silence to the start, and also
    # keep making frames for a second at the end:
    frame_start = 0 #min(note["start"] for note in notes) - silence
    end_t = max(note["end"] for note in notes) if end_t == 0.0 else end_t  #### set as constant

    #first frame init to start conditions
    img = PIL.Image.fromarray(im_frame)
    #img.save("{}/frame00000.png".format(frames_folder))


    # Rest of video:
    finished = False
    frame_ct = 0
    pixel_start = 0
    pixel_start_rounded = 0

    #print("[Step 1/3] Generating video frames from notes")


    #pbar = tqdm.tqdm(total=end_t, desc='Creating video')
    while not finished:
        prev_pixel_start_rounded = pixel_start_rounded
        pixel_start += pixels_per_frame
        pixel_start_rounded = round(pixel_start)

        pixel_increment = pixel_start_rounded - prev_pixel_start_rounded

        #pbar.update(1/fps)
        frame_start += 1/fps

        #pbar.set_description(f'Creating video [Frame count = {frame_ct}]')

        # # Copy most of the previous frame into the new frame:
        im_frame[pixel_increment:main_height, :] = im_frame[0:(main_height - pixel_increment), :]
        im_frame[0:pixel_increment, :] = im_lines[0:pixel_increment, :]
        im_frame[key_start:white_key_end, :] = im_piano
        # Which keys need to be colored?
        # TODO(jxm): put notes in some data structure or something to make this much faster
        keys_to_color = []
        for note in notes:
            if (note["start"]) <= frame_start <= note["end"]:
                keys_to_color.append(note["note"])

        # First color the white keys (this will cover some black-key pixels),
        # then re-draw the black keys either side,
        # then color the black keys.
        for note in keys_to_color:
            if is_white_key(note):
                [x0, x1] = pixel_range(note, image_width)
                im_frame[key_start:white_key_end, x0:x1] = pressed_key_color

        for note in keys_to_color:
            if is_white_key(note):
                if (not is_white_key(note - 1)) and (note > 21):
                    [x0, x1] = pixel_range(note - 1, image_width)
                    im_frame[key_start:black_key_end, x0:x1] = [0,0,0]

                if (not is_white_key(note + 1)) and (note < 108):
                    [x0, x1] = pixel_range(note + 1, image_width)
                    im_frame[key_start:black_key_end, x0:x1] = [0,0,0]

        for note in keys_to_color:
            if not is_white_key(note):
                [x0, x1] = pixel_range(note, image_width)
                im_frame[key_start:black_key_end, x0:x1] = pressed_key_color


        img = PIL.Image.fromarray(im_frame)

        #img = img.resize((img_resize,img_resize))

        frame_ct += 1
        img.save("{}/frame{:05d}.png".format(frames_folder, frame_ct))

        if frame_start >= end_t:
            finished = True


    #pbar.close()

    # print("[Step 2/3] Rendering MIDI to audio with Timidity")
    wav_path = f'data/processed/{split}/wavs/'+midi_save_name+'.wav'
    save_wav_cmd = f"timidity {input_midi} -OwM --preserve-silence -s {sample_rate} -A120 -o {wav_path} "
    save_wav_cmd = save_wav_cmd.split()
    # save_wav_cmd[1], save_wav_cmd[-1] = input_midi, sound_file
    subprocess.call(save_wav_cmd)

    #if wav length is longer than end_t * sample rate, cut the end of the wav off
    wav, sr = torchaudio.load(wav_path) # should be 25600 or anything base2.
    target_len = int(sr*(end_t))
    wav = pad_tensor(wav, (1,target_len))
    #save_wav
    torchaudio.save(wav_path, wav, sr)

    # print("Skipped - [Step 3/3] Rendering full video with ffmpeg")
    # #Running from a terminal, the long filter_complex argument needs to
    # #be in double-quotes, but the list form of subprocess.call requires
    # #_not_ double-quoting.

    spec = librosa.feature.melspectrogram(y=wav.numpy(),
                                        sr=sr,
                                            n_fft=4096,
                                            hop_length=512,
                                            win_length=None,
                                            window='hann',
                                            center=False,
                                            pad_mode='constant',
                                            power=2.0,
                                            fmin=1,
                                            fmax=5000, #pitch 109 / piano key 88 has frequency centered around 4200~ so just let fmax be 5000  https://inspiredacoustics.com/en/MIDI_note_numbers_and_center_frequencies
                                            n_mels=88) #88 piano keys, means 88 bins
    #print(spec.shape)
    #assert spec.shape[1] == spec.shape[2]
    spec = spec.squeeze()
    #spec = torch.nn.functional.interpolate(torch.tensor(spec).unsqueeze(0),size=(80,128)).squeeze().permute(1,0)
    #plt.imshow(spec)
    #plt.savefig(f'data/processed/{split}/spectrograms/{midi_save_name}.png')
    torch.save(spec,f'data/processed/{split}/spectrograms_pt/{midi_save_name}.pt')


    mp4_path = os.path.join(f"data/processed/{split}/videos", midi_save_name)
    ffmpeg_cmd = f"ffmpeg -loglevel quiet -framerate {fps} -i {frames_folder}/frame%05d.png -i {wav_path} -f lavfi -t {end_t} -i anullsrc -filter_complex [1]adelay={0}|{0}[aud];[2][aud]amix -c:v libx264 -vf fps={fps} -pix_fmt yuv420p -y -strict -2 {mp4_path}.mp4 "
    #print("> ffmpeg_cmd: ", ffmpeg_cmd)
    subprocess.call(ffmpeg_cmd.split())



def pad_tensor(tensor, target_shape):
    """
    Pad a tensor to a target shape with zeros if necessary.

    Args:
    - tensor: NumPy array, the input tensor
    - target_shape: tuple, the target shape

    Returns:
    - padded_tensor: Tensor
    """
    current_shape = tensor.shape
    if current_shape[1] < target_shape[1]:
        tensor = torch.nn.functional.pad(tensor, (0,target_shape[1] - current_shape[1]))
    elif current_shape[1] > target_shape[1]:
        tensor = tensor[:,:target_shape[1]]
    return tensor



# collect all frames into a torch.tensor .pt file of size (N,302,360,32,1)

def save_to_pt(path_to_midi,split="train"):

    for root, dirs, files in os.walk(path_to_midi):
        for file in tqdm.tqdm(files):
            if file.endswith('.mid'):

                frames_folder = f'data/processed/{split}/frames/'+ file.split('.')[0]

                frames = []
                for img_file in os.listdir(frames_folder):
                    if img_file.endswith('.png'):
                        img_path = os.path.join(frames_folder, img_file)
                        #load img
                        img = PIL.Image.open(img_path)
                        #to numpy
                        img = np.array(img)
                        ## to grey scale
                        ##img = img[:,:,0]
                        #to tensor
                        img_tensor = torch.from_numpy(img)
                        frames.append(img_tensor)

                frames = torch.stack(frames)
                torch.save(frames, f'data/processed/{split}/frames_pt/'+file.split('.')[0]+'.pt')

def make_synthetic(N,split,max_notes=3):

    if not os.path.exists(f'data/processed/{split}/midi'):
        os.makedirs(f'data/processed/{split}/midi')
    if not os.path.exists(f'data/processed/{split}/videos'):
        os.makedirs(f'data/processed/{split}/videos')
    if not os.path.exists(f'data/processed/{split}/wavs'):
        os.makedirs(f'data/processed/{split}/wavs')
    if not os.path.exists(f'data/processed/{split}/spectrograms'):
        os.makedirs(f'data/processed/{split}/spectrograms')
    if not os.path.exists(f'data/processed/{split}/spectrograms_pt'):
        os.makedirs(f'data/processed/{split}/spectrograms_pt')
    if not os.path.exists(f'data/processed/{split}/frames_pt'):
        os.makedirs(f'data/processed/{split}/frames_pt')

    note_count = np.random.randint(1,max_notes + 1, N)
    tot_notes = np.sum(note_count)
    #midi piano starts at 21 to 109.
    lowest_pitch = 21
    highest_pitch = 109
    pitches  = (torch.randperm(tot_notes) % (highest_pitch-lowest_pitch)).numpy() + lowest_pitch
    #velocitys  = (torch.randperm(tot_notes) % 20).numpy()

    note_index = 0
    for i in range(N):
        midi = pretty_midi.PrettyMIDI()
        instrument = pretty_midi.Instrument(0)

        for note in range(note_count[i]):
            pitch = pitches[note_index]
            velocity = 100
            #clip starts at 0 and ends at sec 1
            min_duration = 0.15
            start = np.random.uniform(0.1,0.7*2)
            end = np.random.uniform(start+min_duration, 2 - 0.1)
            note = pretty_midi.Note(velocity=velocity, pitch=pitch, start=start, end=end)
            instrument.notes.append(note)
            note_index += 1
        midi.instruments.append(instrument)
        #save midi
        string_i = '{:0>5}'.format(str(i))
        midi.write(f'data/processed/{split}/midi/syn_{string_i}.mid')



import os
import tqdm
import threading

def synth_create_vids_worker(split, midi_files, image_width, image_height, piano_height, fps, segm_length, sample_rate, pbar):
    for midi_file in midi_files:
        if not os.path.exists(os.path.join(f'data/processed/{split}/spectrograms_pt', midi_file.replace('.mid','.pt'))):
            create_video(
                input_midi=os.path.join(f'data/processed/{split}/midi', midi_file),
                image_width=image_width,
                image_height=image_height,
                piano_height=piano_height,
                fps=fps,
                end_t=segm_length,
                sample_rate=sample_rate,
                split=split
            )
            pbar.update(1)

def synth_create_vids(split, image_width=512, image_height=16, piano_height=16, segm_length=2, sample_rate=16000, fps=16, num_threads=8):
    # Find paths to all midi files
    midi_files = [file for file in os.listdir(f'data/processed/{split}/midi') if file.endswith('.mid')]
    midi_files.sort()

    # Split midi_files into chunks for each thread
    chunks = [midi_files[i:i + len(midi_files)//num_threads] for i in range(0, len(midi_files), len(midi_files)//num_threads)]

    # Initialize progress bar
    total_files = sum(len(chunk) for chunk in chunks)
    with tqdm.tqdm(total=total_files) as pbar:
        # Create threads
        threads = []
        for chunk in chunks:
            t = threading.Thread(target=synth_create_vids_worker, args=(split, chunk, image_width, image_height, piano_height, fps, segm_length, sample_rate, pbar))
            threads.append(t)

        # Start threads
        for thread in threads:
            thread.start()

        # Wait for all threads to finish
        for thread in threads:
            thread.join()



In [None]:

if __name__ == '__main__':
    #torch set seed
    torch.manual_seed(0)
    import random
    random.seed(0)
    import numpy as np
    np.random.seed(0)

    all_files = []
    for root, dirs, files in os.walk('data/raw'):
        for file in files:
            if file.endswith('.mid'):
                all_files.append(os.path.join(root, file))
    all_files =np.asarray(all_files)
    N_tot = len(all_files)

    segm_length = 2 #in sec
    N_train,N_val,N_test = 200, 100,100
    split = "train"
    make_synthetic(N_train,split)
    synth_create_vids(split,segm_length=segm_length) #using defaultvalues / 32x3x16x512 ...
    save_to_pt(f'data/processed/{split}/midi', split=split)

    split = "val"
    make_synthetic(N_val,split)
    synth_create_vids(split)
    save_to_pt(f'data/processed/{split}/midi', split=split)

#DataLoader

In [None]:
import os
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from tqdm import tqdm
import torchaudio
from torchvision import transforms
#import wandb
import sys,os
import soundfile
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import SpeechT5FeatureExtractor


class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # self.max_waveform_length = max_waveform_length
        # self.max_num_frames = max_num_frames #
        self.frames = [file for file in os.listdir(os.path.join(root_dir, 'frames_pt')) if (file.endswith('.pt') )]
        self.frames.sort()
        self.transf = transforms.Compose([
            transforms.Resize((88, 55)),  # Resize the image if needed
        ])
        #self.feature_extractor = feature_extractor
        # checkpoint = "microsoft/speecht5_tts"
        # self.feature_extractor = SpeechT5FeatureExtractor(fmin=6,fmax=10000,do_normalize=True,num_mel_bins=128)


    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        wav_file = os.path.join(self.root_dir, 'wavs', self.frames[idx].replace('.pt', '.wav'))
        frames_path = os.path.join(self.root_dir, 'frames_pt', self.frames[idx])
        spectrogram_path = os.path.join(self.root_dir, 'spectrograms_pt', self.frames[idx])

        waveform, sample_rate = torchaudio.load(wav_file)

        frames = torch.load(frames_path).unsqueeze(0)
        frames = frames.float() / 255 #

        frames = frames.permute(0, 1, 4, 2, 3)
        spectrogram = torch.tensor(torch.load(spectrogram_path)).unsqueeze(0)
        spectrogram = (spectrogram/torch.max(spectrogram)) #normalize

        name = self.frames[idx].replace('.pt', '')


        return frames, spectrogram, name
def collate_fn(batch):
    frames, spectrogram, name = zip(*batch)
    dat = {'frames':torch.vstack(frames).to(device=device), 'spectrogram':  torch.vstack(spectrogram).to(device=device), 'name':name}
    return dat


#init model
batch_size=8
train_dataset = CustomDataset(root_dir='data/processed/train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataset = CustomDataset(root_dir='data/processed/val')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
print('Train data N:',len(train_dataset))
print('Val data N:',len(val_dataset))
day = iter(train_dataset)
fram, spe, name, spe_t5 = next(day)

print(fram.shape)
print(spe.shape)
print(spe_t5.shape)
print(name)


# Model

In [None]:
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np

from transformers.models.vivit.modeling_vivit import VivitModel, VivitConfig, VivitLayer, VivitEncoder
from transformers.models.speecht5.modeling_speecht5 import SpeechT5Decoder, SpeechT5Config, SpeechT5SpeechDecoderPostnet

class VivitTubeletEmbeddings(nn.Module):
    """
    Construct Vivit Tubelet embeddings.

    This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
    shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.

    The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
    (width // tubelet_size[2]).
    """

    def __init__(self, config):
        super().__init__()
        self.num_frames = config.num_frames
        self.image_size = config.image_size
        self.patch_size = config.tubelet_size
        # print(self.image_size)
        # print(self.patch_size)
        self.num_patches = (
            (self.image_size[1] // self.patch_size[2]) # 256/16
            * (self.image_size[0] // self.patch_size[1]) # 16/4
            * (self.num_frames // self.patch_size[0]) # 32/2
        )
        self.embed_dim = config.hidden_size

        self.projection = nn.Conv3d(
            config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
        )

    def forward(self, pixel_values):
        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height},{width}) doesn't match model ({self.image_size},{self.image_size})."
            )

        # permute to (batch_size, num_channels, num_frames, height, width)
        pixel_values = pixel_values.permute(0, 2, 1, 3, 4)

        x = self.projection(pixel_values)
        # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return x


class VivitEmbeddings(nn.Module):
    """
    Vivit Embeddings.

    Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
    """

    def __init__(self, config):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.patch_embeddings = VivitTubeletEmbeddings(config)

        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size) #used to be +1 patch for cls tokens
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.config = config

    def forward(self, pixel_values):
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values)

        #cls_tokens = self.cls_token.tile([batch_size, 1, 1])
        #embeddings = torch.cat((cls_tokens, embeddings), dim=1) #we dont do classificaiton, we do prediction so no cls tokens

        # add positional encoding to each token
        embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings

class VivitPooler(nn.Module): #dont simply pool like this?
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class SimpleModel(nn.Module):
    def __init__(self, embeddings, encoder, decoder,decoder_postnet):
        super().__init__()

        self.embeddings = embeddings
        self.encoder = encoder
        self.decoder = decoder
        self.decoder_postnet = decoder_postnet

        #self.pool = torch.nn.MaxPool2d((2,1))
        self.pool1 = torch.nn.MaxPool2d((1,4))
        #self.mlp = torch.nn.Linear(128* 88, 88 * 55)
        self.mlp1 = torch.nn.Linear(256 * 64, 88 * 55)

        self.activation = nn.Sigmoid()

    def forward(self, x):
        B, F, C, H, W = x.shape
        x = self.embeddings(x)
        #from second last dimension, drop the class tokens bcs we aren't doing classification
        x = self.encoder(x)#.last_hidden_state#[:,1:]
        #x = self.decoder(x.last_hidden_state)
        x = self.pool1(x.last_hidden_state).flatten(1)
        # x , _ , _ = self.decoder_postnet(x.last_hidden_state)

        #print(x.shape)
        # x = self.pool(x).flatten(1)
        x = self.mlp1(x).reshape(B, 88, 55)
        #x = self.activation(x)
        return x #torch.clip(x,max=10)



if __name__ == "__main__":

    img = torch.ones([2, 32, 3, 16, 512]).cuda()

    vivit_dict = {
        "image_size": (16,512),
        "num_frames": 32,
        "tubelet_size": [4, 16, 32],
        "num_channels": 3,
        "hidden_size": 512,
        "num_hidden_layers": 4,
        "num_attention_heads": 256,
        "intermediate_size": 512,
        "hidden_act": 'gelu_fast',
        "hidden_dropout_prob": 0.0,
        "attention_probs_dropout_prob": 0.0,
        "initializer_range": 0.02,
        "layer_norm_eps": 1e-06,
        "qkv_bias": True
    }
    config_vivit = VivitConfig(**vivit_dict)

    conf_dict = {
      "num_mel_bins": 88,
      "activation_dropout": 0.1,
      "attention_dropout": 0.1,
      "decoder_attention_heads": 32,
      "decoder_ffn_dim": 512,
      "decoder_layerdrop": 0.1,
      "decoder_layers": 2,
      "decoder_start_token_id": 2,
      "hidden_act": "gelu",
      "hidden_dropout": 0.1,
      "hidden_size": 512,
      "is_encoder_decoder": True,
      "layer_norm_eps": 1e-05,
      "mask_feature_length": 4,
      "mask_feature_min_masks": 0,
      "mask_feature_prob": 0.0,
      "mask_time_length": 4,
      "mask_time_min_masks": 2,
      "positional_dropout": 0.1,
      "transformers_version": "4.40.1",
      "use_guided_attention_loss": True,
    }
    embeddings = VivitEmbeddings(config_vivit)
    encoder = VivitEncoder(config_vivit)
    config_speecht5 = SpeechT5Config(**conf_dict)

    decoder = SpeechT5Decoder(config_speecht5)
    decoder_postnet = SpeechT5SpeechDecoderPostnet(config_speecht5)

    model = SimpleModel(embeddings, encoder, decoder,decoder_postnet).cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print('Trainable Parameters: %.3fM' % parameters)

    out = model(img)

    print("Shape of out :", out.shape)      # [B, num_classes]
    print("dtype of out :", out.dtype)      # float32



#Training

In [None]:
from tqdm import tqdm
import librosa
import numpy as np
import matplotlib.pyplot as plt
import wandb
#wandb.login(key='<yourapikeyhere>') # or disable wandb if not configured.



#opt
optimizer = torch.optim.AdamW(model.parameters(),weight_decay=1e-6, lr = 1e-4)
#     [
#         {"params": model.embeddings.parameters(), "lr": 1e-3},
#         {"params": model.encoder.parameters(), "lr": 1e-3},
#         {"params": model.decoder.parameters(), "lr": 1e-3},
#         {"params": model.decoder_postnet.parameters(), "lr": 1e-3},
#         {"params": model.last_layer1.parameters(), "lr": 5e-2},
#     ], weight_decay=1e-3
# )

parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

print('Trainable Parameters: %.3fM' % parameters)

#loss fun
criterion = nn.L1Loss() #nn.L1Loss() #nn.CrossEntropyLoss() # no reason to use crossentropy if we dont work with classes...

wandb.init(
    project="test_aisynth_v2",
    name="vid2audio",
    job_type="training",
    reinit=True)

# %% Fit the model
# Number of epochs
epochs = 200
train_losses = []
val_losses = []
step = 0
val_interval=100
save_interval=1000
lr_update=1000
plot_interval = 100
#use tqdm to print train loss and val loss as updating instead of constantly printing

val_loss = 10000
for epoch in range(epochs):
    train_epoch_loss =[]
    val_epoch_loss =[]
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - Train") as pbar:
        for dat in train_loader:
            model.train()
            optimizer.zero_grad()
            out = model(dat['frames'])
            loss = criterion(out, dat['spectrogram'])
            wandb.log({"train_loss": loss.detach().cpu().item()}, step=step)
            loss.backward()
            optimizer.step()
            train_epoch_loss.append(loss.detach().cpu())
            pbar.set_postfix({'train_loss': f'{loss:.4f}','val_loss': f'{val_loss:.4f}'})
            pbar.update()

            if step % plot_interval == 0:
              fig, axes = plt.subplots(1, 2)
              h = out.shape[2]
              w = out.shape[1]

              viz_out = out[0].detach().clone().cpu().numpy()
              viz_ref = dat['spectrogram'][0].detach().clone().cpu().numpy()

              # Plot the first image on the first subplot
              axes[0].imshow(viz_out,aspect='auto', extent=(0, h, 0, w))
              axes[0].set_title('predicted')

              # Plot the second image on the second subplot
              axes[1].imshow(viz_ref,aspect='auto', extent=(0, h, 0, w))
              axes[1].set_title('GT')

              plt.show()

              # fig, axes = plt.subplots(1, 2)
              # h = out.shape[2]
              # w = out.shape[1]


              # viz_out = np.exp(out[0].detach().clone().cpu().numpy())
              # viz_ref = np.exp(dat['spectrogram'][0].detach().clone().cpu().numpy())

              # # Plot the first image on the first subplot
              # axes[0].imshow(viz_out,aspect='auto', extent=(0, h, 0, w))
              # axes[0].set_title('predicted_log')

              # # Plot the second image on the second subplot
              # axes[1].imshow(viz_ref,aspect='auto', extent=(0, h, 0, w))
              # axes[1].set_title('GT_log')

              plt.show()

            if step % val_interval == 0:
                with torch.no_grad():
                    model.eval()

                    val_dat = next(iter(val_loader))

                    val_out = model(val_dat['frames'])

                    val_loss = criterion(val_out, val_dat['spectrogram'])

                    val_epoch_loss.append(val_loss.detach().cpu())


                    fig, axes = plt.subplots(1, 2)

                    viz_out_val = val_out[0].detach().clone().cpu().numpy()
                    viz_ref_val = val_dat['spectrogram'][0].detach().clone().cpu().numpy()

                    # Plot the first image on the first subplot
                    axes[0].imshow(viz_out_val,aspect='auto', extent=(0, h, 0, w))
                    axes[0].set_title('validation predicted')

                    # Plot the second image on the second subplot
                    axes[1].imshow(viz_ref_val,aspect='auto', extent=(0, h, 0,w))
                    axes[1].set_title('validation GT')

                    plt.show()

                    wandb.log({"val_loss": val_loss.detach().cpu().item()}, step=step)


            step += 1
            if step % save_interval == 0:
              torch.save(model.state_dict(),f'ai_synth_model_works_{step}.pth')

            if step % lr_update == 0:
              scheduler.step()

    # print(f'epoch train mean loss:{np.mean(train_epoch_loss)}')
    # print(f'epoch val mean loss:{np.mean(val_epoch_loss)}')
    wandb.log({"avg_train_loss": np.mean(train_epoch_loss)})
    wandb.log({"avg_val_loss": np.mean(val_epoch_loss)})

#1epoch train mean loss:6.293105602264404
#1epoch val mean loss:8.45799922943115

