## **Initialization**

In [None]:
!pip install better_exceptions
!pip install av
!pip install soundfile



In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Feb  4 14:35:36 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import fnmatch
import os
import random
import re
import threading
import librosa,librosa.display
import numpy as npx
import tensorflow as tf
from six.moves import xrange
import better_exceptions
import tensorflow as tf
import numpy as np
import os
import time
import json
import numpy as np
import av
import torch as t
import tqdm
import soundfile
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.distributed import DistributedSampler
from time import sleep
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Dataset**

In [None]:
def audio_preprocess(x, hps):
    # Extra layer in case we want to experiment with different preprocessing
    # For two channel, blend randomly into mono (standard is .5 left, .5 right)

    # x: NTC
    x = x.float()
    if x.shape[-1]==2:
        if hps.aug_blend:
            mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand()
        else:
            mix = 0.5
        x=(mix*x[:,:,0]+(1-mix)*x[:,:,1])
    elif x.shape[-1]==1:
        x=x[:,:,0]
    else:
        assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels'

    # x: NT -> NTC
    x = x.unsqueeze(2)
    return x

def audio_postprocess(x, hps):
    return x

In [None]:
def get_duration_sec(file, cache=False):
    try:
        with open(file + '.dur', 'r') as f:
            duration = float(f.readline().strip('\n'))
        return duration
    except:
        container = av.open(file)
        audio = container.streams.get(audio=0)[0]
        duration = audio.duration * float(audio.time_base)
        if cache:
            with open(file + '.dur', 'w') as f:
                f.write(str(duration) + '\n')
        return duration

def load_audio(file, sr, offset, duration, resample=True, approx=False, time_base='samples', check_duration=True):
    if time_base == 'sec':
        offset = offset * sr
        duration = duration * sr
    # Loads at target sr, stereo channels, seeks from offset, and stops after duration
    container = av.open(file)
    audio = container.streams.get(audio=0)[0] # Only first audio stream
    audio_duration = audio.duration * float(audio.time_base)
    if approx:
        if offset + duration > audio_duration*sr:
            # Move back one window. Cap at audio_duration
            offset = np.min(audio_duration*sr - duration, offset - duration)
    else:
        if check_duration:
            assert offset + duration <= audio_duration*sr, f'End {offset + duration} beyond duration {audio_duration*sr}'
    if resample:
        resampler = av.AudioResampler(format='fltp',layout='stereo', rate=sr)
    else:
        assert sr == audio.sample_rate
    offset = int(offset / sr / float(audio.time_base)) #int(offset / float(audio.time_base)) # Use units of time_base for seeking
    duration = int(duration) #duration = int(duration * sr) # Use units of time_out ie 1/sr for returning
    sig = np.zeros((2, duration), dtype=np.float32)
    container.seek(offset, stream=audio)
    total_read = 0
    for frame in container.decode(audio=0): # Only first audio stream
        if resample:
            frame.pts = None
            frame = resampler.resample(frame)
        frame = frame.to_ndarray(format='fltp') # Convert to floats and not int16
        read = frame.shape[-1]
        if total_read + read > duration:
            read = duration - total_read
        sig[:, total_read:total_read + read] = frame[:, :read]
        total_read += read
        if total_read == duration:
            break
    assert total_read <= duration, f'Expected {duration} frames, got {total_read}'
    return sig, sr


## **Train Loader**

In [None]:
class AudioDataset(Dataset):
  # Added path variable for passing generalized path
  def  __init__(self, path):
    #data loading
    self.sr = 11000
    self.channels = 2
    self.min_duration = 17.0
    self.max_duration = 30.0
    self.sample_length = 24.0
    self.aug_shift = False
    self.init_dataset(path)


  def get_index_offset(self, item):
    # For a given dataset item and shift, return song index and offset within song
    half_interval = self.sample_length//2
    shift = np.random.randint(-half_interval, half_interval) if self.aug_shift else 0
    offset = item * self.sample_length + shift # Note we centred shifts, so adding now
    midpoint = offset + half_interval
    assert 0 <= midpoint < self.cumsum[-1], f'Midpoint {midpoint} of item beyond total length {self.cumsum[-1]}'
    index = np.searchsorted(self.cumsum, midpoint)  # index <-> midpoint of interval lies in this song
    start, end = self.cumsum[index - 1] if index > 0 else 0.0, self.cumsum[index] # start and end of current song
    assert start <= midpoint <= end, f"Midpoint {midpoint} not inside interval [{start}, {end}] for index {index}"
    if offset > end - self.sample_length: # Going over song
        offset = max(start, offset - half_interval)  # Now should fit
    elif offset < start: # Going under song
        offset = min(end - self.sample_length, offset + half_interval)  # Now should fit
    #assert start <= offset <= end - self.sample_length, f"Offset {offset} not in [{start}, {end - self.sample_length}]. End: {end}, SL: {self.sample_length}, Index: {index}"
    if not(start <= offset and offset <= end - self.sample_length):
      offset = end - self.sample_length
    offset = offset - start
    return index, offset

  def init_dataset(self, path):
    files = librosa.util.find_files(path, ext = ['mp3', 'm4a', 'opus','wav'])
    print(f'Found {len(files)} files!')

    self.files = files
    self.durations = [int(get_duration_sec(file)) for file in tqdm.tqdm(files)]
    self.cumsum = np.cumsum(self.durations)

  def __getitem__(self,item):
    index, offset = self.get_index_offset(item)
    filename, total_length = self.files[index], self.durations[index]
    data, sr = load_audio(filename, sr=self.sr, offset=offset, duration=self.sample_length,time_base='sec')
    assert data.shape == (self.channels, self.sample_length*self.sr), f'Expected {(self.channels, self.sample_length)}, got {data.shape}'
    return data.T

  def __len__(self):
    return int(np.floor(self.cumsum[-1] / self.sample_length))

In [None]:
collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], dim=0)

dataset = AudioDataset('/content/drive/MyDrive/Indian_Classical_Music_Generation/datasets/flute_dataset')

Found 25 files!


100%|██████████| 25/25 [00:00<00:00, 40.10it/s]


In [None]:
print("Length of dataset is ", len(dataset))

train_loader = DataLoader(dataset, batch_size=1, num_workers=2, pin_memory=False,shuffle = True,drop_last=True, collate_fn=collate_fn)

Length of dataset is  3123


# **Encoding**

## **VQ-VAE**

In [None]:
import torch.distributed as dist
from enum import Enum

class ReduceOp(Enum):
    SUM = 0,
    PRODUCT = 1,
    MIN = 2,
    MAX = 3

    def ToDistOp(self):
        return {
            self.SUM: dist.ReduceOp.SUM,
            self.PRODUCT: dist.ReduceOp.PRODUCT,
            self.MIN: dist.ReduceOp.MIN,
            self.MAX: dist.ReduceOp.MAX
        }[self]

def is_available():
    return dist.is_available()

def get_rank():
    if is_available():
        return _get_rank()
    else:
        return 0

def get_world_size():
    if is_available():
        return _get_world_size()
    else:
        return 1

def barrier():
    if is_available():
        return _barrier()
    #else: do nothing

def all_gather(tensor_list, tensor):
    if is_available():
        return _all_gather(tensor_list, tensor)
    else:
        tensor_list[0] = tensor

def all_reduce(tensor, op=ReduceOp.SUM):
    if is_available():
        return _all_reduce(tensor, op)
    #else: do nothing

def reduce(tensor, dst, op=ReduceOp.SUM):
    if is_available():
        return _reduce(tensor, dst, op)
    #else: do nothing

def broadcast(tensor, src):
    if is_available():
        return _broadcast(tensor, src)
    #else: do nothing

def init_process_group(backend, init_method):
    if is_available():
        return _init_process_group(backend, init_method)
    #else: do nothing

def _get_rank():
    return dist.get_rank()

def _barrier():
    return dist.barrier()

def _get_world_size():
    return dist.get_world_size()

def _all_gather(tensor_list, tensor):
    return dist.all_gather(tensor_list, tensor)

def _all_reduce(tensor, op):
    return dist.all_reduce(tensor, op.ToDistOp())

def _reduce(tensor, dst, op):
    return dist.reduce(tensor, dst, op.ToDistOp())

def _broadcast(tensor, src):
    return dist.broadcast(tensor, src)

def _init_process_group(backend, init_method):
    return dist.init_process_group(backend, init_method)


import gc

def freeze_model(model):
    model.eval()
    for params in model.parameters():
        params.requires_grad = False


def unfreeze_model(model):
    model.train()
    for params in model.parameters():
        params.requires_grad = True

def zero_grad(model):
    for p in model.parameters():
        if p.requires_grad and p.grad is not None:
            p.grad = None

def empty_cache():
    gc.collect()
    t.cuda.empty_cache()

def assert_shape(x, exp_shape):
    # assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}"
    return

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_state(model):
    return sum(s.numel() for s in model.state_dict().values())

In [None]:
# Simple gradient checkpointing. Works with distributed data parallel

def checkpoint(func, inputs, params, flag):
    if flag:
        args = inputs + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

class CheckpointFunction(t.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with t.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        for i in range(len(ctx.input_tensors)):
            temp = ctx.input_tensors[i]
            ctx.input_tensors[i] = temp.detach()
            ctx.input_tensors[i].requires_grad = temp.requires_grad
        with t.enable_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
        del ctx.input_tensors
        del output_tensors
        return (None, None) + input_grads

In [None]:
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F

class BottleneckBlock(nn.Module):
    def __init__(self, k_bins, emb_width, mu):
        super().__init__()
        self.k_bins = k_bins
        self.emb_width = emb_width
        self.mu = mu
        self.reset_k()
        self.threshold = 1.0

    def reset_k(self):
        self.init = False
        self.k_sum = None
        self.k_elem = None
        self.register_buffer('k', t.zeros(self.k_bins, self.emb_width).cuda())

    def _tile(self, x):
        d, ew = x.shape
        if d < self.k_bins:
            n_repeats = (self.k_bins + d - 1) // d
            std = 0.01 / np.sqrt(ew)
            x = x.repeat(n_repeats, 1)
            x = x + t.randn_like(x) * std
        return x

    def init_k(self, x):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        self.init = True
        # init k_w using random vectors from x
        y = self._tile(x)
        _k_rand = y[t.randperm(y.shape[0])][:k_bins]
        #dist.broadcast(_k_rand, 0)
        self.k = _k_rand
        assert self.k.shape == (k_bins, emb_width)
        self.k_sum = self.k
        self.k_elem = t.ones(k_bins, device=self.k.device)

    def restore_k(self, num_tokens=None, threshold=1.0):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        self.init = True
        assert self.k.shape == (k_bins, emb_width)
        self.k_sum = self.k.clone()
        self.k_elem = t.ones(k_bins, device=self.k.device)
        if num_tokens is not None:
            expected_usage = num_tokens / k_bins
            self.k_elem.data.mul_(expected_usage)
            self.k_sum.data.mul_(expected_usage)
        self.threshold = threshold

    def update_k(self, x, x_l):
        mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
        with t.no_grad():
            # Calculate new centres
            x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device)  # k_bins, N * L
            x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1)

            _k_sum = t.matmul(x_l_onehot, x)  # k_bins, w
            _k_elem = x_l_onehot.sum(dim=-1)  # k_bins
            y = self._tile(x)
            _k_rand = y[t.randperm(y.shape[0])][:k_bins]

            # dist.broadcast(_k_rand, 0)
            # dist.all_reduce(_k_sum)
            # dist.all_reduce(_k_elem)

            # Update centres
            old_k = self.k
            self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum  # w, k_bins
            self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem  # k_bins
            usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float()
            self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \
                     + (1 - usage) * _k_rand
            _k_prob = _k_elem / t.sum(_k_elem)  # x_l_onehot.mean(dim=-1)  # prob of each bin
            entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8))  # entropy ie how diverse
            used_curr = (_k_elem >= self.threshold).sum()
            usage = t.sum(usage)
            dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape))
        return dict(entropy=entropy,
                    used_curr=used_curr,
                    usage=usage,
                    dk=dk)

    def preprocess(self, x):
        # NCT -> NTC -> [NT, C]
        x = x.permute(0, 2, 1).contiguous()
        x = x.view(-1, x.shape[-1])  # x_en = (N * L, w), k_j = (w, k_bins)

        if x.shape[-1] == self.emb_width:
            prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape))
        elif x.shape[-1] == 2 * self.emb_width:
            x1, x2 = x[...,:self.emb_width], x[...,self.emb_width:]
            prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape)))

            # Normalise
            x = x1 + x2
        else:
            assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}"
        return x, prenorm

    def postprocess(self, x_l, x_d, x_shape):
        # [NT, C] -> NTC -> NCT
        N, T = x_shape
        x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
        x_l = x_l.view(N, T)
        return x_l, x_d

    def quantise(self, x):
        # Calculate latent code x_l
        k_w = self.k.t()
        distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0,
                                                                                            keepdim=True)  # (N * L, b)
        min_distance, x_l = t.min(distance, dim=-1)
        fit = t.mean(min_distance)
        return x_l, fit

    def dequantise(self, x_l):
        x = F.embedding(x_l, self.k)
        return x

    def encode(self, x):
        N, width, T = x.shape

        # Preprocess.
        x, prenorm = self.preprocess(x)

        # Quantise
        x_l, fit = self.quantise(x)

        # Postprocess.
        x_l = x_l.view(N, T)
        return x_l

    def decode(self, x_l):
        N, T = x_l.shape
        width = self.emb_width

        # Dequantise
        x_d = self.dequantise(x_l)

        # Postprocess
        x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous()
        return x_d

    def forward(self, x, update_k=True):
        N, width, T = x.shape

        # Preprocess
        x, prenorm = self.preprocess(x)

        # Init k if not inited
        if update_k and not self.init:
            self.init_k(x)

        # Quantise and dequantise through bottleneck
        x_l, fit = self.quantise(x)
        x_d = self.dequantise(x_l)

        # Update embeddings
        if update_k:
            update_metrics = self.update_k(x, x_l)
        else:
            update_metrics = {}

        # Loss
        commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape)

        # Passthrough
        x_d = x + (x_d - x).detach()

        # Postprocess
        x_l, x_d = self.postprocess(x_l, x_d, (N,T))
        return x_l, x_d, commit_loss, dict(fit=fit,
                                           pn=prenorm,
                                           **update_metrics)


class Bottleneck(nn.Module):
    def __init__(self, l_bins, emb_width, mu, levels):
        super().__init__()
        self.levels = levels
        level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu)
        self.level_blocks = nn.ModuleList()
        for level in range(self.levels):
            self.level_blocks.append(level_block(level))

    def encode(self, xs):
        zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)]
        return zs

    def decode(self, zs, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels
        xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)]
        return xs_quantised

    def forward(self, xs):
        zs, xs_quantised, commit_losses, metrics = [], [], [], []
        for level in range(self.levels):
            level_block = self.level_blocks[level]
            x = xs[level]
            z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training)
            zs.append(z)
            if not self.training:
                # Be extra paranoid and make sure the encoder weights can't
                # change from straight-through estimator
                x_quantised = x_quantised.detach()
            xs_quantised.append(x_quantised)
            commit_losses.append(commit_loss)
            if self.training:
                metrics.append(metric)
        return zs, xs_quantised, commit_losses, metrics

class NoBottleneckBlock(nn.Module):
    def restore_k(self):
        pass

class NoBottleneck(nn.Module):
    def __init__(self, levels):
        super().__init__()
        self.level_blocks = nn.ModuleList()
        self.levels = levels
        for level in range(levels):
            self.level_blocks.append(NoBottleneckBlock())

    def encode(self, xs):
        return xs

    def decode(self, zs, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels
        return zs

    def forward(self, xs):
        zero = t.zeros(()).cuda()
        commit_losses = [zero for _ in range(self.levels)]
        metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)]
        return xs, xs, commit_losses, metrics

In [None]:
import math
import torch.nn as nn

class ResConvBlock(nn.Module):
    def __init__(self, n_in, n_state):
        super().__init__()
        self.model = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(n_in, n_state, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(n_state, n_in, 1, 1, 0),
        )

    def forward(self, x):
        return x + self.model(x)

class Resnet(nn.Module):
    def __init__(self, n_in, n_depth, m_conv=1.0):
        super().__init__()
        self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)])

    def forward(self, x):
        return self.model(x)

class ResConv1DBlock(nn.Module):
    def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0):
        super().__init__()
        padding = dilation
        self.model = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(n_in, n_state, 3, 1, padding, dilation),
            nn.ReLU(),
            nn.Conv1d(n_state, n_in, 1, 1, 0),
        )
        if zero_out:
            out = self.model[-1]
            nn.init.zeros_(out.weight)
            nn.init.zeros_(out.bias)
        self.res_scale = res_scale

    def forward(self, x):
        return x + self.res_scale * self.model(x)

class Resnet1D(nn.Module):
    def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False):
        super().__init__()
        def _get_depth(depth):
            if dilation_cycle is None:
                return depth
            else:
                return depth % dilation_cycle
        blocks = [ResConv1DBlock(n_in, int(m_conv * n_in),
                                 dilation=dilation_growth_rate ** _get_depth(depth),
                                 zero_out=zero_out,
                                 res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth))
                  for depth in range(n_depth)]
        if reverse_dilation:
            blocks = blocks[::-1]
        self.checkpoint_res = checkpoint_res
        if self.checkpoint_res == 1:
            if dist.get_rank() == 0:
                print("Checkpointing convs")
            self.blocks = nn.ModuleList(blocks)
        else:
            self.model = nn.Sequential(*blocks)

    def forward(self, x):
        if self.checkpoint_res == 1:
            for block in self.blocks:
                x = checkpoint(block, (x, ), block.parameters(), True)
            return x
        else:
            return self.model(x)

In [None]:
class EncoderConvBlock(nn.Module):
    def __init__(self, input_emb_width, output_emb_width, down_t,
                 stride_t, width, depth, m_conv,
                 dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
                 res_scale=False):
        super().__init__()
        blocks = []
        filter_t, pad_t = stride_t * 2, stride_t // 2
        if down_t > 0:
            for i in range(down_t):
                block = nn.Sequential(
                    nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),    #filter_t is kernel size
                    Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale),
                )
                blocks.append(block)
            block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
            blocks.append(block)
        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        return self.model(x)

class DecoderConvBock(nn.Module):
    def __init__(self, input_emb_width, output_emb_width, down_t,
                 stride_t, width, depth, m_conv, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_decoder_dilation=False, checkpoint_res=False):
        super().__init__()
        blocks = []
        if down_t > 0:
            filter_t, pad_t = stride_t * 2, stride_t // 2
            block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
            blocks.append(block)
            for i in range(down_t):
                block = nn.Sequential(
                    Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out, res_scale=res_scale, reverse_dilation=reverse_decoder_dilation, checkpoint_res=checkpoint_res),
                    nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t, pad_t)  #filter_t is kernel size
                )
                blocks.append(block)
        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        return self.model(x)

class Encoder(nn.Module):
    def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
                 strides_t, **block_kwargs):
        super().__init__()
        self.input_emb_width = input_emb_width
        self.output_emb_width = output_emb_width
        self.levels = levels
        self.downs_t = downs_t
        self.strides_t = strides_t

        block_kwargs_copy = dict(**block_kwargs)
        if 'reverse_decoder_dilation' in block_kwargs_copy:
            del block_kwargs_copy['reverse_decoder_dilation']
        level_block = lambda level, down_t, stride_t: EncoderConvBlock(input_emb_width if level == 0 else output_emb_width,
                                                           output_emb_width,
                                                           down_t, stride_t,
                                                           **block_kwargs_copy)
        self.level_blocks = nn.ModuleList()
        iterator = zip(list(range(self.levels)), downs_t, strides_t)
        for level, down_t, stride_t in iterator:
            self.level_blocks.append(level_block(level, down_t, stride_t))

    def forward(self, x):
        N, T = x.shape[0], x.shape[-1]
        emb = self.input_emb_width
        assert_shape(x, (N, emb, T))
        xs = []

        # 64, 32, ...
        iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
        for level, down_t, stride_t in iterator:
            level_block = self.level_blocks[level]
            x = level_block(x)
            emb, T = self.output_emb_width, T // (stride_t ** down_t)
            assert_shape(x, (N, emb, T))
            xs.append(x)

        return xs

class Decoder(nn.Module):
    def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
                 strides_t, **block_kwargs):
        super().__init__()
        self.input_emb_width = input_emb_width
        self.output_emb_width = output_emb_width
        self.levels = levels

        self.downs_t = downs_t

        self.strides_t = strides_t

        level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width,
                                                          output_emb_width,
                                                          down_t, stride_t,
                                                          **block_kwargs)
        self.level_blocks = nn.ModuleList()
        iterator = zip(list(range(self.levels)), downs_t, strides_t)
        for level, down_t, stride_t in iterator:
            self.level_blocks.append(level_block(level, down_t, stride_t))

        self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)

    def forward(self, xs, all_levels=True):
        if all_levels:
            assert len(xs) == self.levels
        else:
            assert len(xs) == 1
        x = xs[-1]
        N, T = x.shape[0], x.shape[-1]
        emb = self.output_emb_width
        assert_shape(x, (N, emb, T))

        # 32, 64 ...
        iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
        for level, down_t, stride_t in iterator:
            level_block = self.level_blocks[level]
            x = level_block(x)
            emb, T = self.output_emb_width, T * (stride_t ** down_t)
            assert_shape(x, (N, emb, T))
            if level != 0 and all_levels:
                x = x + xs[level - 1]

        x = self.out(x)
        return x

In [None]:
class DefaultSTFTValues:
    def __init__(self, hps):
        self.sr = 44100
        self.n_fft = 2048
        self.hop_length = 256
        self.window_size = 6 * self.hop_length

class STFTValues:
    def __init__(self, hps, n_fft, hop_length, window_size):
        self.sr = 44100
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.window_size = window_size

# not sure about this but used in loss decode!
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen = 0.0, 0.0, 0.0, 0.0
    spec_norm_total, spec_nelem = 0.0, 0.0
    idx = 0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples ** 2)
        idx += 1

    # if dist.is_available():
    #     from jukebox.utils.dist_utils import allreduce
    #     n_seen = allreduce(n_seen)
    #     total = allreduce(total)
    #     total_sq = allreduce(total_sq)
    #     l1 = allreduce(l1)
    #     spec_nelem = allreduce(spec_nelem)
    #     spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2 = total_sq / n_seen - mean ** 2,
                     l1 = l1 / n_seen,
                     spec = spec_norm_total / spec_nelem)
    print(bandwidth)
    return bandwidth

############################################### above are important values try to debug ##################################

def stft(sig, hps):
    return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device))

def spec(x, hps):
    return t.norm(stft(x, hps), p=2, dim=-1)

def norm(x):
    return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt()

def squeeze(x):
    if len(x.shape) == 3:
        assert x.shape[-1] in [1,2]
        x = t.mean(x, -1)
    if len(x.shape) != 2:
        raise ValueError(f'Unknown input shape {x.shape}')
    return x

def spectral_loss(x_in, x_out, hps):
    hps = DefaultSTFTValues(hps)
    spec_in = spec(squeeze(x_in.float()), hps)
    spec_out = spec(squeeze(x_out.float()), hps)
    return norm(spec_in - spec_out)

def multispectral_loss(x_in, x_out, hps):
    losses = []
    assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size)
    args = [hps.multispec_loss_n_fft,
            hps.multispec_loss_hop_length,
            hps.multispec_loss_window_size]
    for n_fft, hop_length, window_size in zip(*args):
        hps = STFTValues(hps, n_fft, hop_length, window_size)
        spec_in = spec(squeeze(x_in.float()), hps)
        spec_out = spec(squeeze(x_out.float()), hps)
        losses.append(norm(spec_in - spec_out))
    return sum(losses) / len(losses)

def spectral_convergence(x_in, x_out, hps, epsilon=2e-3):
    hps = DefaultSTFTValues(hps)
    spec_in = spec(squeeze(x_in.float()), hps)
    spec_out = spec(squeeze(x_out.float()), hps)

    gt_norm = norm(spec_in)
    residual_norm = norm(spec_in - spec_out)
    mask = (gt_norm > epsilon).float()
    return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon)

def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4):
    hps = DefaultSTFTValues(hps)
    spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon)
    spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon)
    return t.mean(t.abs(spec_in - spec_out))

# def load_audio(file, sr, offset, duration, mono=False):
#     # Librosa loads more filetypes than soundfile
#     x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr)
#     if len(x.shape) == 1:
#         x = x.reshape((1, -1))
#     return x

def save_wav(fname, aud, sr):
    # clip before saving?
    aud = t.clamp(aud, -1, 1).cpu().numpy()
    for i in list(range(aud.shape[0])):
        soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav')


In [None]:
import sys

def def_tqdm(x):
    return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

def get_range(x):
    return x

def init_logging(hps, local_rank, rank):
    logdir = f"{hps.local_logdir}/{hps.name}"
    if local_rank == 0:
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        with open(logdir + 'argv.txt', 'w') as f:
            f.write(hps.argv + '\n')
        print("Logging to", logdir)
    logger = Logger(logdir, rank)
    metrics = Metrics()
    logger.add_text('hps', str(hps))
    return logger, metrics

def get_name(hps):
    name = ""
    for key, value in hps.items():
        name += f"{key}_{value}_"
    return name

def average_metrics(_metrics):
    metrics = {}
    for _metric in _metrics:
        for key, val in _metric.items():
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(val)
    return {key: sum(vals)//len(vals) for key, vals in metrics.items()}

class Metrics:
    def __init__(self):
        self.sum = {}
        self.n = {}

    def update(self, tag, val, batch):
        # v is average value over batch
        # store total value and total batch, returns dist average
        sum = t.tensor(val * batch).float().cuda()
        n = t.tensor(batch).float().cuda()
        dist.all_reduce(sum)
        dist.all_reduce(n)
        sum = sum.item()
        n = n.item()
        self.sum[tag] = self.sum.get(tag, 0.0) + sum
        self.n[tag] = self.n.get(tag, 0.0) + n
        return sum / n

    def avg(self, tag):
        if tag in self.sum:
            return self.sum[tag] / self.n[tag]
        else:
            return 0.0

    def reset(self):
        self.sum = {}
        self.n = {}

class Logger:
    def __init__(self, logdir, rank):
        if rank == 0:
            from tensorboardX import SummaryWriter
            self.sw = SummaryWriter(f"{logdir}/logs")
        self.iters = 0
        self.rank = rank
        self.works = []
        self.logdir = logdir

    def step(self):
        self.iters += 1

    def flush(self):
        if self.rank == 0:
            self.sw.flush()

    def add_text(self, tag, text):
        if self.rank == 0:
            self.sw.add_text(tag, text, self.iters)

    def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8):
        if self.rank == 0:
            for i in range(min(len(auds), max_log)):
                if max_len:
                    self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate)
                else:
                    self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate)

    def add_audio(self, tag, aud, sample_rate=22050):
        if self.rank == 0:
            self.sw.add_audio(tag, aud, self.iters, sample_rate)

    def add_images(self, tag, img, dataformats="NHWC"):
        if self.rank == 0:
            self.sw.add_images(tag, img, self.iters, dataformats=dataformats)

    def add_image(self, tag, img):
        if self.rank == 0:
            self.sw.add_image(tag, img, self.iters)

    def add_scalar(self, tag, val):
        if self.rank == 0:
            self.sw.add_scalar(tag, val, self.iters)

    def get_range(self, loader):
        if self.rank == 0:
            self.trange = def_tqdm(loader)
        else:
            self.trange = loader
        return enumerate(self.trange)

    def close_range(self):
        if self.rank == 0:
            self.trange.close()

    def set_postfix(self, *args, **kwargs):
        if self.rank == 0:
            self.trange.set_postfix(*args, **kwargs)

    # For logging summaries of varies graph ops
    def add_reduce_scalar(self, tag, layer, val):
        if self.iters % 100 == 0:
            with t.no_grad():
                val = val.float().norm()/float(val.numel())
            work = dist.reduce(val, 0, async_op=True)
            self.works.append((tag, layer, val, work))

    def finish_reduce(self):
        for tag, layer, val, work in self.works:
            work.wait()
            if self.rank == 0:
                val = val.item()/dist.get_world_size()
                self.lw[layer].add_scalar(tag, val, self.iters)
        self.works = []

In [None]:
import torch as t
import torch.nn as nn

def dont_update(params):
    for param in params:
        param.requires_grad = False

def update(params):
    for param in params:
        param.requires_grad = True

def calculate_strides(strides, downs):
    return [stride ** down for stride, down in zip(strides, downs)]

def _loss_fn(loss_fn, x_target, x_pred, hps):
    if loss_fn == 'l1':
        return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1']
    elif loss_fn == 'l2':
        return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2']
    elif loss_fn == 'linf':
        residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
        values, _ = t.topk(residual, hps.linf_k, dim=1)
        return t.mean(values) / hps.bandwidth['l2']
    elif loss_fn == 'lmix':
        loss = 0.0
        if hps.lmix_l1:
            loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps)
        if hps.lmix_l2:
            loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps)
        if hps.lmix_linf:
            loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps)
        return loss
    else:
        assert False, f"Unknown loss_fn {loss_fn}"

class VQVAE(nn.Module):
    def __init__(self, input_shape, levels, downs_t, strides_t,
                 emb_width, l_bins, mu, commit, spectral, multispectral,
                 multipliers=None, use_bottleneck=True, **block_kwargs):
        super().__init__()

        self.sample_length = input_shape[0]
        # print(input_shape)
        x_shape, x_channels = input_shape[:-1], input_shape[-1]
        self.x_shape = x_shape

        self.downsamples = calculate_strides(strides_t, downs_t)
        self.hop_lengths = np.cumprod(self.downsamples)
        self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)]
        self.levels = levels

        if multipliers is None:
            self.multipliers = [1] * levels
        else:
            assert len(multipliers) == levels, "Invalid number of multipliers"
            self.multipliers = multipliers
        def _block_kwargs(level):
            this_block_kwargs = dict(block_kwargs)
            this_block_kwargs["width"] *= self.multipliers[level]
            this_block_kwargs["depth"] *= self.multipliers[level]
            return this_block_kwargs

        encoder = lambda level: Encoder(x_channels, emb_width, level + 1,
                                        downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
        decoder = lambda level: Decoder(x_channels, emb_width, level + 1,
                                        downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for level in range(levels):
            self.encoders.append(encoder(level))
            self.decoders.append(decoder(level))

        if use_bottleneck:
            self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels)
        else:
            self.bottleneck = NoBottleneck(levels)

        self.downs_t = downs_t
        self.strides_t = strides_t
        self.l_bins = l_bins
        self.commit = commit
        self.spectral = spectral
        self.multispectral = multispectral

    def preprocess(self, x):
        # x: NTC [-1,1] -> NCT [-1,1]
        assert len(x.shape) == 3
        x = x.permute(0,2,1).float()
        return x

    def postprocess(self, x):
        # x: NTC [-1,1] <- NCT [-1,1]
        x = x.permute(0,2,1)
        return x

    def _decode(self, zs, start_level=0, end_level=None):
        # Decode
        if end_level is None:
            end_level = self.levels
        assert len(zs) == end_level - start_level
        xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level)
        assert len(xs_quantised) == end_level - start_level

        # Use only lowest level
        decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1]
        x_out = decoder(x_quantised, all_levels=False)
        x_out = self.postprocess(x_out)
        return x_out

    def decode(self, zs, start_level=0, end_level=None, bs_chunks=1):
        z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs]
        x_outs = []
        for i in range(bs_chunks):
            zs_i = [z_chunk[i] for z_chunk in z_chunks]
            x_out = self._decode(zs_i, start_level=start_level, end_level=end_level)
            x_outs.append(x_out)
        return t.cat(x_outs, dim=0)

    def _encode(self, x, start_level=0, end_level=None):
        # Encode
        if end_level is None:
            end_level = self.levels
        x_in = self.preprocess(x)
        xs = []
        for level in range(self.levels):
            encoder = self.encoders[level]
            x_out = encoder(x_in)
            xs.append(x_out[-1])
        zs = self.bottleneck.encode(xs)
        return zs[start_level:end_level]

    def encode(self, x, start_level=0, end_level=None, bs_chunks=1):
        x_chunks = t.chunk(x, bs_chunks, dim=0)
        zs_list = []
        for x_i in x_chunks:
            zs_i = self._encode(x_i, start_level=start_level, end_level=end_level)
            zs_list.append(zs_i)
        zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)]
        return zs

    def sample(self, n_samples):
        zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes]
        return self.decode(zs)

    def forward(self, x, hps, loss_fn='l1'):
        metrics = {}

        N = x.shape[0]

        # Encode/Decode
        x_in = self.preprocess(x)
        xs = []
        for level in range(self.levels):
            encoder = self.encoders[level]
            x_out = encoder(x_in)
            xs.append(x_out[-1])

        zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs)
        x_outs = []
        for level in range(self.levels):
            decoder = self.decoders[level]
            x_out = decoder(xs_quantised[level:level+1], all_levels=False)
            assert_shape(x_out, x_in.shape)
            x_outs.append(x_out)

        # Loss
        def _spectral_loss(x_target, x_out, hps):
            if hps.use_nonrelative_specloss:
                sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
            else:
                sl = spectral_convergence(x_target, x_out, hps)
            sl = t.mean(sl)
            return sl

        def _multispectral_loss(x_target, x_out, hps):
            sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
            sl = t.mean(sl)
            return sl

        recons_loss = t.zeros(()).to(x.device)
        spec_loss = t.zeros(()).to(x.device)
        multispec_loss = t.zeros(()).to(x.device)
        x_target = audio_postprocess(x.float(), hps)

        for level in reversed(range(self.levels)):
            x_out = self.postprocess(x_outs[level])
            x_out = audio_postprocess(x_out, hps)
            this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps)
            this_spec_loss = _spectral_loss(x_target, x_out, hps)
            this_multispec_loss = _multispectral_loss(x_target, x_out, hps)
            metrics[f'recons_loss_l{level + 1}'] = this_recons_loss
            metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss
            metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss
            recons_loss += this_recons_loss
            spec_loss += this_spec_loss
            multispec_loss += this_multispec_loss

        commit_loss = sum(commit_losses)
        loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss

        with t.no_grad():
            sc = t.mean(spectral_convergence(x_target, x_out, hps))
            l2_loss = _loss_fn("l2", x_target, x_out, hps)
            l1_loss = _loss_fn("l1", x_target, x_out, hps)
            linf_loss = _loss_fn("linf", x_target, x_out, hps)

        quantiser_metrics = average_metrics(quantiser_metrics)

        metrics.update(dict(
            recons_loss=recons_loss,
            spectral_loss=spec_loss,
            multispectral_loss=multispec_loss,
            spectral_convergence=sc,
            l2_loss=l2_loss,
            l1_loss=l1_loss,
            linf_loss=linf_loss,
            commit_loss=commit_loss,
            **quantiser_metrics))

        for key, val in metrics.items():
            metrics[key] = val.detach()

        return x_out, loss, metrics

# **Generation**

## **Scalable(Sparse) Transformer**

In [None]:
def print_once(msg):
    if (not dist.is_available()) or dist.get_rank()==0:
        print(msg)

In [None]:
# jukebox/jukebox/transformer/factored_attention.py /

def repeat(x, n, dim):
    if dim == -1:
        dim = len(x.shape) - 1
    return x.view(int(np.prod(x.shape[:dim+1])), 1, int(np.prod(x.shape[dim+1:]))).repeat(1,n,1).view(*x.shape[:dim], n * x.shape[dim], *x.shape[dim+1:])

def get_mask(mask, q_l, kv_l, blocks, spread, device, sample, sample_t):
    # returns a mask of shape 1 x 1 x q_l x kv_l or None if masking is not needed.
    if mask is None or q_l == 1:
        return None
    offset = sample_t - q_l if sample else max(kv_l - q_l, 0)
    if mask == 'autoregressive':
        # Masked dense
        mask = t.ones(q_l, kv_l, device=device).tril(offset)
    elif mask == 'summary':
        # Masked summary
        mask = t.nn.functional.pad(t.ones(q_l, q_l, device=device).tril().view(q_l, blocks, q_l // blocks)[:,:-1,-kv_l//blocks:],(0,0,1,0),value=1).contiguous().view(q_l, kv_l)
    elif mask == 'prime':
        mask = t.ones(q_l, kv_l, device=device).tril(offset)
    return mask.view(1,1,q_l,kv_l)

class FactoredAttention(nn.Module):
    def __init__(self, n_in, n_ctx, n_state, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 scale=True, mask=False,
                 zero_out=False, init_scale=1.0,
                 checkpoint_attn=0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.n_in = n_in
        self.n_ctx = n_ctx # NOTE: n_ctx could be different within operations. This is complete n_ctx
        self.n_state = n_state
        assert n_state % n_head == 0
        self.n_head = n_head
        self.scale = scale
        self.mask = mask
        if attn_func == 6:
            self.c_attn = Conv1D(n_in, n_state, init_scale=init_scale)
            self.c_enc_kv = Conv1D(n_in, n_state * 2, init_scale=init_scale)
        else:
            self.c_attn = Conv1D(n_in, n_state * 3, init_scale=init_scale)
        self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
        self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0.0 else lambda x: x
        self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x

        # Sequence of length l is factored as [blocks, l // blocks]
        self.attn_func = attn_func
        self.qkv, self.attn, self.attn_mask = {
            0: (self.factored_qkv, self.dense_attn, 'autoregressive'),              # Attend to all positions
            1: (self.factored_qkv, self.block_attn, 'autoregressive'),              # Attend to your block
            2: (self.factored_qkv, self.transpose_block_attn, 'autoregressive'),    # Attend to transpose block
            3: (self.factored_qkv, self.prev_block_attn, None),                     # Attend to previous block
            4: (self.factored_qkv, self.summary_attn, 'summary'),                   # Attend to last position of each block
            5: (self.factored_qkv, self.summary_spread_attn, 'summary'),
            6: (self.decode_qkv, self.decode_attn, None),
            7: (self.prime_qkv, self.prime_attn, 'prime')
        }[attn_func] # Attend to last k position of each block

        self.blocks = blocks
        self.spread = spread
        if blocks is not None:
            assert n_ctx % blocks == 0
            self.block_ctx = n_ctx // blocks
        self.checkpoint_attn = checkpoint_attn # 0: None, 1: Attn after heads split, 2: Attn

        self.sample_t = 0
        self.cache = {}
        self.encoder_dims = encoder_dims
        self.prime_len = prime_len
        self.record_attn = False
        self.w = None

    def _attn(self, q, k, v, sample):
        scale = 1. / math.sqrt(math.sqrt(self.n_state // self.n_head))
        if self.training:
            w = t.matmul(q * scale, k * scale)
        else:
            w = t.matmul(q, k)
            w.mul_(scale*scale)
        wtype = w.dtype
        w = w.float()
        if self.mask:
            # Generate appropriate mask to mask out all positions before current
            # Might take up lot of memory for dense, so can cache it
            mask = get_mask(self.attn_mask, q.size(-2), k.size(-1), self.blocks, self.spread, w.device, sample, self.sample_t)
            if mask is not None:
                #print(mask)
                w = w * mask + -1e9 * (1 - mask)
            w = F.softmax(w, dim=-1).type(wtype)
        else:
            w = F.softmax(w, dim=-1).type(wtype)
        if self.record_attn:
            self.w = w #.float().cpu().numpy()
            if self.attn_func == 7:
                # only keep music queries and lyrics keys/values
                self.w = self.w[:,:,self.prime_len:,:self.prime_len]
        w = self.attn_dropout(w)
        a = t.matmul(w, v)
        return a

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = (*x.size()[:-2], x.size(-2) * x.size(-1))
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = (*x.size()[:-1], self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def dense_attn(self, query, key, value, sample):
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        if self.checkpoint_attn == 1 and not sample:
            a = checkpoint(lambda q,k,v,s=sample: self._attn(q,k,v,s), (query, key, value),
                       (), True)
        else:
            a = self._attn(query,key,value,sample)
        a = self.merge_heads(a)
        return a

    def block_attn(self, q, k, v, sample):
        blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
        bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
        if sample:
            assert l == self._suff_cache_len(), f"{l} != {self._suff_cache_len()}"
            return self.dense_attn(q, k, v, sample).view(bs, 1, d)
        else:
            ql = q.shape[1]
            q = q.view(bs * ql // block_ctx, block_ctx, d)
            if ql < l:
                l = ql
                k = k[:, -l:].contiguous()
                v = v[:, -l:].contiguous()
            k = k.view(bs * l // block_ctx, block_ctx, d)
            v = v.view(bs * l // block_ctx, block_ctx, d)
            return self.dense_attn(q, k, v, sample).view(bs, l, d)

    def transpose_block_attn(self, q, k, v, sample):
        blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
        bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
        if sample:
            block_l = (l - 1) % block_ctx
            k = k[:,block_l::block_ctx,:]
            v = v[:,block_l::block_ctx,:]
            return self.dense_attn(q, k, v, sample).view(bs, 1, d)
        else:
            ql = q.shape[1]
            q = q.view(bs, ql // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx, ql // block_ctx, d)
            k = k.view(bs,  l // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx,  l // block_ctx, d)
            v = v.view(bs,  l // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx,  l // block_ctx, d)
            return self.dense_attn(q, k, v, sample).view(bs, block_ctx, ql // block_ctx, d).transpose(1,2).contiguous().view(bs, ql, d)

    def prev_block_attn(self, q, k, v, sample):
        blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
        bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
        if sample:
            assert l == self._suff_cache_len(), f"{l} != {self._suff_cache_len()}"
            block = (l - 1) // block_ctx
            prev_l = (block - 1) * block_ctx
            if block > 0:
                assert prev_l == 0
                k = k[:, prev_l:prev_l + block_ctx, :]
                v = v[:, prev_l:prev_l + block_ctx, :]
            else:
                k = t.zeros(bs, block_ctx, d, device=q.device, dtype=q.dtype)
                v = t.zeros(bs, block_ctx, d, device=q.device, dtype=q.dtype)
            return self.dense_attn(q, k, v, sample).view(bs, 1, d)
        else:
            ql = q.shape[1]
            q = q.view(bs * ql // block_ctx, block_ctx, d)
            k = t.nn.functional.pad(k.view(bs, l // block_ctx, block_ctx, d)[:, :-1, :, :], (0,0,0,0,1,0)).view(bs * l // block_ctx, block_ctx, d)
            v = t.nn.functional.pad(v.view(bs, l // block_ctx, block_ctx, d)[:, :-1, :, :], (0,0,0,0,1,0)).view(bs * l // block_ctx, block_ctx, d)
            if ql < l:
                qb = ql // block_ctx
                kb =  l // block_ctx
                l = ql
                k = k.view(bs, kb, block_ctx, d)[:, -qb:].contiguous().view(bs * qb, block_ctx, d)
                v = v.view(bs, kb, block_ctx, d)[:, -qb:].contiguous().view(bs * qb, block_ctx, d)
            return self.dense_attn(q, k, v, sample).view(bs, l, d)

    def summary_attn(self, q, k, v, sample):
        blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
        bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
        if sample:
            k = t.nn.functional.pad(k[:, block_ctx-1:blocks*block_ctx-1:block_ctx, :],(0,0,1,0))
            v = t.nn.functional.pad(v[:, block_ctx-1:blocks*block_ctx-1:block_ctx, :],(0,0,1,0))
            return self.dense_attn(q, k, v, sample).view(bs, 1, d)
        else:
            k = t.nn.functional.pad(k.view(bs, blocks, l // blocks, d)[:, :-1, -1, :],(0,0,1,0)) # bs, blocks, d
            v = t.nn.functional.pad(v.view(bs, blocks, l // blocks, d)[:, :-1, -1, :],(0,0,1,0)) # bs, blocks, d
            return self.dense_attn(q, k, v, sample).view(bs, l, d)

    def summary_spread_attn(self, q, k, v, sample):
        blocks, block_ctx, spread = self.blocks, self.block_ctx, self.spread # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
        bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
        if sample:
            assert False, "Not yet implemented"
            # k = t.nn.functional.pad(k,(0,0,block_ctx,(-l)%block_ctx)).view(bs, -1, block_ctx, d)[:,:-1,-spread:,:].contiguous().view(bs, -1, d)
            # v = t.nn.functional.pad(v,(0,0,block_ctx,(-l)%block_ctx)).view(bs, -1, block_ctx, d)[:,:-1,-spread:,:].contiguous().view(bs, -1, d)
            # return self.dense_attn(q, k, v, sample).view(bs, 1, d)
        else:
            k = t.nn.functional.pad(k.view(bs, blocks, l // blocks, d)[:, :-1, -spread:, :],(0,0,0,0,1,0)).contiguous().view(bs, blocks * spread, d)  # bs, blocks * spread, d
            v = t.nn.functional.pad(v.view(bs, blocks, l // blocks, d)[:, :-1, -spread:, :],(0,0,0,0,1,0)).contiguous().view(bs, blocks * spread, d)  # bs, blocks * spread, d
            return self.dense_attn(q, k, v, sample).view(bs, l, d)

    def prime_attn(self, q, k, v, sample):
        prime_len = self._prime_len
        k = k[:, :prime_len]
        v = v[:, :prime_len]
        return self.dense_attn(q, k, v, sample)

    def decode_attn(self, q, k, v, sample):
        assert k.shape[1] == v.shape[1] == self.encoder_dims, f'k: {k.shape}, v: {v.shape}, enc_dims: {self.encoder_dims}'
        return self.dense_attn(q, k, v, sample)

    def factored_qkv(self, x, encoder_kv=None, sample=False):
        curr_ctx = x.shape[1]
        assert encoder_kv is None
        query, key, value = x.chunk(3, dim=2)
        if sample:
            self.sample_t += curr_ctx
            key, value = self._append_cache(key, value)
            l_cache = self._suff_cache_len()
            if self._cache_len() > l_cache:
                self._slice_cache(-l_cache)
            if curr_ctx > 1:
                if self.attn_func != 0:
                    query = self._pad_to_block_ctx(query, query=True)
                    key = self._pad_to_block_ctx(key)
                    value = self._pad_to_block_ctx(value)
                    assert key.shape[1] % self.block_ctx == 0
                    assert query.shape[1] % self.block_ctx == 0
                assert key.shape[1] == value.shape[1]
                assert query.shape[1] <= key.shape[1]
                sample = False
            else:
                key = self.cache['key']
                value = self.cache['value']
        return query, key, value, sample

    def prime_qkv(self, x, encoder_kv=None, sample=False):
        curr_ctx = x.shape[1]
        assert encoder_kv is None
        query, key, value = x.chunk(3, dim=2)
        if sample:
            if self._cache_len() < self._prime_len:
                self._append_cache(key, value)
            if self._cache_len() > self._prime_len:
                self._slice_cache(0, self._prime_len)
            key, value = self.cache['key'], self.cache['value']
            self.sample_t += curr_ctx
            assert key.shape[1] == value.shape[1] == self._suff_cache_len(), f'k: {key.shape}, v: {value.shape}, prime_dims: {self._suff_cache_len()}'
        else:
            assert key.shape[1] == value.shape[1] == self.n_ctx, f'k: {key.shape}, v: {value.shape}, prime_dims: {self.n_ctx}'
        assert key.shape[0] == value.shape[0] == query.shape[0], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
        assert key.shape[2] == value.shape[2] == query.shape[2], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
        return query, key, value, sample

    def decode_qkv(self, x, encoder_kv=None, sample=False):
        curr_ctx = x.shape[1]
        assert encoder_kv is not None
        query = x
        if sample:
            if self.sample_t == 0:
                self.cache['key'], self.cache['value'] = self.c_enc_kv(encoder_kv.type_as(x)).chunk(2, dim=2)
            key, value = self.cache['key'], self.cache['value']
            self.sample_t += curr_ctx
        else:
            key, value = self.c_enc_kv(encoder_kv.type_as(x)).chunk(2, dim=2)
        assert key.shape[0] == value.shape[0] == query.shape[0], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
        assert key.shape[1] == value.shape[1] == self.encoder_dims, f'k: {key.shape}, v: {value.shape}, enc_dims: {self.encoder_dims}'
        assert key.shape[2] == value.shape[2] == query.shape[2], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
        return query, key, value, sample

    def forward(self, x, encoder_kv=None, sample=False):
        curr_ctx = x.shape[1]
        x = self.c_attn(x)
        query, key, value, sample = self.qkv(x, encoder_kv=encoder_kv, sample=sample)
        if self.checkpoint_attn == 2 and not sample:
            a = checkpoint(lambda q,k,v,s=sample: self.attn(q,k,v,s), (query, key, value), (), True)
        else:
            a = self.attn(query,key,value,sample)
        if a.shape[1] != curr_ctx:
            offset = self._offset(curr_ctx)
            a = a[:,offset:offset + curr_ctx,:].contiguous()
        a = self.c_proj(a)
        return self.resid_dropout(a)

    @property
    def _prime_len(self):
        prime_len = self.prime_len
        assert prime_len is not None
        prime_blocks = (prime_len // self.blocks) + 1
        return prime_blocks * self.blocks

    def _offset(self, curr_ctx):
        if self.attn_func == 0:
            return 0
        return (self.sample_t - curr_ctx) % self.block_ctx

    def _pad_to_block_ctx(self, x, query=False):
        l = x.shape[1]
        offset = self._offset(l) if query else 0
        n_blocks = (l + offset + self.block_ctx - 1) // self.block_ctx
        pad = n_blocks * self.block_ctx - l - offset
        if pad == 0 and offset == 0:
            return x
        else:
            return F.pad(x, (0, 0, offset, pad))

    def _cache_len(self):
        return 0 if 'key' not in self.cache else self.cache['key'].shape[1]

    def _suff_cache_len(self):
        """
        Precondition:
            key and value are appended with the current context and
            self.sample_t reflects the 1-indexed sample location in the
            context.
        """
        if self.attn_func == 0:
            return self.sample_t
        elif self.attn_func == 1:
            return (self.sample_t - 1) % self.block_ctx + 1
        elif self.attn_func == 2:
            return self.sample_t
        elif self.attn_func == 3:
            if self.sample_t <= self.block_ctx:
                return self.sample_t
            else:
                curr_block = (self.sample_t - 1) % self.block_ctx + 1
                prev_block = self.block_ctx
                return curr_block + prev_block
        elif self.attn_func == 6:
            return self.encoder_dims
        elif self.attn_func == 7:
            return min(self.sample_t, self._prime_len)
        else:
            raise NotImplementedError()

    def _slice_cache(self, start, end=None):
        self.cache['key'] = self.cache['key'][:, start:end]
        self.cache['value'] = self.cache['value'][:, start:end]

    def _append_cache(self, key, value):
        if 'key' not in self.cache:
            self.cache['key'] = key
            self.cache['value'] = value
        else:
            old_key, old_value = key, value
            key = t.cat([self.cache['key'], key], dim=1)
            value = t.cat([self.cache['value'], value], dim=1)
            del self.cache['key']
            del self.cache['value']
            del old_key
            del old_value
            self.cache['key'] = key
            self.cache['value'] = value
        return self.cache['key'], self.cache['value']

    def del_cache(self):
        self.sample_t = 0
        if 'key' in self.cache:
            del self.cache['key']
        if 'value' in self.cache:
            del self.cache['value']
        self.cache = {}

    def check(self):
        blocks = self.blocks or 1
        spread = self.spread or 1
        bs, l, d = (4, self.n_ctx, self.n_in)
        x = t.randn(bs, l, d).cuda()
        x.requires_grad = True
        x_out = self.forward(x) # bs, l, d
        loss = x_out.mean(dim = -1) # bs, l
        pos = 60
        grad = t.autograd.grad(loss[2, pos], x)[0]

        assert grad.shape == (bs, l, d)
        assert (grad[:2] == 0).all()
        assert (grad[3:] == 0).all()
        assert (grad[2, (pos + 1):] == 0).all()
        pos_grad = (t.sum(grad[2] ** 2, dim=-1) > 0).nonzero().view(-1).cpu()

        block_pos = pos - (pos % (l // blocks))
        exp_pos_grad = {0: t.arange(pos),
                        1: t.arange(block_pos, pos),
                        2: t.arange(pos % (l // blocks), pos, l // blocks),
                        3: t.arange(block_pos - l // blocks, block_pos),
                        4: t.arange(l // blocks - 1, pos, l // blocks),
                        5: ((t.arange(pos) % (l // blocks) >= (l // blocks - spread)) & (t.arange(pos) < block_pos)).nonzero().view(-1)}[self.attn_func]
        exp_pos_grad = t.cat([exp_pos_grad, t.tensor([pos])], dim=-1)

        assert (len(pos_grad) == len(exp_pos_grad)) and (pos_grad == exp_pos_grad).all(), \
            f"Expected pos grad {exp_pos_grad} got {pos_grad} for attn_func {self.attn_func} pos {pos} l {l} blocks {blocks}"

    def check_cache(self, n_samples, sample_t, fp16):
        assert self.sample_t == sample_t, f"{self.sample_t} != {sample_t}"
        if sample_t == 0:
            assert self.cache == {}
        else:
            dtype = {True: t.float16, False: t.float32}[fp16]
            l_cache = self._suff_cache_len()
            assert self.cache['key'].shape == (n_samples, l_cache, self.n_state)
            assert self.cache['value'].shape == (n_samples, l_cache, self.n_state)
            assert self.cache['key'].dtype == dtype, f"Expected {dtype}, got {self.cache['key'].dtype}"
            assert self.cache['value'].dtype == dtype, f"Expected {dtype}, got {self.cache['value'].dtype}"

    def check_sample(self):
        t.manual_seed(42)
        bs, l, d = (4, self.n_ctx, self.n_in)
        prime = 5
        x = t.randn(bs, l, d).cuda()
        xs = t.chunk(x, l, dim=1)
        assert self.sample_t == 0
        assert self.cache == {}

        with t.no_grad():
            enc_l = self.encoder_dims
            encoder_kv = None
            if self.attn_func == 6:
                encoder_kv = t.randn(bs, enc_l, d).cuda()

            # Normal path
            x_out_normal = self.forward(x, encoder_kv=encoder_kv)

            # Sampling path
            x_out_sample = t.cat([self.forward(xs[i], encoder_kv=encoder_kv, sample=True) for i in range(l)],dim=1)
        max_err = t.max(t.abs(x_out_sample - x_out_normal))
        assert max_err < 1e-8, f"Max sampling err is {max_err} {[i for i in range(l) if t.max(t.abs(x_out_sample - x_out_normal)[:,i,:]) > 1e-8]}"

        with t.no_grad():
            x_out_normal = x_out_normal[:,:prime,:]
            # Prime sampling path
            self.del_cache()
            x_out_sample = self.forward(x[:,:prime,:].contiguous(), encoder_kv=encoder_kv, sample=True)
            self.check_cache(bs, prime, False)

        max_err = t.max(t.abs(x_out_sample - x_out_normal))
        assert max_err < 1e-8, f"Max prime sampling err is {max_err} {[i for i in range(prime) if t.max(t.abs(x_out_sample - x_out_normal)[:,i,:]) > 1e-8]}"

    def check_chunks(self, chunk_size):
        t.manual_seed(42)
        bs, l, d = (4, self.n_ctx, self.n_in)
        enc_l = self.encoder_dims
        assert l % chunk_size == 0
        n_chunks = l // chunk_size
        with t.no_grad():
            encoder_kv = None
            x = t.randn(bs, l, d).cuda()
            if self.attn_func == 6:
                encoder_kv = t.randn(bs, enc_l, d).cuda()

            self.del_cache()
            y_forw = self.forward(x, encoder_kv=encoder_kv, sample=False)
            self.del_cache()
            y_forw_sample = self.forward(x, encoder_kv=encoder_kv, sample=True)
            max_err = t.max(t.abs(y_forw - y_forw_sample))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_sample)[:, i, :]) > 1e-6]}"

            self.del_cache()
            x_chunks = t.chunk(x, n_chunks, dim=1)
            y_chunks = []
            total_len = 0
            for x_chunk in x_chunks:
                y_chunk = self.forward(x_chunk.contiguous(), encoder_kv=encoder_kv, sample=True)
                total_len += x_chunk.shape[1]
                self.check_cache(bs, total_len, False)
                y_chunks.append(y_chunk)
            y_forw_in_chunks = t.cat(y_chunks, dim=1)

            max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"


#if __name__ == '__main__':
#    from jukebox.utils.dist_utils import setup_dist_from_mpi
#    setup_dist_from_mpi(port=29600)
#    n_in = 16
#    n_state = n_in * 2
#    n_ctx = 6144
#    n_head = 4
#    n_depth = 12
#    blocks = 64
#    chunk_size = 8
#    for attn_func in [0, 1, 2, 3, 6, 7]:
#        encoder_dims = {0: 0, 1: 0, 2: 0, 3: 0, 6: 64, 7: 0}[attn_func]
#        prime_len = {0: 0, 1: 0, 2: 0, 3: 0, 6: 0, 7: 384}[attn_func]
#        attn = FactoredAttention(n_in, n_ctx + prime_len, n_state, n_head, mask=True,
#                                 attn_func=attn_func, blocks=blocks,
#                                 encoder_dims=encoder_dims, prime_len=prime_len)
#        attn.training = False
#        attn.check_sample()
#        attn.check_chunks(chunk_size)
#        print(f"Checked attn_func: {attn_func}")

In [None]:
# jukebox/jukebox/transformer/ops.py /

# Import FusedLayerNorm if we have apex, otherwise use regular LayerNorm
try:
    from apex.normalization import FusedLayerNorm
    print("Using apex FusedLayerNorm")
except ImportError:
    from torch.nn import LayerNorm as FusedLayerNorm

class LayerNorm(FusedLayerNorm):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
        self.width = np.prod(normalized_shape)
        self.max_numel = 65535*self.width

    def forward(self, input):
        if input.numel() > self.max_numel:
            return F.layer_norm(input.float(), self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
        else:
            return super(LayerNorm, self).forward(input.float()).type_as(input)

def gelu(x):
    return 0.5 * x * (1 + t.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * t.pow(x, 3))))


def swish(x):
    return x * t.sigmoid(x)

@t.jit.script
def quick_gelu(x):
    return x * t.sigmoid(1.702 * x)

@t.jit.script
def quick_gelu_bwd(x, grad_output):
    sig = t.sigmoid(1.702 * x)
    return grad_output * sig * (1.702 * x * (1 - sig) + 1.)

class QuickGelu(t.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return quick_gelu(x)

    @staticmethod
    def backward(ctx, grad_output):
        return quick_gelu_bwd(ctx.saved_tensors[0], grad_output)

def memory_efficient_quick_gelu(x):
    return QuickGelu.apply(x)

ACT_FNS = {
    'relu': t.nn.functional.relu,
    'swish': swish,
    'gelu': gelu,
    'quick_gelu': memory_efficient_quick_gelu #quick_gelu
}

def _move_to_gpu_and_convert_conv_weights_to_fp16(l):
    l.cuda()
    if isinstance(l, Conv1D):
        l.w.data = l.w.data.half()

def _convert_conv_weights_to_fp32(l):
    if isinstance(l, Conv1D):
        l.w.data = l.w.data.float()

def _convert_conv_weights_to_fp16(l):
    if isinstance(l, Conv1D):
        l.w.data = l.w.data.half()

def _convert_embedding_weights_to_fp16(l):
    if isinstance(l, t.nn.Embedding):
        l.weight.data = l.weight.data.half()

def _convert_embedding_weights_to_fp32(l):
    if isinstance(l, t.nn.Embedding):
        l.weight.data = l.weight.data.float()

class Conv1D(nn.Module):
    def __init__(self, n_in, n_out, zero_out=False, init_scale=1.0):
        super(Conv1D, self).__init__()
        self.n_in = n_in
        self.n_out = n_out
        if zero_out:
            w = t.zeros(n_in, n_out)
        else:
            w = t.empty(n_in, n_out)
            nn.init.normal_(w, std=0.02 * init_scale)
        b = t.zeros(n_out)
        self.w = nn.Parameter(w)
        self.b = nn.Parameter(b)

    def forward(self, x):
        size_out = (*x.size()[:-1], self.n_out)
        x = t.addmm(self.b.type_as(x), x.view(-1, x.size(-1)), self.w.type_as(x)) # If x if float then float else half
        x = x.view(*size_out)
        return x

# For large contexts, mask's can take up memory, so you can make a single saved mask for all layers
class Mask(nn.Module):
    def __init__(self, n_ctx):
        super().__init__()
        self.register_buffer('b', t.tril(t.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))

    def forward(self, w):
        w = w * self.b + -1e9 * (1 - self.b)  # For fp16 do w = w.float().masked_fill(self.b, float('-inf')
        return w

def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    #assert logits.dim() == 2  # batch size 1 for now - could be updated for more but the code would be less clear
    logits = logits.clone()
    top_k = min(top_k, logits.size(-1))  # Safety check
    assert (top_k == 0) or (top_p == 0.0)
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < t.topk(logits, top_k, dim=-1)[0][..., -1:]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = t.sort(logits, descending=True, dim=-1)
        cumulative_probs = t.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        #indices_to_remove = sorted_indices[sorted_indices_to_remove]
        indices_to_remove = t.zeros_like(logits, dtype=t.uint8).scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

In [None]:
# jukebox/jukebox/transformer/transformer.py /

def _convert_mlp_traced(l):
    if isinstance(l, ResAttnBlock):
        l.mlp = t.jit.trace(l.mlp, t.randn(1, 1, l.n_in).cuda())

def _convert_mlp_traced_fp16(l):
    if isinstance(l, ResAttnBlock):
        l.mlp = t.jit.trace(l.mlp, t.randn(1, 1, l.n_in).cuda().half())

class MLP(nn.Module):
    def __init__(self, n_in, n_state, resid_dropout=0.0, afn='quick_gelu', zero_out=False, init_scale=1.0):
        super().__init__()
        self.c_fc = Conv1D(n_in, n_state, init_scale=init_scale)
        self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
        self.act = ACT_FNS[afn]
        self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x

    def forward(self, x):
        m = self.act(self.c_fc(x))
        m = self.c_proj(m)
        return self.resid_dropout(m)

class ResAttnBlock(nn.Module):
    def __init__(self, n_in, n_ctx, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=1.0,
                 m_attn = 0.25, m_mlp = 1.,
                 checkpoint_attn = 0, checkpoint_mlp = 0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                      scale=scale, mask=mask,
                                      zero_out=zero_out, init_scale=init_scale,
                                      checkpoint_attn=checkpoint_attn,
                                      attn_func=attn_func, blocks=blocks, spread=spread,
                                      encoder_dims=encoder_dims, prime_len=prime_len)
        self.ln_0 = LayerNorm(n_in)
        self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
                       resid_dropout=resid_dropout,
                       afn=afn,
                       zero_out=zero_out, init_scale=init_scale)
        self.ln_1 = LayerNorm(n_in)
        self.res_scale = res_scale

        self.checkpoint_attn = checkpoint_attn
        self.checkpoint_mlp = checkpoint_mlp
        self.n_in = n_in
        self.attn_func = attn_func

    def forward(self, x, encoder_kv, sample=False):
        if sample:
            a = self.attn(self.ln_0(x), encoder_kv, sample)
            m = self.mlp(self.ln_1(x + a))
        else:
            if self.attn_func == 6:
                assert encoder_kv is not None
                a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,encoder_kv),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            else:
                assert encoder_kv is None
                a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
                           (*self.mlp.parameters(), *self.ln_1.parameters()),
                           self.checkpoint_mlp == 1)
        if self.res_scale == 1.0:
            h = x + a + m
        else:
            h = x + self.res_scale * (a + m)
        return h

class Transformer(nn.Module):
    def __init__(self, n_in, n_ctx, n_head, n_depth,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=False,
                 m_attn=0.25, m_mlp=1.,
                 checkpoint_attn=0, checkpoint_mlp=0, checkpoint_res=0,
                 attn_order=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.n_in = n_in
        self.n_ctx = n_ctx
        self.encoder_dims = encoder_dims
        self.blocks = blocks
        if blocks is not None:
            assert n_ctx % blocks == 0
            self.block_ctx = n_ctx // blocks
        self.prime_len = prime_len
        self.n_head = n_head

        res_scale = 1.0 / n_depth if res_scale else 1.0

        # Orders of attn_func
        attn_func = {0: lambda d: 0,                    # Complete dense attn
                     1: lambda d: [1,2][d%2],           # Alternate row and column attn
                     2: lambda d: [1,2,3][d % 3],       # Alternate row, column and previous row attn
                     3: lambda d: [1,4][d % 2],         # Alternate row and last column
                     4: lambda d: [1,5][d % 2],         # Alternate row and last k columns
                     5: lambda d: [1,4,1,1][d % 4],      # Alternate row, last column, row, row
                     6: lambda d: [1,2,3,6][d % 4],
                     7: lambda d: [*[1,2,3]*5,6][d%16],
                     8: lambda d: [1,2,3,1,2,3,1,2,3,6][d%10], # Used by separated_enc_dec model with lyrics
                     9: lambda d: [1,2,3,0][d % 4],
                     10: lambda d: [*[1,2,3,1,2,3,1,2,3],*[1,2,3,1,2,3,1,2,3,6]*7][d%79], # Used by large separated_enc_dec model with lyrics
                     11: lambda d: [6,6,0][d%3] if d%16 == 15 else [1,2,3][d%3],
                     12: lambda d: [7,7,0][d%3] if d%16 == 15 else [1,2,3][d%3], # Used by single_enc_dec model with lyrics
                     }[attn_order]

        attn_cycle = {0:1, 1:2, 2:3, 3:2, 4:2, 5:4, 6:4, 7:16, 8:10, 9:4, 10:79, 11:16, 12:16}[attn_order]
        #assert n_depth % attn_cycle == 0, f'Depth {n_depth} not a multiple of cycle {attn_cycle} for attn_order {attn_order}'

        attn_block = lambda d: ResAttnBlock(n_in=n_in, n_ctx=n_ctx, n_head=n_head,
                                  attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                  afn=afn, scale=scale, mask=mask,
                                  zero_out=zero_out if attn_func(d) !=6 else True,
                                  init_scale=init_scale, res_scale=res_scale,
                                  m_attn=m_attn, m_mlp=m_mlp,
                                  checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp,
                                  attn_func=attn_func(d), blocks=blocks, spread=spread,
                                  encoder_dims=encoder_dims, prime_len=prime_len)

        self.checkpoint_res = checkpoint_res
        self._attn_mods = nn.ModuleList()
        for d in range(n_depth):
            self._attn_mods.append(attn_block(d))
        self.ws = []


    def set_record_attn(self, record_attn):
        """
        Arguments:
            record_attn (bool or set): Makes forward prop dump self-attention
                softmaxes to self.ws. Either a set of layer indices indicating
                which layers to store, or a boolean value indicating whether to
                dump all.
        """
        def _should_record_attn(layer_idx):
            if isinstance(record_attn, bool):
                return record_attn
            return layer_idx in record_attn
        for i, l in enumerate(self._attn_mods):
            l.attn.record_attn = _should_record_attn(i)
        if record_attn:
            assert self.ws == []
            for l in self._attn_mods:
                assert l.attn.w == None
        else:
            self.ws = []
            for l in self._attn_mods:
                l.attn.w = None

    def forward(self, x, encoder_kv=None, sample=False, fp16=False, fp16_out=False):
        if fp16:
            x = x.half()

        # Blocks
        for i,l in enumerate(self._attn_mods):
            if self.checkpoint_res == 1 and not sample:
                if l.attn_func == 6:
                    assert encoder_kv is not None
                    f = functools.partial(l, sample=sample)
                    x = checkpoint(f, (x, encoder_kv), l.parameters(), True)
                else:
                    f = functools.partial(l, encoder_kv=None, sample=sample)
                    x = checkpoint(f, (x,), l.parameters(), True)
            else:
                if l.attn_func == 6:
                    x = l(x, encoder_kv=encoder_kv, sample=sample)
                else:
                    x = l(x, encoder_kv=None, sample=sample)
            if l.attn.record_attn:
                self.ws.append(l.attn.w)
        if not fp16_out:
            x = x.float()
        return x

    def check_cache(self, n_samples, sample_t, fp16):
        for l in self._attn_mods:
            l.attn.check_cache(n_samples, sample_t, fp16)

    def del_cache(self):
        for l in self._attn_mods:
            l.attn.del_cache()

    def check_sample(self):
        bs, l, s, d = (4, self.n_ctx, self.encoder_dims, self.n_in)
        prime = 5
        with t.no_grad():
            encoder_kv = t.randn(bs, s, d).cuda()
            x = t.randn(bs, l, d).cuda()
            y_forw = self.forward(x, encoder_kv=encoder_kv, sample=True)

            self.del_cache()
            x_chunks = t.chunk(x, 4, dim=1)
            y_chunks = []
            n = 0
            for x_chunk in x_chunks:
                self.check_cache(bs, n, False)
                y_chunk = self.forward(x_chunk, encoder_kv=encoder_kv, sample=True)
                y_chunks.append(y_chunk)
                n += x_chunk.shape[1]
            self.check_cache(bs, n, False)
            y_forw_in_chunks = t.cat(y_chunks, dim=1)

            max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"


#if __name__ == '__main__':
#    from jukebox.utils.dist_utils import setup_dist_from_mpi
#    setup_dist_from_mpi(port=29600)
#    n_in = 16
#    n_ctx = 192
#    n_head = 4
#    n_depth = 12
#    blocks = 16
#    for attn_order in [0,2,6]:
#        encoder_dims = {0: 0, 2: 0, 6: 64}[attn_order]
#        prior = Transformer(n_in, n_ctx, n_head, n_depth, mask=True, attn_order=attn_order, encoder_dims=encoder_dims, blocks=blocks).cuda()
#        prior.training = False
#        prior.check_sample()
#        print(f"Checked attn_order: {attn_order}")

In [None]:
# jukebox/jukebox/data/labels.py /

class EmptyLabeller():
    def get_label(self, artist=None, genre=None, lyrics=None, total_length=None, offset=None):
        y = np.array([], dtype=np.int64)
        info = dict(artist="n/a", genre="n/a", lyrics=[], full_tokens=[])
        return dict(y=y, info=info)

    def get_batch_labels(self, metas, device='cpu'):
        ys, infos = [], []
        for meta in metas:
            label = self.get_label()
            y, info = label['y'], label['info']
            ys.append(y)
            infos.append(info)

        ys = t.stack([t.from_numpy(y) for y in ys], dim=0).to(device).long()
        assert ys.shape[0] == len(metas)
        assert len(infos) == len(metas)
        return dict(y=ys, info=infos)

## **ConditionalAutoregressive2D - The main prior**

In [None]:
# jukebox/jukebox/prior/autoregressive.py /

def get_normal(*shape, std=0.01):
    w = t.empty(shape)
    nn.init.normal_(w, std=std)
    return w

def roll(x, n):
    return t.cat((x[:, -n:], x[:, :-n]), dim=1)

def split_chunks(length, chunk_size):
    n_passes = (length + chunk_size - 1) // chunk_size
    chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
    assert sum(chunk_sizes) == length
    return chunk_sizes

class PositionEmbedding(nn.Module):
    def __init__(self, input_shape, width, init_scale=1.0, pos_init=False):
        super().__init__()
        self.input_shape = input_shape
        self.input_dims = input_dims = np.prod(input_shape)
        self.pos_init = pos_init
        if pos_init:
            self.register_buffer('pos', t.tensor(get_pos_idx(input_shape)).long())
            self._pos_embs = nn.ModuleList()
            for i in range(len(input_shape)):
                emb = nn.Embedding(input_shape[i], width)
                nn.init.normal_(emb.weight, std=0.02)
                self._pos_embs.append(emb)
        else:
            self.pos_emb = nn.Parameter(get_normal(input_dims, width, std=0.01 * init_scale))

    def forward(self):
        if self.pos_init:
            pos_emb = sum([self._pos_embs[i](self.pos[:,i]) for i in range(len(self.input_shape))])
        else:
            pos_emb = self.pos_emb
        return pos_emb

class ConditionalAutoregressive2D(nn.Module):
    def __init__(self, input_shape, bins,
                 width=128, depth=2, heads=1,
                 attn_dropout=0.0, resid_dropout=0.0, emb_dropout=0.0, mask=True,
                 zero_out=False, init_scale=1.0, res_scale=False, pos_init=False,
                 m_attn=0.25, m_mlp=1,
                 checkpoint_res=0, checkpoint_attn=0, checkpoint_mlp=0,
                 attn_order=0, blocks=None, spread=None, x_cond=False, y_cond=False,
                 encoder_dims=0, only_encode=False, merged_decoder=False, prime_len=None):
        super().__init__()
        self.input_shape = input_shape
        self.input_dims = input_dims = np.prod(input_shape)
        self.encoder_dims = encoder_dims
        self.bins = bins
        self.width = width
        self.depth = depth

        # A very simple token embedding which converts tokens into real valued vector with embedding dimension size
        self.x_emb = nn.Embedding(bins, width)
        nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale)
        self.x_emb_dropout = nn.Dropout(emb_dropout)

        # Mostly none in our case therefore we prepend a start token
        self.y_cond = y_cond
        self.x_cond = x_cond
        if not y_cond:
            self.start_token = nn.Parameter(get_normal(1, width, std=0.01 * init_scale))

        # Positional encoding for obious reasons
        self.pos_emb = PositionEmbedding(input_shape=input_shape, width=width, init_scale=init_scale, pos_init=pos_init)
        self.pos_emb_dropout = nn.Dropout(emb_dropout)

        # The main scalable Transformer : the point to note is that it takes up most of the params passed to Conditional Autoregressive
        self.transformer = Transformer(n_in=width, n_ctx=input_dims, n_head=heads, n_depth=depth,
                                       attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                       afn='quick_gelu', scale=True, mask=mask,
                                       zero_out=zero_out, init_scale=init_scale, res_scale=res_scale,
                                       m_attn=m_attn, m_mlp=m_mlp,
                                       checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp, checkpoint_res=checkpoint_res,
                                       attn_order=attn_order, blocks=blocks, spread=spread,
                                       encoder_dims=encoder_dims, prime_len=prime_len)

        # This setup might be for lyrical pretraining ??
        self.only_encode = only_encode
        self.prime_len = prime_len
        if merged_decoder:
            # Merged piped model uses this setup
            self.add_cond_after_transformer = False
            self.share_x_emb_x_out = False
        else:
            self.add_cond_after_transformer = True
            self.share_x_emb_x_out = True

        # This is not clear for now
        if not only_encode:
            self.x_out = nn.Linear(width, bins, bias=False)
            if self.share_x_emb_x_out:
                self.x_out.weight = self.x_emb.weight
            self.loss = t.nn.CrossEntropyLoss()

    def preprocess(self, x):
        # Input: x is NHWC and uint8. Converted to NL and long
        # Can include stuff like bitpacking, reordering here.
        N = x.shape[0]
        return x.view(N, -1).long()

    def postprocess(self, x, sample_tokens=None):
        # Convert back from NL and long to NHWC
        N = x.shape[0]
        assert (0 <= x).all() and (x < self.bins).all()
        if sample_tokens is None or sample_tokens==self.input_dims:
            return x.view(N, *self.input_shape)
        else:
            return x.view(N, -1)

    def forward(self, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, loss_full=False,
                encode=False, get_preds=False, get_acts=False, get_sep_loss=False):
        # Preprocess.
        with t.no_grad():
            x = self.preprocess(x)

        N, D = x.shape
        assert isinstance(x, t.cuda.LongTensor)
        assert (0 <= x).all() and (x < self.bins).all()

        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"{x_cond.shape} != {(N, D, self.width)} nor {(N, 1, self.width)}. Did you pass the correct --sample_length?"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), device=x.device, dtype=t.float)

        x_t = x # Target
        x = self.x_emb(x) # X emb
        x = roll(x, 1) # Shift by 1, and fill in start token
        if self.y_cond:
            x[:,0] = y_cond.view(N, self.width)
        else:
            x[:,0] = self.start_token

        # print(x.shape)
        # x = x[:,:2048,:]
        # print(x.shape)

        x = self.x_emb_dropout(x) + self.pos_emb_dropout(self.pos_emb()) + x_cond # Pos emb and dropout

        #adding encoder kv(input) manually don't do it!
        # encode_kv = x

        x = self.transformer(x, encoder_kv=encoder_kv, fp16=fp16) # Transformer
        if self.add_cond_after_transformer: # Piped doesnt add x_cond
            x = x + x_cond

        acts = x
        if self.only_encode:
            return x
        x = self.x_out(x) # Predictions

        if get_sep_loss:
            assert self.prime_len is not None
            x_prime = x[:, :self.prime_len].reshape(-1, self.bins)
            x_gen = x[:, self.prime_len:].reshape(-1, self.bins)

            prime_loss = F.cross_entropy(x_prime, x_t[:, :self.prime_len].reshape(-1)) / np.log(2.)
            gen_loss = F.cross_entropy(x_gen, x_t[:, self.prime_len:].reshape(-1)) / np.log(2.)

            loss = (prime_loss, gen_loss) # Note order! Prime is first
        else:
            loss = F.cross_entropy(x.view(-1, self.bins), x_t.view(-1)) / np.log(2.)  # Loss

        if get_preds:
            return loss, x
        elif get_acts:
            return loss, acts
        else:
            return loss, None

    def get_emb(self, sample_t, n_samples, x, x_cond, y_cond):
        N, D = n_samples, self.input_dims
        if sample_t == 0:
            # Fill in start token
            x = t.empty(n_samples, 1, self.width).cuda()
            if self.y_cond:
                x[:, 0] = y_cond.view(N, self.width)
            else:
                x[:, 0] = self.start_token
        else:
            assert isinstance(x, t.cuda.LongTensor)
            assert (0 <= x).all() and (x < self.bins).all()
            x = self.x_emb(x)
        assert x.shape == (n_samples, 1, self.width)
        if x_cond.shape == (N, D, self.width):
            cond = x_cond[:, sample_t:sample_t + 1, :]
        else:
            cond = x_cond
        x = x + self.pos_emb()[sample_t:sample_t + 1] + cond  # Pos emb, dropout is identity at eval time
        assert x.shape == (n_samples, 1, self.width)
        return x, cond

    def sample(self, n_samples, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0,
               get_preds=False, sample_tokens=None):
        assert self.training == False

        print(self.input_dims)

        if sample_tokens is None: sample_tokens=self.input_dims
        N, D = n_samples, self.input_dims
        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()

        with t.no_grad():
            xs, x = [], None
            if get_preds:
                preds = []
            for sample_t in get_range(range(0, sample_tokens)):
                x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                self.transformer.check_cache(n_samples, sample_t, fp16)
                x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer
                if self.add_cond_after_transformer:
                    x = x + cond
                assert x.shape == (n_samples, 1, self.width)
                x = self.x_out(x) # Predictions
                if get_preds:
                    preds.append(x.clone())
                # Adjust logits
                x = x / temp
                x = filter_logits(x, top_k=top_k, top_p=top_p)
                x = t.distributions.Categorical(logits=x).sample() # Sample and replace x
                assert x.shape == (n_samples, 1)
                xs.append(x.clone())

            del x
            self.transformer.del_cache()

            x = t.cat(xs, dim=1)
            if get_preds:
                preds = t.cat(preds, dim=1)
            x = self.postprocess(x, sample_tokens)
        if get_preds:
            return x, preds
        else:
            return x

    def primed_sample(self, n_samples, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0,
                      top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None):
        assert self.training == False

        if sample_tokens is None: sample_tokens=self.input_dims
        # Preprocess.
        with t.no_grad():
            x = self.preprocess(x)
        assert isinstance(x, t.cuda.LongTensor)
        assert (0 <= x).all() and (x < self.bins).all()
        assert x.shape[0] == n_samples
        xs = t.split(x, 1, dim=1)
        xs = list(xs)

        # print(xs)
        # print(len(xs))

        assert len(xs) < sample_tokens

        N, D = n_samples, self.input_dims
        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()

        with t.no_grad():
            if get_preds:
                preds = []

            # Fill up key/value cache for past context by runing forward pass.
            # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage.
            if chunk_size is None:
                chunk_size = len(xs)
            #assert len(xs) % chunk_size == 0, f'expected {len(xs)} to be divisible by {chunk_size}'
            chunk_sizes = split_chunks(len(xs), chunk_size)
            x_primes = []
            start = 0
            x = None
            for current_chunk_size in get_range(chunk_sizes):
                xs_prime, conds_prime = [], []
                for sample_t in range(start, start + current_chunk_size):
                    x_prime, cond_prime = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                    x = xs[sample_t]
                    xs_prime.append(x_prime)
                    conds_prime.append(cond_prime)
                start = start + current_chunk_size

                x_prime, cond_prime = t.cat(xs_prime, dim=1), t.cat(conds_prime, dim=1)
                assert x_prime.shape == (n_samples, current_chunk_size, self.width)
                assert cond_prime.shape == (n_samples, current_chunk_size, self.width)
                del xs_prime
                del conds_prime
                if not get_preds:
                    del cond_prime
                x_prime = self.transformer(x_prime, encoder_kv=encoder_kv, sample=True, fp16=fp16)

                if get_preds:
                    if self.add_cond_after_transformer:
                        x_prime = x_prime + cond_prime
                    assert x_prime.shape == (n_samples, current_chunk_size, self.width)
                    del cond_prime
                    x_primes.append(x_prime)
                else:
                    del x_prime

            if get_preds:
                x_prime = t.cat(x_primes, dim=1)
                assert x_prime.shape == (n_samples, len(xs), self.width)
                x_prime = self.x_out(x_prime)  # Predictions
                preds.append(x_prime)

            empty_cache()
            self.transformer.check_cache(n_samples, len(xs), fp16)

            x = xs[-1]
            assert x.shape == (n_samples, 1)
            empty_cache()
            for sample_t in get_range(range(len(xs), sample_tokens)):
                x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                self.transformer.check_cache(n_samples, sample_t, fp16)
                x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer
                if self.add_cond_after_transformer:
                    x = x + cond
                assert x.shape == (n_samples, 1, self.width)
                x = self.x_out(x) # Predictions
                if get_preds:
                    preds.append(x)
                # Adjust logits
                x = x / temp
                x = filter_logits(x, top_k=top_k, top_p=top_p)
                x = t.distributions.Categorical(logits=x).sample() # Sample and replace x
                assert x.shape == (n_samples, 1)
                xs.append(x.clone())

            del x
            self.transformer.del_cache()

            x = t.cat(xs, dim=1)
            if get_preds:
                preds = t.cat(preds, dim=1)
            x = self.postprocess(x, sample_tokens)
        if get_preds:
            return x, preds
        else:
            return x

    def check_sample(self, chunk_size):
        bs, l, d = (4, self.input_dims, self.width)
        prime = int(self.input_dims//8*7)
        enc_l = self.encoder_dims
        with t.no_grad():
            y_cond = t.randn(bs, 1, d).cuda() if self.y_cond else None
            x_cond = t.randn(bs, l, d).cuda() if self.x_cond else None
            encoder_kv = t.randn(bs, enc_l, d).cuda()

            x, preds_sample = self.sample(bs, x_cond, y_cond, encoder_kv, get_preds=True)
            loss, preds_forw = self.forward(x, x_cond, y_cond, encoder_kv, get_preds=True)
            max_err = t.max(t.abs(preds_sample - preds_forw))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(preds_sample - preds_forw)[:, i, :]) > 1e-6]}"

            x_prime = x.view(bs, -1)[:,:prime]
            # unchunked
            x, preds_sample = self.primed_sample(bs, x_prime.clone(), x_cond, y_cond, encoder_kv, get_preds=True)
            assert (x.view(bs, -1)[:,:prime] == x_prime).all(), "Priming samples don't match"
            loss, preds_forw = self.forward(x, x_cond, y_cond, encoder_kv, get_preds=True)
            max_err = t.max(t.abs(preds_sample - preds_forw))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(preds_sample - preds_forw)[:, i, :]) > 1e-6]}"

            # chunked
            x, preds_sample = self.primed_sample(bs, x_prime.clone(), x_cond, y_cond, encoder_kv, get_preds=True, chunk_size=chunk_size)
            assert (x.view(bs, -1)[:,:prime] == x_prime).all(), "Priming samples don't match"
            loss, preds_forw = self.forward(x, x_cond, y_cond, encoder_kv, get_preds=True)
            max_err = t.max(t.abs(preds_sample - preds_forw))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(preds_sample - preds_forw)[:, i, :]) > 1e-6]}"


def test_prior(input_shape, encoder_dims, blocks, heads, chunk_size):
    bins = 512
    width = 32
    depth = 2
    prime_len = encoder_dims
    for x_cond in [True, False]:
        for y_cond in [True, False]:
            for attn_order in [0,2,6,12]:
                prior = ConditionalAutoregressive2D(input_shape, bins,
                                                    width=width, depth=depth, heads=heads,
                                                    attn_order=attn_order, blocks=blocks,
                                                    x_cond=x_cond, y_cond=y_cond,
                                                    encoder_dims=encoder_dims, prime_len=prime_len).cuda()
                prior.training = False
                prior.check_sample(chunk_size)
                print(f"Checked x_cond: {x_cond}, y_cond: {y_cond}, attn_order: {attn_order}")
            # prior.apply(_convert_mlp_traced)
            # prior.check_sample()
            # print(f"Checked traced x_cond: {x_cond}, y_cond: {y_cond}")


#if __name__ == '__main__':
#    from jukebox.utils.dist_utils import setup_dist_from_mpi
#    setup_dist_from_mpi(port=29600)
#    test_cases = [
#        ((6144,), 384, 64, 2, 23),
#        ((6144,), 384, 64, 2, 8),
#        ((8192,), 512, 128, 2, 16),
#    ]
#    for test_case in test_cases:
#        test_prior(*test_case)

In [None]:
# jukebox/jukebox/prior/conditioners.py /

class Conditioner(nn.Module):
    def __init__(self, input_shape, bins, down_t, stride_t, out_width, init_scale, zero_out, res_scale, **block_kwargs):
        super().__init__()
        self.x_shape = input_shape

        # Embedding
        self.width = out_width
        self.x_emb = nn.Embedding(bins, out_width)
        nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale)

        # Conditioner
        self.cond = DecoderConvBock(self.width, self.width, down_t, stride_t, **block_kwargs, zero_out=zero_out, res_scale=res_scale)
        self.ln = LayerNorm(self.width)

    def preprocess(self, x):
        x = x.permute(0,2,1) # NTC -> NCT
        return x

    def postprocess(self, x):
        x = x.permute(0,2,1) # NCT -> NTC
        return x

    def forward(self, x, x_cond=None):
        N = x.shape[0]
        # assert_shape(x, (N, *self.x_shape))
        if x_cond is not None:
            assert_shape(x_cond, (N, *self.x_shape, self.width))
        else:
            x_cond = 0.0
        # Embed x
        x = x.long()
        x = self.x_emb(x)
        assert_shape(x, (N, *self.x_shape, self.width))
        x = x + x_cond

        # Run conditioner
        x = self.preprocess(x)
        x = self.cond(x)
        x = self.postprocess(x)
        x = self.ln(x)
        return x

def flip(x):
    def _flip(x):
        return x.permute(0,2,1).contiguous()
    if isinstance(x, (list, tuple)):
        return [flip(z) for z in x]
    return _flip(x)

class SimpleEmbedding(nn.Module):
    def __init__(self, bins, out_width, init_scale):
        super().__init__()
        self.bins = bins
        self.emb = nn.Embedding(bins, out_width)
        nn.init.normal_(self.emb.weight, std=0.01 * init_scale)

    def forward(self, y):
        assert len(y.shape) == 2, f"Expected shape with 2 dims, got {y.shape}"
        assert isinstance(y, t.cuda.LongTensor), f"Expected dtype {t.cuda.LongTensor}, got {y.dtype}"
        assert (0 <= y).all() and (y < self.bins).all(), f"Bins {self.bins}, got label {y}"
        return self.emb(y)

class RangeEmbedding(nn.Module):
    # Interpolating
    # Interpolate so that [pos_start, pos_end] <-> position tensor of length n_ctx
    #
    # Binning
    # For each pos in position tensor, find its bin
    # [start,end) mapped to [0,1,...,bins-1]
    # [start,end) -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1]
    # NOTE: Open ended interval on right, so start <= pos < end, not <= end
    def __init__(self, n_time, bins, range, out_width, init_scale, clamp=False):
        super().__init__()
        self.n_time = n_time
        self.bins = bins
        self.emb = nn.Embedding(bins, out_width)
        nn.init.normal_(self.emb.weight, std=0.01 * init_scale)
        self.pos_min, self.pos_max = range
        self.clamp = clamp

    def forward(self, pos_start, pos_end=None):
        # Check if [pos_start,pos_end] in [pos_min, pos_max)
        assert len(pos_start.shape) == 2, f"Expected shape with 2 dims, got {pos_start.shape}"
        assert (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all(), f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}"
        pos_start = pos_start.float()
        if pos_end is not None:
            assert len(pos_end.shape) == 2, f"Expected shape with 2 dims, got {pos_end.shape}"
            if self.clamp:
                pos_end = pos_end.clamp(self.pos_min, self.pos_max)
            assert (self.pos_min <= pos_end).all() and (pos_end <= self.pos_max).all(), f"Range is [{self.pos_min},{self.pos_max}), got {pos_end}"
            pos_end = pos_end.float()
        # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx
        n_time = self.n_time
        if n_time != 1:
            assert pos_end is not None
            interpolation  = (t.arange(0, n_time, dtype=t.float, device='cuda').view(1,n_time)/n_time)
            position = pos_start + (pos_end - pos_start)*interpolation
        else:
            position = pos_start

        # Bin each value to bins
        normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min) # [0,1)
        bins = (self.bins * normalised_position).floor().long().detach() # [0,1) -> [0,1..,bins) -> [0,1...,bins-1]
        return self.emb(bins)

# class LabelConditioner(nn.Module):
#    def __init__(self, y_bins, t_bins, sr, min_duration, max_duration, n_time, out_width, init_scale, max_bow_genre_size, include_time_signal):
#        super().__init__()
#        self.n_time = n_time
#        self.out_width = out_width
#        assert len(y_bins) == 2, f"Expecting (genre, artist) bins, got {y_bins}"
#        bow_genre_bins, artist_bins = y_bins
#        self.max_bow_genre_size = max_bow_genre_size
#        self.bow_genre_emb = SimpleEmbedding(bow_genre_bins, out_width, init_scale)
#        self.artist_emb = SimpleEmbedding(artist_bins, out_width, init_scale)
#        self.include_time_signal = include_time_signal
#        if self.include_time_signal:
#            t_ranges = ((min_duration * sr, max_duration * sr),  # Total length
#                        (0.0, max_duration * sr),                # Absolute pos
#                        (0.0, 1.0))                              # Relative pos
#            assert len(t_ranges) == 3, f"Expecting (total, absolute, relative) ranges, got {t_ranges}"
#            total_length_range, absolute_pos_range, relative_pos_range = t_ranges
#            self.total_length_emb = RangeEmbedding(1, t_bins, total_length_range, out_width, init_scale)
#            self.absolute_pos_emb = RangeEmbedding(n_time, t_bins, absolute_pos_range, out_width, init_scale)
#            self.relative_pos_emb = RangeEmbedding(n_time, t_bins, relative_pos_range, out_width, init_scale, clamp=True)#

#    def forward(self, y):
#        assert len(y.shape) == 2, f"Expected shape with 2 dims, got {y.shape}"
#       assert y.shape[-1] == 4 + self.max_bow_genre_size, f"Expected shape (N,{4 + self.max_bow_genre_size}), got {y.shape}"
#        assert isinstance(y, t.cuda.LongTensor), f"Expected dtype {t.cuda.LongTensor}, got {y.dtype}"
#        N = y.shape[0]
#        total_length, offset, length, artist, genre = y[:,0:1], y[:,1:2], y[:,2:3], y[:,3:4], y[:,4:]

#        # Start embedding of length 1
#        artist_emb = self.artist_emb(artist)
#        # Empty genre slots are denoted by -1. We mask these out.
#        mask = (genre >= 0).float().unsqueeze(2)
#        genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)
#        start_emb = genre_emb + artist_emb
#        assert_shape(start_emb, (N, 1, self.out_width))

#        # Pos embedding of length n_ctx
#        if self.include_time_signal:
#            start, end = offset, offset + length
#            total_length, start, end = total_length.float(), start.float(), end.float()
#            pos_emb = self.total_length_emb(total_length) + self.absolute_pos_emb(start, end) + self.relative_pos_emb(start/total_length, end/total_length)
#            assert_shape(pos_emb, (N, self.n_time, self.out_width))
#        else:
#            pos_emb = None
#        return start_emb, pos_emb

## **Prior**

In [None]:
# jukebox/jukebox/prior/prior.py /

"""
Model the prior on vq codes conditioned on timing, artist, genre, lyrics and codes from levels above.
To condition on the timing, genre and artist, we use the LabelConditioner class
To condition on the codes from the level above, we use the Conditioner class
To condition on lyrics, we allow two types of priors:
- Separate Encoder Decoder: This is the usual encoder-decoder style transformer. The encoder transformer autoregressively
models the lyrics, and we use its last layer to produce keys/values that are attened to by the decoder transformer
- Single Encoder Decoder: This is a simplification where we combine them into a single model. We merge the text vocab
and VQ vocab into a single large vocab, and the lyric tokens and VQ tokens into a single longer sequence of tokens which
we autoregressively model together.
"""
class SimplePrior(nn.Module):
    def __init__(self, z_shapes, l_bins, encoder, decoder, level,
                 downs_t, strides_t, labels, prior_kwargs, x_cond_kwargs, y_cond_kwargs,
                 prime_kwargs, copy_input, labels_v3=False,
                 merged_decoder=False, single_enc_dec=False):
        super().__init__()

        #This is trash all values are False or 0 (Assuming this is used for lyrics in the encoder part)
        self.use_tokens = prime_kwargs.pop('use_tokens')
        self.n_tokens = prime_kwargs.pop('n_tokens')
        self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction')
        self.copy_input = copy_input

        # print(self.use_tokens)
        # print(self.n_tokens)
        # print(self.prime_loss_fraction)

        if self.copy_input:                   #No copy input no prime_kwargs
            prime_kwargs['bins'] = l_bins

        # A list containing encoded dimesion of input at each level of vq-vae
        self.z_shapes = z_shapes
        self.levels = len(self.z_shapes)        # Length of such a list would obiously provide total no of levels

        # Finally z_shape specifies the encoded input dimension of current level (In our case mostly level 0 would be used)
        self.z_shape = self.z_shapes[level]

        self.level = level
        assert level < self.levels, f"Total levels {self.levels}, got level {level}"

        #obiously codebook size in the vq-vae
        self.l_bins = l_bins

        # vq-vae encoder decoder
        # Passing functions instead of the vqvae module to avoid getting params
        self.encoder = encoder
        self.decoder = decoder

        # X conditioning and even Y conditioning is False for now in our case
        # X conditioning
        self.x_cond = (level != (self.levels - 1))          # True if level is not the top one i.e 2 in general case
        self.cond_level = level + 1

        # Y conditioning
        self.y_cond = labels

        self.single_enc_dec = single_enc_dec
        # X conditioning
        if self.x_cond:
            print('x conditoning')
            self.conditioner_blocks = nn.ModuleList()
            conditioner_block = lambda _level: Conditioner(input_shape=z_shapes[_level],
                                                          bins=l_bins,
                                                          down_t=downs_t[_level],
                                                          stride_t=strides_t[_level],
                                                          **x_cond_kwargs)
            #if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)")
            self.conditioner_blocks.append(conditioner_block(self.cond_level))

        # Y conditioning
        if self.y_cond:
            print('y conditoning')
            self.n_time = self.z_shape[0] # Assuming STFT=TF order and raw=T1 order, so T is first dim
            self.y_emb = LabelConditioner(n_time=self.n_time,include_time_signal=not self.x_cond,**y_cond_kwargs)

        # Lyric conditioning
        if single_enc_dec:
            # Single encoder-decoder transformer
            self.prior_shapes = [(self.n_tokens,), prior_kwargs.pop('input_shape')]
            self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')]
            self.prior_dims = [np.prod(shape) for shape in self.prior_shapes]
            self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1]
            self.prior_width = prior_kwargs['width']
            print_once(f'Creating cond. autoregress with prior bins {self.prior_bins}, ')
            print_once(f'dims {self.prior_dims}, ')
            print_once(f'shift {self.prior_bins_shift}')
            print_once(f'input shape {sum(self.prior_dims)}')
            print_once(f'input bins {sum(self.prior_bins)}')
            print_once(f'Self copy is {self.copy_input}')

            self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[0], self.prior_dims[1]
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(input_shape=(sum(self.prior_dims),),
                                                     bins=sum(self.prior_bins),
                                                     x_cond=(self.x_cond or self.y_cond), y_cond=True,
                                                     prime_len=self.prime_loss_dims,
                                                     **prior_kwargs)

        else:
            # Separate encoder-decoder transformer
            if self.n_tokens != 0 and self.use_tokens:
                from jukebox.transformer.ops import Conv1D
                prime_input_shape = (self.n_tokens,)
                self.prime_loss_dims = np.prod(prime_input_shape)
                self.prime_acts_width, self.prime_state_width = prime_kwargs['width'], prior_kwargs['width']
                self.prime_prior = ConditionalAutoregressive2D(input_shape=prime_input_shape, x_cond=False, y_cond=False,
                                                               only_encode=True,
                                                               **prime_kwargs)
                self.prime_state_proj = Conv1D(self.prime_acts_width, self.prime_state_width, init_scale=prime_kwargs['init_scale'])
                self.prime_state_ln = LayerNorm(self.prime_state_width)
                self.prime_bins = prime_kwargs['bins']
                self.prime_x_out = nn.Linear(self.prime_state_width, self.prime_bins, bias=False)
                nn.init.normal_(self.prime_x_out.weight, std=0.02 * prior_kwargs['init_scale'])
            else:
                self.prime_loss_dims = 0
            # This part is used in our case no Lyric conditioning is done therfore prime_loss_dims is zero
            self.gen_loss_dims = np.prod(self.z_shape)
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(x_cond=(self.x_cond or self.y_cond), y_cond=self.y_cond,
                                                     encoder_dims = self.prime_loss_dims, merged_decoder=merged_decoder,
                                                     **prior_kwargs)

        # This is true as self.gen_loss_dims = np.prod(self.z_shape) which gives n_ctx for current level
        # This part below mostly looks like calculation involving compression
        # For now this part calculates the total compression and tries to figure out the initial sample length before compression vq_vae
        self.n_ctx = self.gen_loss_dims
        self.downsamples = calculate_strides(strides_t, downs_t)
        self.cond_downsample = self.downsamples[level+1] if level != self.levels - 1 else None
        self.raw_to_tokens = np.prod(self.downsamples[:level+1])
        self.sample_length = self.n_ctx*self.raw_to_tokens

        if labels:
            self.labels_v3 = labels_v3
            self.labeller = Labeller(self.y_emb.max_bow_genre_size, self.n_tokens, self.sample_length, v3=self.labels_v3)
        else:
            self.labeller = EmptyLabeller()

        print(f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}")


    def get_y(self, labels, start, get_indices=False):
        if isinstance(self.labeller, EmptyLabeller):
            return None
        y = labels['y'].clone()

        # Set sample_length to match this level
        y[:, 2] = int(self.sample_length)

        # Set offset
        y[:, 1:2] = y[:, 1:2] + int(start * self.raw_to_tokens)

        # Set lyric tokens
        indices = self.labeller.set_y_lyric_tokens(y, labels)
        if get_indices:
            return y, indices
        else:
            return y

    def get_z_conds(self, zs, start, end):
        if self.level != self.levels - 1:
            assert start % self.cond_downsample == end % self.cond_downsample == 0
            z_cond = zs[self.level + 1][:,start//self.cond_downsample:end//self.cond_downsample]
            assert z_cond.shape[1] == self.n_ctx//self.cond_downsample
            z_conds = [z_cond]
        else:
            z_conds = None
        return z_conds

    def prior_preprocess(self, xs, conds):
        N = xs[0].shape[0]
        for i in range(len(xs)):
            x, shape, dims = xs[i], self.prior_shapes[i], self.prior_dims[i]
            bins, bins_shift = int(self.prior_bins[i]), int(self.prior_bins_shift[i])
            assert isinstance(x, t.cuda.LongTensor), x
            assert (0 <= x).all() and (x < bins).all()
            #assert_shape(x, (N, *shape))
            xs[i] = (xs[i] + bins_shift).view(N, -1)

        for i in range(len(conds)):
            cond, shape, dims = conds[i], self.prior_shapes[i], self.prior_dims[i]
            if cond is not None:
                assert_shape(cond, (N, dims, self.prior_width))
            else:
                conds[i] = t.zeros((N, dims, self.prior_width), dtype=t.float, device='cuda')

        return t.cat(xs, dim=1), t.cat(conds, dim=1)

    def prior_postprocess(self, z):
        N = z.shape[0]
        dims = (self.prior_dims[0], z.shape[1] - self.prior_dims[0])
        # xs = list(t.split(z, self.prior_dims, dim=1))
        xs = list(t.split(z, dims, dim=1))

        for i in range(len(xs)):
            # x, shape, dims, bins, bins_shift = xs[i], self.prior_shapes[i], self.prior_dims[i], self.prior_bins[i], self.prior_bins_shift[i]
            # assert_shape(x, (N, dims))
            shape = self.prior_shapes[i]
            bins, bins_shift = int(self.prior_bins[i]), int(self.prior_bins_shift[i])
            # xs[i] = (xs[i] - bins_shift).view(N, *shape) #view(N, -1, *shape[1:])
            xs[i] = (xs[i] - bins_shift).view(N, -1, *shape[1:])
            xs[i] = t.clamp(xs[i], min=0)  # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift
            assert (xs[i] < bins).all(), f'rank: , bins: {bins}, dims {dims}, shape {shape}, prior_shape {self.prior_shapes}, bins_shift {bins_shift}, xs[i]: {xs[i]}'

        return xs[-1]

    def x_emb(self, z_conds):
        z_conds = z_conds[:self.cond_level - self.level]
        assert len(z_conds) == len(self.conditioner_blocks) == self.cond_level - self.level, f"Expected {len(z_conds)} == {len(self.conditioner_blocks)} == {self.cond_level} - {self.level}"
        x_cond = None
        for z_cond, conditioner_block in reversed(list(zip(z_conds, self.conditioner_blocks))):
            x_cond = conditioner_block(z_cond, x_cond)
        return x_cond

    def encode(self, x, start_level=None, end_level=None, bs_chunks=1):
        if start_level == None:
            start_level = self.level
        if end_level == None:
            end_level = self.levels
        # Get latents
        with t.no_grad():
            zs = self.encoder(x, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks)
        return zs

    def decode(self, zs, start_level=None, end_level=None, bs_chunks=1):
        if start_level == None:
            start_level = self.level
        if end_level == None:
            end_level = self.levels

        assert len(zs) == end_level - start_level
        with t.no_grad():
            x_out = self.decoder(zs, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks)
        return x_out

    def get_cond(self, z_conds, y):
        if y is not None:
            assert y.shape[1] == 4 + self.y_emb.max_bow_genre_size + self.n_tokens, f"Expected {4} + {self.y_emb.max_bow_genre_size} + {self.n_tokens}, got {y.shape[1]}"
            n_labels = y.shape[1] - self.n_tokens
            y, prime = y[:,:n_labels], y[:,n_labels:]
        else:
            y, prime = None, None
        y_cond, y_pos = self.y_emb(y) if self.y_cond else (None, None)
        x_cond = self.x_emb(z_conds) if self.x_cond else y_pos
        return x_cond, y_cond, prime

    def sample(self, n_samples, z=None, z_conds=None, y=None, fp16=False, temp=1.0, top_k=0, top_p=0.0,
               chunk_size=None, sample_tokens=None):
        N = n_samples
        if z is not None: assert z.shape[0] == N, f"Expected shape ({N},**), got shape {z.shape}"
        if y is not None: assert y.shape[0] == N, f"Expected shape ({N},**), got shape {y.shape}"
        if z_conds is not None:
            for z_cond in z_conds:
                assert z_cond.shape[0] == N,  f"Expected shape ({N},**), got shape {z_cond.shape}"


        no_past_context = (z is None or z.shape[1] == 0)
        #if dist.get_rank() == 0:
        #    name = {True: 'Ancestral', False: 'Primed'}[no_past_context]
        #    print(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}")

        with t.no_grad():
            # Currently x_cond only uses immediately above layer
            x_cond, y_cond, prime = self.get_cond(z_conds, y)
            if self.single_enc_dec:
                # assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed
                if no_past_context:
                    z, x_cond = self.prior_preprocess([prime], [None, x_cond])
                else:
                    z, x_cond = self.prior_preprocess([prime, z], [None, x_cond])
                if sample_tokens is not None:
                    sample_tokens += self.n_tokens
                z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, fp16=fp16, temp=temp,
                                             top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
                z = self.prior_postprocess(z)
            else:
                # This part is used in our case and z decides wheter it is prime or not for now
                encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
                if no_past_context:
                    z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,
                                          top_p=top_p, sample_tokens=sample_tokens)
                else:
                    z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp,
                                             top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
            if sample_tokens is None:
                assert_shape(z, (N, *self.z_shape))
        return z

    # imporatant
    def get_encoder_kv(self, prime, fp16=False, sample=False):
        if self.n_tokens != 0 and self.use_tokens:
            if sample:
                self.prime_prior.cuda()
            N = prime.shape[0]
            prime_acts = self.prime_prior(prime, None, None, None, fp16=fp16)
            assert_shape(prime_acts, (N, self.prime_loss_dims, self.prime_acts_width))
            assert prime_acts.dtype == t.float, f'Expected t.float, got {prime_acts.dtype}'
            encoder_kv = self.prime_state_ln(self.prime_state_proj(prime_acts))
            assert encoder_kv.dtype == t.float, f'Expected t.float, got {encoder_kv.dtype}'
            if sample:
                self.prime_prior.cpu()
                if fp16:
                    encoder_kv = encoder_kv.half()
        else:
            encoder_kv = None
        return encoder_kv

    def get_prime_loss(self, encoder_kv, prime_t):
        if self.use_tokens:
            encoder_kv = encoder_kv.float()
            encoder_kv = self.prime_x_out(encoder_kv)
            prime_loss = nn.functional.cross_entropy(encoder_kv.view(-1, self.prime_bins), prime_t.view(-1)) / np.log(2.)
        else:
            prime_loss = t.tensor(0.0, device='cuda')
        return prime_loss


    # important note search here
    def z_forward(self, z, z_conds=[], y=None, fp16=False, get_preds=False, get_attn_weights=False):
        """
        Arguments:
            get_attn_weights (bool or set): Makes forward prop dump
                self-attention softmaxes to self.prior.transformer.ws. Either a
                set of layer indices indicating which layers to store, or a
                boolean value indicating whether to dump all.
        """
        assert isinstance(get_attn_weights, (bool, set))
        if get_attn_weights:
            self.prior.transformer.set_record_attn(get_attn_weights)
        x_cond, y_cond, prime = self.get_cond(z_conds, y)
        if self.copy_input:
            prime = z[:,:self.n_tokens]
        if self.single_enc_dec:
            z, x_cond = self.prior_preprocess([prime, z], [None, x_cond])
            (prime_loss, gen_loss), preds = self.prior(z, x_cond, y_cond, fp16=fp16, get_sep_loss=True, get_preds=get_preds)
        else:
            # The part below is used here in our prior with encoder_kv and prime_loss to be not taken in account
            encoder_kv = self.get_encoder_kv(prime, fp16=fp16)
            prime_loss = self.get_prime_loss(encoder_kv, prime)
            gen_loss, preds = self.prior(z, x_cond, y_cond, encoder_kv, fp16=fp16, get_preds=get_preds)
        loss = (self.prime_loss_fraction*prime_loss*self.prime_loss_dims/self.total_loss_dims) + \
                   (gen_loss*self.gen_loss_dims/self.total_loss_dims)
        metrics=dict(bpd=gen_loss.clone().detach(), prime_loss=prime_loss.clone().detach(),
                     gen_loss=gen_loss.clone().detach())
        if get_preds:
            metrics["preds"] = preds.clone().detach()
        if get_attn_weights:
            ws = self.prior.transformer.ws
            self.prior.transformer.set_record_attn(False)
            return ws
        else:
            return loss, metrics

    def forward(self, x, y=None, fp16=False, decode=False, get_preds=False):
        bs = x.shape[0]
        z, *z_conds = self.encode(x, bs_chunks=bs)
        loss, metrics = self.z_forward(z=z, z_conds=z_conds, y=y, fp16=fp16, get_preds=get_preds)
        if decode:
            x_out = self.decode([z, *z_conds])
        else:
            x_out = None
        return x_out, loss, metrics


## **hps and checkpoints**

In [None]:
HPARAMS_REGISTRY = {}

class Hyperparams(dict):
    def __getattr__(self, attr):
        return self[attr]

    def __setattr__(self, attr, value):
        self[attr] = value

def setup_hparams(hparam_set_names, kwargs):
    H = Hyperparams()
    if not isinstance(hparam_set_names, tuple):
        hparam_set_names = hparam_set_names.split(",")
    hparam_sets = [HPARAMS_REGISTRY[x.strip()] for x in hparam_set_names if x] + [kwargs]
    for k, v in DEFAULTS.items():
        H.update(v)
    for hps in hparam_sets:
        for k in hps:
            if k not in H:
                raise ValueError(f"{k} not in default args")
        H.update(**hps)
    H.update(**kwargs)
    return H

In [None]:
# Model hps
pr = Hyperparams(
    # Prior
    prior = False,
    n_vocab=1024,
    restore_prior='/content/drive/MyDrive/Indian_Classical_Music_Generation/checkpoints/prior/flute_checkpoint_latest.pth.tar',
    restore_prior_ddp=False,
    max_bow_genre_size=None,
    y_bins=0,
    level=0,
    cond_levels=None,
    t_bins=64,
    y_cond_as_bias=False,
    copy_input=False,
    merged_decoder=True,
    single_enc_dec=False,
    alignment_layer=None,
    alignment_head=None,

    # context length
    n_ctx=8250,
    prior_depth=16,         # prior_depth=3
    prior_width=2048,      # prior_width=128
    heads=8,               # heads = 1
    attn_order=2,
    blocks=110,            #blocks = 15
    spread=None,
    attn_dropout=0.0,
    resid_dropout=0.0,
    emb_dropout=0.0,
    zero_out=False,
    res_scale=False,
    pos_init=False,
    init_scale=0.7,
    m_attn=0.25,
    m_mlp=1.0,
    c_res=0,
    c_attn=0,
    c_mlp=0,
    cond_depth=3,
    cond_width=128,
    cond_m_conv=1.0,
    cond_zero_out=False,
    cond_res_scale=False,
    cond_dilation_growth_rate=1,
    cond_dilation_cycle=None,
    cond_c_res=0,
    min_duration = 23,
    max_duration = 600,
    use_tokens=False,
    n_tokens=0,
    prime_loss_fraction=0.0,
    fp16_params=False,
    labels = False,
    labels_v3 = False,
    iters_before_update=1,
    bs_sample=1,

    # VQVAE
    sr = 11000,
    levels = 1,
    downs_t = (5, 5),
    strides_t = (2, 2),
    emb_width = 64,
    l_bins = 1024,
    l_mu = 0.99,
    commit = 0.02,
    spectral = 0.0,
    multispectral = 1.0,
    loss_fn = 'l2',
    width = 32,
    depth = 4,
    m_conv = 1.0,
    dilation_growth_rate = 3,
    revival_threshold=1.0,
    hvqvae_multipliers = None,
    lmix_l1=0.0,
    lmix_l2 = 1.0,
    lmix_linf=0.02,
    linf_k=2,
    use_bottleneck=True,
    dilation_cycle=None,
    vqvae_reverse_decoder_dilation=True,
    sample_length = 24.0*11000,
    restore_vqvae='/content/drive/MyDrive/Indian_Classical_Music_Generation/checkpoints/vq_vae/vqvae-flute/1/checkpoint_step_1.pth.tar',
    lr=0.000007,
    clip=1.0,
    beta1=0.9,
    beta2=0.999,
    ignore_grad_norm=0,
    weight_decay=0.0,
    eps=1e-08,
    lr_warmup=100.0,
    lr_decay=10000000000.0,
    lr_gamma=1.0,
    lr_scale=1.0,
    lr_use_linear_decay=False,
    lr_start_linear_decay=0,
    lr_use_cosine_decay=False,
    save_iters = 1000,
    save = True,
    aug_blend=False,
    n_inps=1,
    n_hops=2,
    n_segment=1,
    n_total_segment=1,
    n_segment_each=1,
    prime_chunks=4,
    sample_hop_length=30000,
    max_silence_pad_length=0,
    ignore_boundaries=False,
    use_nonrelative_specloss=True,
    multispec_loss_n_fft=(2048,1024,512),
    multispec_loss_hop_length=(240,120,50),
    multispec_loss_window_size=(1200,600,240),
)

In [None]:
# 44100 x 24 = .....
# 12 swaras standing 3 octaves
# 21 x 3

# 11200 x 30 ->
-> 128 -> 11200 / 128 x 30s
-> 32

x_1 ... x_11200 x 30s -> Amplitude

x_1 ... x_11200/128





In [None]:
def load_checkpoint(path):
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        if dist.get_rank() % 8 == 0:
            print("Downloading from gce")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(gs_path, local_path)
        restore = local_path
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint

def save_checkpoint(i,name, model, opt, metrics, hps):
    with t.no_grad():
        prefix = '/content/drive/My Drive/audio_VAE/checkpoints_prior2/top_level'
        # name = time.strftime(prefix + '%m%d_%H_%M_%S.pth.tr')
        save_hps = {**hps}
        save_hps = {k: v for k,v in save_hps.items() if k not in ['metadata_v2','metadata_v3', 'alignments', 'lyric_processor', 'midi_processor']}
        t.save({'hps': save_hps,
                'model': model.state_dict(), # should also save bottleneck k's as buffers
                'opt': opt.state_dict() if opt is not None else None,
                'step': i,
                **metrics},f'{prefix}/checkpoint_{name}.pth.tar')
    return

def restore_model(hps, model, checkpoint_path):
    model.step = 0
    if checkpoint_path != '':
        checkpoint = load_checkpoint(checkpoint_path)
        # checkpoint_hps = Hyperparams(**checkpoint['hps'])
        # for k in set(checkpoint_hps.keys()).union(set(hps.keys())):
        #     if checkpoint_hps.get(k, None) != hps.get(k, None):
        #         print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
        checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
        model.load_state_dict(checkpoint['model'])
        if 'step' in checkpoint: model.step = checkpoint['step']

def restore_opt(opt, shd, checkpoint_path):
    if not checkpoint_path:
        return
    checkpoint = load_checkpoint(checkpoint_path)
    if "opt" in checkpoint:
        opt.load_state_dict(checkpoint['opt'])
    if "step" in checkpoint:
        shd.step(checkpoint['step'])

## **Training**

In [None]:
def make_prior(hps, vqvae,train,device='cuda'):
    #from jukebox.prior.prior import SimplePrior

    prior_kwargs = dict(input_shape=(hps.n_ctx,), bins=vqvae.l_bins,
                        width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads,
                        attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread,
                        attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout,
                        zero_out=hps.zero_out, res_scale=hps.res_scale, pos_init=hps.pos_init,
                        init_scale=hps.init_scale,
                        m_attn=hps.m_attn, m_mlp=hps.m_mlp,
                        checkpoint_res=hps.c_res if train else 0, checkpoint_attn=hps.c_attn if train else 0, checkpoint_mlp=hps.c_mlp if train else 0)

    x_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
                         width=hps.cond_width, depth=hps.cond_depth, m_conv=hps.cond_m_conv,
                         dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle,
                         zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale,
                         checkpoint_res=hps.cond_c_res)  # have to keep this else names wrong

    y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
                         y_bins=hps.y_bins, t_bins=hps.t_bins, sr= hps.sr, min_duration=hps.min_duration,
                         max_duration=hps.max_duration, max_bow_genre_size=hps.max_bow_genre_size)

    if hps.use_tokens and not hps.single_enc_dec:
        prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens, bins=hps.n_vocab,
                            width=hps.prime_width, depth=hps.prime_depth, heads=hps.prime_heads,
                            attn_order=hps.prime_attn_order, blocks=hps.prime_blocks, spread=hps.prime_spread,
                            attn_dropout=hps.prime_attn_dropout, resid_dropout=hps.prime_resid_dropout,
                            emb_dropout=hps.prime_emb_dropout,
                            zero_out=hps.prime_zero_out, res_scale=hps.prime_res_scale,
                            pos_init=hps.prime_pos_init, init_scale=hps.prime_init_scale,
                            m_attn=hps.prime_m_attn, m_mlp=hps.prime_m_mlp,
                            checkpoint_res=hps.prime_c_res if train else 0, checkpoint_attn=hps.prime_c_attn if train else 0,
                            checkpoint_mlp=hps.prime_c_mlp if train else 0)
    else:
        prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens, bins=hps.n_vocab)

    # z_shapes for other levels given this level gets n_ctx codes ()
    rescale = lambda z_shape: (z_shape[0]*hps.n_ctx//vqvae.z_shapes[hps.level][0],)
    z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]

    prior = SimplePrior(z_shapes=z_shapes,
                        l_bins=vqvae.l_bins,
                        encoder=vqvae.encode,
                        decoder=vqvae.decode,
                        level=hps.level,
                        downs_t=vqvae.downs_t,
                        strides_t=vqvae.strides_t,
                        labels=hps.labels,
                        prior_kwargs=prior_kwargs,
                        x_cond_kwargs=x_cond_kwargs,
                        y_cond_kwargs=y_cond_kwargs,
                        prime_kwargs=prime_kwargs,
                        copy_input=hps.copy_input,
                        labels_v3=hps.labels_v3,
                        merged_decoder=hps.merged_decoder,
                        single_enc_dec=hps.single_enc_dec)

    prior.alignment_head = hps.get('alignment_head', None)
    prior.alignment_layer = hps.get('alignment_layer', None)

    if hps.fp16_params:
        print("Converting to fp16 params")
        from jukebox.transformer.ops import _convert_conv_weights_to_fp16
        prior.apply(_convert_conv_weights_to_fp16)
    prior = prior.to(device)
    restore_model(hps, prior, hps.restore_prior)
    if train:
        print(f"Loading prior in train mode")
        pass
    else:
        print(f"Loading prior in eval mode")
        prior.eval()
        freeze_model(prior)
    return prior

In [None]:
def make_vqvae(hps, device='cuda',train = False):
    block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv,
                        dilation_growth_rate=hps.dilation_growth_rate,
                        dilation_cycle=hps.dilation_cycle,
                        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens
        print(f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}")

    vqvae = VQVAE(input_shape=(hps.sample_length,1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t,
                  emb_width=hps.emb_width, l_bins=hps.l_bins,
                  mu=hps.l_mu, commit=hps.commit,
                  spectral=hps.spectral, multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore_model(hps, vqvae, hps.restore_vqvae)
    if train:
        print(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples // raw_to_tokens)
                bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold)
    else:
        print(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
vqvae = make_vqvae(pr, device = t.device("cuda"), train=False)
#print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")

Restored from /content/drive/MyDrive/Indian_Classical_Music_Generation/checkpoints/vq_vae/vqvae-flute/1/checkpoint_step_1.pth.tar
Loading vqvae in eval mode


In [None]:
prior = make_prior(pr, vqvae, train = False)
#print_once(f"Parameters Prior:{count_parameters(prior)}")
model = prior

Level:0, Cond downsample:None, Raw to tokens:32, Sample length:264000.0
Restored from /content/drive/MyDrive/Indian_Classical_Music_Generation/checkpoints/prior/flute_checkpoint_latest.pth.tar
Loading prior in eval mode


In [None]:
print(model)

SimplePrior(
  (prior): ConditionalAutoregressive2D(
    (x_emb): Embedding(1024, 2048)
    (x_emb_dropout): Dropout(p=0.0, inplace=False)
    (pos_emb): PositionEmbedding()
    (pos_emb_dropout): Dropout(p=0.0, inplace=False)
    (transformer): Transformer(
      (_attn_mods): ModuleList(
        (0-15): 16 x ResAttnBlock(
          (attn): FactoredAttention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
          )
          (ln_0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
          )
          (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (x_out): Linear(in_features=2048, out_features=1024, bias=False)
    (loss): CrossEntropyLoss()
  )
)


In [None]:
print(f'model parameters : ',count_parameters(vqvae))

model parameters :  0


# **Inference**

In [None]:
def sample_prior(orig_model, x_in, y, hps,i):
    #if ema is not None: ema.swap()
    orig_model.eval()

    x_in = audio_preprocess(x_in, hps)

    x_in = x_in[:hps.bs_sample]
    bs = x_in.shape[0]
    zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
    assert len(zs_in) == hps.levels
    x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]


    if not hps.labels:
        y = None
    elif hps.level == (hps.levels - 1):
        # Topmost level labels in order
        y = y[:hps.bs_sample]  # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank()
    else:
        # Other levels keep labels to match x_cond
        y = y[:hps.bs_sample]

    # Temp 1.0
    z_enc, *z_conds = orig_model.encode(x_in, bs_chunks=bs)

    # print(z_enc[:,:512].shape)

    z = orig_model.sample(hps.bs_sample,z = z_enc[:,:512],z_conds=z_conds, y=y, fp16=False, temp=1.0)

    # print(z_enc.shape)
    # print(z.shape)
    # print(*z_conds)

    # z = t.cat((z_enc,z),dim = 1)

    x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs)
    x_out = orig_model.decode([z_enc,*z_conds],bs_chunks = bs)

    save_to_wav('/content/drive/My Drive/audio_VAE/final_sample/primed_classical/',x_sample,11000,i)
    save_to_wav('/content/drive/My Drive/audio_VAE/final_sample/primed_classical/',x_out,11000,i,input = True)
    #log_aud(logger, 'sample_x_T1', x_sample, hps)
    #if hps.prior and hps.labels:
    #    log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps)

    # Recons
    #for i in range(len(x_ds)):
    #    log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps)
    return x_sample
    #if ema is not None: ema.swap()
    #logger.flush()

In [None]:
import torch as t

def split_batch(obj, n_samples, split_size):
    n_passes = (n_samples + split_size - 1) // split_size
    if isinstance(obj, t.Tensor):
        return t.split(obj, split_size, dim=0)
    elif isinstance(obj, list):
        return list(zip(*[t.split(item, split_size, dim=0) for item in obj]))
    elif obj is None:
        return [None] * n_passes
    else:
        raise TypeError('Unknown input type')

# Break total_length into hops/windows of size n_ctx separated by hop_length
def get_starts(total_length, n_ctx, hop_length):
    starts = []
    n_ctx = int(n_ctx)
    for start in range(0, total_length - n_ctx + hop_length, hop_length):
        if start + n_ctx >= total_length:
            # Last hop could be smaller, we make it n_ctx to maximise context
            start = total_length - n_ctx
        starts.append(start)
    # print(starts)
    return starts

In [None]:
# Sample a partial window of length<n_ctx with tokens_to_sample new tokens on level=level
def sample_partial_window(zs, labels, sampling_kwargs, level, prior, tokens_to_sample, hps):
    z = zs[level]
    n_ctx = prior.n_ctx
    current_tokens = z.shape[1]
    if current_tokens < n_ctx - tokens_to_sample:
        sampling_kwargs['sample_tokens'] = current_tokens + tokens_to_sample
        start = 0
    else:
        sampling_kwargs['sample_tokens'] = n_ctx
        start = current_tokens - n_ctx + tokens_to_sample

    return sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)

# Sample a single window of length=n_ctx at position=start on level=level
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps):

    # 8192
    n_samples = 1
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level

    start = int(start)
    end = int(end)

    z = zs[level][:,start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print(f'new_tokens : {new_tokens}')
    print(f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens")

    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)
    if(z_conds == None):
      print(f'z_conds at level : {level} : {z_conds}')
    else:
      print(f'z_conds at level : {level} : {z_conds[0].shape}')

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']


    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs)
        z_samples.append(z_samples_i)

    z = t.cat(z_samples, dim=0)
    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:,-new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs

# Sample total_length tokens at level=level with hop_length=hop_length
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps):
    print(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        for start in get_starts(total_length, prior.n_ctx, hop_length):
            # print(zs[1].shape)
            zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps)
    return zs

# Sample multiple levels
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        #
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = int(hps.sample_length//prior.raw_to_tokens)

        # print(f'total_length at {level} : {total_length}')

        hop_length = int(prior.n_ctx)

        if(labels == None):
            p_labels = None
        else:
            p_labels = labels[level]

        zs = sample_level(zs, p_labels, sampling_kwargs[level], level, prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        # x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

        # if dist.get_world_size() > 1:
        #     logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        # else:
        #     logdir = f"{hps.name}/level_{level}"
        # if not os.path.exists(logdir):
        #     os.makedirs(logdir)
        # t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
        # save_wav(logdir, x, hps.sr)
        # if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
        #     alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
        # save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs

In [None]:
# Generate ancestral samples given a list of artists and genres
def ancestral_sample(labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = [t.zeros(1,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

# Continue ancestral sampling from previously saved codes
def continue_sample(zs, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

# Upsample given already generated upper-level codes
def upsample(zs, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors) - 1))
    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

# Prompt the model with raw audio input (dimension: NTC) and generate continuations
def primed_sample(x, labels, sampling_kwargs, priors, hps):
    sample_levels = list(range(len(priors)))
    zs = priors[-1].encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])
    # zs = priors[-1].encode(x,bs_chunks=x.shape[0])
    # zs[1] = zs[1][:,:256]
    zs[0] = zs[0][:,:1024]

    # 8192 -> 1024 tokens (passed from previous audio)

    zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    return zs

In [None]:
lower_level_chunk_size = 16
lower_level_max_batch_size = 1
chunk_size = 16
max_batch_size = 3

In [None]:
sampling_kwargs = [dict(temp=1.5, fp16=False, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=1.5, fp16=False, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=1.5, fp16=False, chunk_size=chunk_size, max_batch_size=max_batch_size)]

In [None]:
def primed_sample_new(x, labels, priors, hps, fp16=False, temp=1.0, chunk_size=None, sample_tokens=None):
    sample_levels = list(range(len(priors)))
    zs = priors[-1].encode(x, bs_chunks=x.shape[0])
    print(zs)

In [None]:
model_list = [model]

In [None]:
def save_to_wav(fname, aud, sr, j, input = False):
    # clip before saving

    print('converting the audio to wav and saving...')
    aud = aud.detach()
    aud = t.clamp(aud, -1, 1).cpu().numpy()
    print(aud.shape)
    for i in list(range(aud.shape[0])):
        if(input):
          soundfile.write(f'{fname}/item_input_{j}.wav', aud[i], samplerate=sr, format='wav')
        else:
          soundfile.write(f'{fname}/item_{j}.wav', aud[i], samplerate=sr, format='wav')

## **Ancestral Sampling**

In [None]:
save_path = '/content/drive/MyDrive/Indian_Classical_Music_Generation/samples/ancestral'

# Ancestral Sampling
for i in range(5):
    z = ancestral_sample(None,
                         sampling_kwargs,
                         model_list,
                         pr)

    x_sample = model_list[0].decode(z, bs_chunks=1)

    save_to_wav(save_path,
                x_sample,
                11200,
                i)
    print()

Sampling level 0
new_tokens : 8250
Sampling 8250 tokens for [0,8250]. Conditioning on 0 tokens
z_conds at level : 0 : None
8250
converting the audio to wav and saving...
(1, 264000, 1)

Sampling level 0
new_tokens : 8250
Sampling 8250 tokens for [0,8250]. Conditioning on 0 tokens
z_conds at level : 0 : None
8250
converting the audio to wav and saving...
(1, 264000, 1)

Sampling level 0
new_tokens : 8250
Sampling 8250 tokens for [0,8250]. Conditioning on 0 tokens
z_conds at level : 0 : None
8250
converting the audio to wav and saving...
(1, 264000, 1)

Sampling level 0
new_tokens : 8250
Sampling 8250 tokens for [0,8250]. Conditioning on 0 tokens
z_conds at level : 0 : None
8250
converting the audio to wav and saving...
(1, 264000, 1)

Sampling level 0
new_tokens : 8250
Sampling 8250 tokens for [0,8250]. Conditioning on 0 tokens
z_conds at level : 0 : None
8250
converting the audio to wav and saving...
(1, 264000, 1)



## **Primed Sampling**

In [None]:
# # Primed Sampling

# count = 0
# for i, x in enumerate(train_loader):
#   if(count >= 5):
#     break

#   x = x.to('cuda', non_blocking=True)
#   x_in = audio_preprocess(x, pr)

#   z_enc, *z_conds = model_list[0].encode(x_in, bs_chunks=1)
#   x_out = model_list[0].decode([z_enc,*z_conds],bs_chunks = 1)
#   save_to_wav('/content/drive/My Drive/audio_VAE/final_sample/',x_out,11200,i + 5,input = True)

#   # main sampling


#   z = primed_sample(x_in,None,sampling_kwargs,model_list,up)
#   x_sample = model_list[0].decode(z, bs_chunks=1)
#   save_to_wav('/content/drive/My Drive/audio_VAE/final_sample/',x_sample,11200,i + 5)
#   count += 1