In [18]:
import sys
import torch
import clip
import warnings
import torchaudio

from IPython.display import Audio
from pathlib import Path

warnings.filterwarnings("ignore")

In [2]:
ROOT = Path().resolve().parents[0]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

SAMPLE_RATE = 44100

In [4]:
from src.dataset.base.audio_and_caption import AudioTextGenerationDataset
from src.model.clip_audio_text_generator import AudioCLIPTextGenerator

## Загрузим набор данных

In [5]:
dataset_root = Path(ROOT, 'data', 'test')
test_audio_file = Path(dataset_root, 'John_Towner_Williams_London_Symphony_Orchestra_-_The_Imperial_March_Darth_Vaders_Theme_51136533.mp3')
dataset = AudioTextGenerationDataset(
    [test_audio_file],
    [''],
    sample_rate=SAMPLE_RATE,
    duration=10,
    channel=0,
)

In [6]:
test_music, test_caption = dataset[0]

Послушаем полученный кусочек аудио

In [7]:
print(f'Test music duration {len(test_music) / SAMPLE_RATE:0.2f} sec')
Audio(test_music.numpy(), rate=SAMPLE_RATE)

Test music duration 10.00 sec


In [20]:
torchaudio.save('test_audio.mp3', test_music.unsqueeze(0), SAMPLE_RATE)

## Сгенерируем текст

In [9]:
def run(audio, cond_text="Music is about", beam_size=5, **kwargs):
    text_generator = AudioCLIPTextGenerator(**kwargs)
    text_generator = text_generator.to(device)
    text_generator.eval()

    with torch.no_grad():
        audio_features = text_generator.get_audio_feature(audio.to(device))
    captions = text_generator.run(audio_features, cond_text, beam_size=beam_size)

    encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
    encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
    best_clip_idx = (torch.cat(encoded_captions) @ audio_features.t()).squeeze().argmax().item()

    print(captions)
    print('best clip:', cond_text + captions[best_clip_idx])

In [10]:
run(test_music.unsqueeze(0),
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:07:35 | [' celebration %% -3.4961832', ' prayer %% -3.6967826', ' courage %% -3.7132688']
18/06/2024 02:07:46 | [' courage. %% -3.1188593', ' courage and %% -3.846805', ' celebration. %% -4.1435647']
18/06/2024 02:07:57 | [' courage.! %% -3.1188593', ' courage and fear %% -4.0269346', ' celebration.! %% -4.1435647']
18/06/2024 02:08:09 | [' courage.!! %% -3.1188593', ' courage and fear. %% -3.9670525', ' courage and fear is %% -3.9976895']
18/06/2024 02:08:20 | [' courage.!!! %% -3.1188593', ' courage and fear is about %% -3.9168065', ' courage and fear.! %% -3.9670525']
18/06/2024 02:08:32 | [' courage.!!!! %% -3.1188593', ' courage and fear.!! %% -3.9670525', ' courage and fear is about weakness %% -4.18147']
18/06/2024 02:08:43 | [' courage.!!!!! %% -3.1188593', ' courage and fear.!!! %% -3.9670525', ' courage and fear is about weakness. %% -4.177792']
[' courage.', ' courage and fear.', ' courage and fear is about weakness.']
best clip: Music is about courage and fear

In [11]:
run(test_music.unsqueeze(0),
    cond_text="Music of a",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:09:57 | [' bird %% -4.173117', ' tiger %% -4.276628', ' band %% -4.3400946']
18/06/2024 02:10:08 | [' band rehearsal %% -4.273383', ' tiger rescued %% -4.346455', ' tiger escaped %% -4.3499928']
18/06/2024 02:10:20 | [' band rehearsal DVD %% -3.8559265', ' band rehearsal tape %% -4.2036576', ' band rehearsal CD %% -4.298081']
18/06/2024 02:10:31 | [' band rehearsal DVD - %% -4.2174225', ' band rehearsal DVD DVD %% -4.3128004', ' band rehearsal DVD 1 %% -4.3162885']
18/06/2024 02:10:43 | [' band rehearsal DVD - Radio %% -4.3328576', ' band rehearsal DVD - The %% -4.387536', ' band rehearsal DVD - Black %% -4.400156']
18/06/2024 02:10:54 | [' band rehearsal DVD - Radio City %% -4.343035', ' band rehearsal DVD - Radio Australia %% -4.3759', ' band rehearsal DVD - The Beast %% -4.4126134']
18/06/2024 02:11:06 | [' band rehearsal DVD - Radio City Fire %% -4.159396', ' band rehearsal DVD - Radio City Rock %% -4.38075', ' band rehearsal DVD - Radio City Girls %% -4.4476247']
18/

In [12]:
run(test_music.unsqueeze(0),
    cond_text="Music feelings is a",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:13:19 | [' beautiful %% -3.8995214', ' gorgeous %% -4.2534127', ' dark %% -4.4035306']
18/06/2024 02:13:30 | [' gorgeous thing %% -4.131052', ' beautiful euph %% -4.1349516', ' dark horse %% -4.194648']
18/06/2024 02:13:42 | [' beautiful euphoric %% -3.822873', ' beautiful euphor %% -4.086774', ' gorgeous thing in %% -4.213611']
18/06/2024 02:13:53 | [' beautiful euphoric euph %% -3.7051103', ' gorgeous thing in Guitar %% -3.982268', ' beautiful euphorbia %% -3.9830523']
18/06/2024 02:14:05 | [' beautiful euphoric euph. %% -3.9332318', ' beautiful euphoric euph- %% -3.9447262', ' beautiful euphoric euphoria %% -3.9972591']
18/06/2024 02:14:18 | [' beautiful euphoric euph.! %% -3.9332318', ' beautiful euphoric euph-inducing %% -3.9769704', ' beautiful euphoric euph-like %% -4.018782']
18/06/2024 02:14:30 | [' beautiful euphoric euph-inducing synth %% -3.8927255', ' beautiful euphoric euph.!! %% -3.9332318', ' beautiful euphoric euph-inducing pedal %% -3.9697692']
18/06/202

In [13]:
run(test_music.unsqueeze(0),
    cond_text="Music genre of a",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:16:17 | [' band %% -3.88393', ' country %% -4.149632', ' musical %% -4.2330146']
18/06/2024 02:16:30 | [' band Country %% -4.1342993', ' band American %% -4.4776273', ' band B %% -4.4825993']
18/06/2024 02:16:42 | [' band American Football %% -4.426455', ' band Burden %% -4.5843377', ' band B. %% -4.6668324']
18/06/2024 02:16:54 | [' band American Football Panthers %% -4.304871', ' band American Football Buffalo %% -4.4195013', ' band American Football Steelers %% -4.5569205']
18/06/2024 02:17:06 | [' band American Football Buffalo Bills %% -4.2074537', ' band American Football Buffalo Bears %% -4.3054986', ' band American Football Buffalo Bulls %% -4.394666']
18/06/2024 02:17:18 | [' band American Football Buffalo Bills 49 %% -4.3069572', ' band American Football Buffalo Bears 49 %% -4.309338', ' band American Football Buffalo Bills Patriots %% -4.343117']
18/06/2024 02:17:31 | [' band American Football Buffalo Bills 49 49 %% -4.2441406', ' band American Football Buffalo

In [21]:
run(test_music.unsqueeze(0),
    cond_text="Music genre is a",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:33:38 | [' strange %% -4.2177277', ' vast %% -4.458797', ' funny %% -4.6704655']
18/06/2024 02:33:50 | [' funny purse %% -3.8416684', ' funny dog %% -4.4628477', ' strange bag %% -4.5309205']
18/06/2024 02:34:02 | [' funny purse. %% -3.868821', ' funny dog. %% -4.031333', ' funny purse- %% -4.1967435']
18/06/2024 02:34:14 | [' funny purse.! %% -3.868821', ' funny dog.! %% -4.031333', ' funny purse-draw %% -4.471611']
18/06/2024 02:34:26 | [' funny purse.!! %% -3.868821', ' funny dog.!! %% -4.031333', ' funny purse-drawing %% -4.5101423']
18/06/2024 02:34:38 | [' funny purse.!!! %% -3.868821', ' funny dog.!!! %% -4.031333', ' funny purse-drawing exercise %% -4.473829']
18/06/2024 02:34:51 | [' funny purse.!!!! %% -3.868821', ' funny dog.!!!! %% -4.031333', ' funny purse-drawing exercise. %% -4.4295206']
[' funny purse.', ' funny dog.', ' funny purse-drawing exercise.']
best clip: Music genre is a funny purse.


In [15]:
run(test_music.unsqueeze(0),
    cond_text="Audio is about",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:20:06 | [' the %% -3.5416498', ' emotion %% -4.161348', ' power %% -4.511094']
18/06/2024 02:20:31 | [' the cru %% -2.926023', ' emotion and %% -3.9370146', ' emotion. %% -4.0997877']
18/06/2024 02:20:57 | [' the crux %% -3.396806', ' the cru ship %% -3.551814', ' the cru. %% -3.6555882']
18/06/2024 02:21:23 | [' the cru shipwreck %% -3.3051543', ' the crux of %% -3.3481336', ' the crux. %% -3.4974809']
18/06/2024 02:21:48 | [' the crux of of %% -2.9628615', ' the crux.! %% -3.4974809', ' the cru shipwreck discovery %% -3.4995675']
18/06/2024 02:22:14 | [' the crux of of David %% -2.8040333', ' the crux of of Samuel %% -2.8908498', ' the crux of of the %% -3.13797']
18/06/2024 02:22:40 | [" the crux of of Samuel's %% -3.0633729", ' the crux of of the Torah %% -3.0991924', " the crux of of David's %% -3.136517"]
18/06/2024 02:23:05 | [' the crux of of the Torah. %% -3.1681175', " the crux of of David's death %% -3.2768812", ' the crux of of the Torah story %% -3.295949']
1

In [16]:
run(test_music.unsqueeze(0),
    cond_text="Audio of a",
    beam_size=3,
    clip_checkpoints=Path(ROOT, "data", "models", "AudioCLIP-Full-Training.pt").resolve(),
    forbidden_tokens_file_path=Path(ROOT, "data", "forbidden_tokens.npy").resolve()
    )

18/06/2024 02:24:09 | [' firefighter %% -2.9142501', ' helicopter %% -3.067083', ' Halifax %% -3.0941613']
18/06/2024 02:24:35 | [' Halifax Dartmouth %% -3.4644177', ' helicopter rescue %% -3.6259506', ' firefighter removing %% -3.857686']
18/06/2024 02:25:01 | [' Halifax Dartmouth Israel %% -3.2367735', ' Halifax Dartmouth Israeli %% -3.2450016', ' Halifax Dartmouth Indian %% -3.6673706']
18/06/2024 02:25:27 | [' Halifax Dartmouth Israeli community %% -3.6403575', ' Halifax Dartmouth Israeli Jewish %% -3.668097', ' Halifax Dartmouth Israeli Religious %% -3.6701286']
18/06/2024 02:25:54 | [' Halifax Dartmouth Israeli Jewish Police %% -3.7589269', ' Halifax Dartmouth Israeli community meeting %% -3.816853', ' Halifax Dartmouth Israeli Jewish Agency %% -3.9244378']
18/06/2024 02:26:20 | [' Halifax Dartmouth Israeli community meeting. %% -3.8707902', ' Halifax Dartmouth Israeli Jewish Police officer %% -3.9115229', ' Halifax Dartmouth Israeli community meeting recorded %% -3.948745']
18/0