In [1]:
import torch
import os

os.environ['CUDA_VISIBLE_DEVICES']='0,1,2'
available_gpus = [torch.device(i) for i in range(torch.cuda.device_count())]
print(len(available_gpus))

import random
import numpy as np
from model_multigpu import SALMONN
from collections import Counter
from pprint import pprint
import time
import librosa
import soundfile as sf

3

Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
CUDA SETUP: CUDA runtime path found: /mnt/cs/voice/malykh-s/conda/miniconda3/envs/salmonn/lib/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 6.1
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /mnt/cs/voice/malykh-s/conda/miniconda3/envs/salmonn/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so...


In [2]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = available_gpus[0]
delim = '# ' * 90

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_parameters_all(model):
    return sum(p.numel() for p in model.parameters())

def get_llama_map(layers, zero_layers, n_gpus):
    n_layers = len(layers)
    n_layers_dist = n_layers - zero_layers
    needed_count = n_layers_dist / (n_gpus - 1)
    curr_gpu, curr_count = 1, 0
    llama_map = dict()
    for i in range(len(layers)):
        layer = layers[i]
        if i < zero_layers:
            llama_map[layer] = 0
            continue
        if curr_count < needed_count:
            llama_map[layer] = curr_gpu
            curr_count += 1
        else:
            curr_count = 0
            curr_gpu = min(n_gpus, curr_gpu + 1)
            llama_map[layer] = curr_gpu
            curr_count += 1
    return llama_map

In [3]:
ckpt_path = '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/SALMONN-7B/salmonn_7b_v0.pth'
whisper_path = '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/whisper-large-v2/'
beats_path = '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt'
vicuna_path = '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/vicuna-7b-v1.5/'
lora_alpha = 32
low_resource = False
llama_layers = ['model.embed_tokens'] + [f'model.layers.{i}' for i in range(0, 32)] + ['model.norm', 'lm_head']
llama_map = get_llama_map(llama_layers, zero_layers=1, n_gpus=len(available_gpus))

In [4]:
model = SALMONN(
    ckpt=ckpt_path,
    whisper_path=whisper_path,
    beats_path=beats_path,
    vicuna_path=vicuna_path,
    lora_alpha=lora_alpha,
    low_resource=low_resource,
    encoder_device=available_gpus[0],
    llama_device_map=llama_map,
    llama_dtype=torch.bfloat16
)
print(count_parameters(model))
print(count_parameters_all(model))
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

696681274
7530008249


SALMONN(
  (speech_encoder): WhisperEncoder(
    (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 1280)
    (layers): ModuleList(
      (0-31): 32 x WhisperEncoderLayer(
        (self_attn): WhisperAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_a

In [5]:
total_bits = 0
for key in model._modules:
    print(key.upper())
    print('Trainable parameter count: ', f'{count_parameters(model._modules[key]):_}')
    print('All parameter count: ', f'{count_parameters_all(model._modules[key]):_}')
    dtypes = []
    bits = 0
    for elem in model._modules[key].parameters():
        dtypes.append(elem.dtype)
        bits_sub = 1
        for part in elem.shape:
            bits_sub *= part
        if elem.dtype == torch.int8:
            bits += bits_sub * 8
        elif elem.dtype == torch.float16:
            bits += bits_sub * 16
        elif elem.dtype == torch.bfloat16:
            bits += bits_sub * 16
        elif elem.dtype == torch.float32:
            bits += bits_sub * 32
        else:
            print('bad')
    print(Counter(dtypes).most_common())
    print(f'Size: {bits/(2**23):.2f}MiB')
    print(delim)
    total_bits += bits
print(f'Size: {total_bits/(2**23):.2f}MiB')

SPEECH_ENCODER
Trainable parameter count:  636_784_640
All parameter count:  636_784_640
[(torch.float32, 487)]
Size: 2429.14MiB
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
LN_SPEECH
Trainable parameter count:  2_560
All parameter count:  2_560
[(torch.float32, 2)]
Size: 0.01MiB
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
BEATS
Trainable parameter count:  0
All parameter count:  90_717_055
[(torch.float32, 241)]
Size: 346.06MiB
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
LN_AUDIO
Trainable parameter count:  1_536
All parameter count:  1_536
[(torch.float32, 2)]
Size: 0.01MiB
# # # #

In [6]:
model.llama_model.hf_device_map

{'model.embed_tokens': 0,
 'model.layers.0': 1,
 'model.layers.1': 1,
 'model.layers.2': 1,
 'model.layers.3': 1,
 'model.layers.4': 1,
 'model.layers.5': 1,
 'model.layers.6': 1,
 'model.layers.7': 1,
 'model.layers.8': 1,
 'model.layers.9': 1,
 'model.layers.10': 1,
 'model.layers.11': 1,
 'model.layers.12': 1,
 'model.layers.13': 1,
 'model.layers.14': 1,
 'model.layers.15': 1,
 'model.layers.16': 1,
 'model.layers.17': 2,
 'model.layers.18': 2,
 'model.layers.19': 2,
 'model.layers.20': 2,
 'model.layers.21': 2,
 'model.layers.22': 2,
 'model.layers.23': 2,
 'model.layers.24': 2,
 'model.layers.25': 2,
 'model.layers.26': 2,
 'model.layers.27': 2,
 'model.layers.28': 2,
 'model.layers.29': 2,
 'model.layers.30': 2,
 'model.layers.31': 2,
 'model.norm': 2,
 'lm_head': 2}

In [7]:
wav_paths = ['/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/SALMONN-7B/wav_examples/tada.wav', 
             '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/SALMONN-7B/wav_examples/cartoon_monkey.wav',
             '/mnt/cs/voice/malykh-s/PycharmProjects/SALMONN/SALMONN-7B/wav_examples/machine_gun.wav']
prompts = ['Describe the following sound in great detail', 
           'What animal is laughing like this? Provide datailed reasoning for you answer', 
           'What gun is this? Tell me about its characteristics']
wavs = []
for wav_path in wav_paths:
    wav, sr = sf.read(wav_path)
    if len(wav.shape) == 2:
        wav = wav[:, 0]
    if len(wav) > 30 * sr:
        wav = wav[: 30 * sr]
    if sr != 16000:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
    wavs.append(wav)

In [8]:
total_time = 0
for i in range(len(wavs)):
    wav, prompt = wavs[i], prompts[i]
    t0 = time.time()
    with torch.no_grad():
        generation = model.generate(wav, prompt=prompt)
    t1 = time.time()
    total = t1-t0
    total_time += total
    pprint(generation[0])
    print(f'{total:.2f} seconds')
    print(delim)
print(f'{total_time:.2f} total seconds')

('The sound is a brass fanfare. The fanfare is played by a brass ensemble '
 'consisting of trumpets, trombones, and French horns. The sound is loud and '
 'brassy. The fanfare is played to announce the arrival of an important '
 'person, such as a king or queen, or to celebrate a victory. The fanfare can '
 'also be used to signal the start of a parade or a military march. The sound '
 'is energetic and uplifting. The fanfare can also be used to create a sense '
 'of excitement and anticipation. The sound can also be used to create a sense '
 'of triumph and victory. The sound can also be used to create a sense of '
 'grandeur and')
46.22 seconds
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. The hyena is known for its '
 'high-pitched laughter, which it uses to communicate with other hyenas and to '
 

In [9]:
total_time = 0
for i in range(len(wavs)):
    wav, prompt = wavs[i], prompts[i]
    t0 = time.time()
    with torch.no_grad():
        generation = model.generate(wav, prompt=prompt)
    t1 = time.time()
    total = t1-t0
    total_time += total
    pprint(generation[0])
    print(f'{total:.2f} seconds')
    print(delim)
print(f'{total_time:.2f} total seconds')

('The sound is a trumpet fanfare. The trumpet is a brass instrument that '
 'produces sound by blowing air into a mouthpiece that is attached to a metal '
 'tube. The trumpet fanfare is a piece of music that is played by the trumpet. '
 'The trumpet fanfare is a piece of music that is played by the trumpet. The '
 'trumpet fanfare is a piece of music that is played by the trumpet. The '
 'trumpet fanfare is a piece of music that is played by the trumpet. The '
 'trumpet fanfare is a piece of music that is played by the trumpet. The '
 'trumpet fanfare is a piece')
44.91 seconds
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. Hyenas are known for their '
 'high-pitched laughter, which they use to communicate with each other and to '
 'attract the attention of other animals. They are also known for their s

In [10]:
total_time = 0
for i in range(len(wavs)):
    wav, prompt = wavs[i], prompts[i]
    t0 = time.time()
    with torch.no_grad():
        generation = model.generate(wav, prompt=prompt)
    t1 = time.time()
    total = t1-t0
    total_time += total
    pprint(generation[0])
    print(f'{total:.2f} seconds')
    print(delim)
print(f'{total_time:.2f} total seconds')

('The sound is a trumpet fanfare. The trumpet is a brass instrument that '
 'produces sound by blowing air through a mouthpiece. The fanfare is a musical '
 'piece that is typically used to announce the arrival of an important person, '
 'such as a king or queen. The trumpet fanfare is often used in military '
 'ceremonies to announce the arrival of a high-ranking officer. The sound of '
 'the trumpet fanfare is loud and brassy, with a clear and crisp tone. The '
 'trumpet fanfare is typically played by a single trumpet, but it can also be '
 'played by a group of trumpets. The sound of the trumpet fan')
44.99 seconds
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. The hyena is known for its '
 'high-pitched laughter and is often referred to as the "laughing hyena." The '
 'hyena is a social animal that 

In [11]:
t0 = time.time()
with torch.no_grad():
    output = model.generate(wavs, prompts)
t1 = time.time()
total = t1-t0
print(f'{total:.2f} seconds')
for answer in output:
    pprint(answer)
    print(delim)

73.00 seconds
('The sound is of a brass instrument playing a loud and triumphant fanfare. '
 'The instrument is likely a bugle or trumpet, and the sound is brassy and '
 'bright. The player is likely a military musician or a brass band member. The '
 'sound is likely being played in a military setting, such as a military '
 'parade or a victory celebration. The sound is brassy and bright, and the '
 'player is likely a military musician or a brass band member. The sound is '
 'likely being played in a military setting, such as a military parade or a '
 'victory celebration. The sound is brassy and bright, and the player is '
 'likely a military musician or a brass band member')
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. Hyenas are known for their '
 'distinctive, high-pitched laughter, which they us

In [12]:
t0 = time.time()
with torch.no_grad():
    output = model.generate(wavs, prompts)
t1 = time.time()
total = t1-t0
print(f'{total:.2f} seconds')
for answer in output:
    pprint(answer)
    print(delim)

72.38 seconds
('The sound is a brass fanfare. The brass section is playing in unison, with '
 'the trumpets and trombones playing in harmony. The sound is loud and '
 'boisterous, with a sense of triumph and celebration. The brass section is '
 'playing with a lot of energy and enthusiasm. The sound is loud and '
 'boisterous, with a sense of triumph and celebration. The brass section is '
 'playing with a lot of energy and enthusiasm. The sound is loud and '
 'boisterous, with a sense of triumph and celebration. The brass section is '
 'playing with a lot of energy and enthusiasm. The sound is loud and '
 'boisterous, with a sense of triumph')
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. The hyena is known for its '
 'distinctive laughing call, which is often described as sounding like a human '
 'la

In [13]:
t0 = time.time()
with torch.no_grad():
    output = model.generate(wavs, prompts)
t1 = time.time()
total = t1-t0
print(f'{total:.2f} seconds')
for answer in output:
    pprint(answer)
    print(delim)

72.49 seconds
('The sound is a trumpet fanfare. The trumpet is a brass instrument that '
 'produces sound by blowing air through a mouthpiece into a long tube. The '
 'trumpet fanfare is a musical piece that is typically used to announce the '
 'arrival of an important person, such as a king or queen. The trumpet fanfare '
 'is often used in military ceremonies to announce the arrival of a general or '
 'other high-ranking officer. The trumpet fanfare is also used in sports '
 'events, such as football games, to announce the arrival of the home team. '
 'The trumpet fanfare is typically played by a single trumpet player, but it '
 'can also be played')
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
('The animal that is laughing like this is a hyena. Hyenas are known for their '
 'distinctive, high-pitched laughter, which they use to communicate with each