In [1]:
import torch

In [47]:
import diffusers

In [None]:
from diffusers import VQModel, DownEnc

# Test VQModeL

In [15]:
vq = VQModel(in_channels=1, out_channels=1, latent_channels=1, vq_embed_dim=512, num_vq_embeddings=1000, layers_per_block=3)

In [6]:
vq.device

device(type='cpu')

In [16]:
x = torch.randn(32,1,129,16)
x.shape

torch.Size([32, 1, 129, 16])

In [39]:
#with torch.no_grad():
h = vq.encode(x)
z = vq.decode(h.latents).sample

In [28]:
z.shape

torch.Size([32, 1, 129, 16])

In [41]:
h.keys()

odict_keys(['latents'])

In [33]:
e = vq.quantize(h.latents)

In [42]:
e[0].shape

torch.Size([32, 512, 129, 16])

In [43]:
e[0].dtype

torch.float32

In [46]:
e[1]

tensor(0.4473, grad_fn=<AddBackward0>)

In [45]:
e[-1]

(None, None, tensor([68, 68, 68,  ..., 68, 68, 68]))

In [51]:
c = vq.quantize.get_codebook_entry(indices=torch.LongTensor([68,102,220]), shape=None)

In [52]:
c.shape

torch.Size([3, 512])

In [98]:
vq.quantize.legacy = False
vq.quantize.beta = 95 / 5

# load dataset

In [60]:
from datasets import load_from_disk
import numpy as np
import matplotlib.pyplot as plt

In [54]:
ds = load_from_disk("/media/gagan/Gagan_external/songbird_data/age_resampled_hfdataset/")

In [55]:
len(ds)

514229

In [57]:
x = np.array(ds[0]["spectrogram"])
x.shape

(129, 250)

In [58]:
from birdsong_gan.utils.audio_utils import random_time_crop_spectrogram

In [59]:
y = random_time_crop_spectrogram(x, crop_length=16)
y.shape

(129, 16)

In [81]:
class SpectrogramSnippets(torch.utils.data.Dataset):

    def __init__(self, ds, ntimeframes: int = 16, spec_dtype: str = "float32", log_scale: bool = True):
        super().__init__()
        self.ds = ds
        self.ntimeframes = ntimeframes
        self.log_scale = log_scale
        if spec_dtype == "float16":
            self.spec_dtype = torch.float16
        else:
            self.spec_dtype = torch.float32

    def __len__(self):
        return len(self.ds)
        
    def __getitem__(self, index: int):
        x = np.array(ds[index]["spectrogram"])
        x = random_time_crop_spectrogram(x, self.ntimeframes)
        if self.log_scale:
            x = np.log1p(x)
        x = torch.from_numpy(x).to(self.spec_dtype)
        x = x.view(1, x.shape[0], x.shape[1])  # add channel dim
        return x

In [88]:
dataset = SpectrogramSnippets(ds)

In [89]:
len(dataset)

514229

In [96]:
xx = dataset[124124]
xx.shape

torch.Size([1, 129, 16])

# Training function setup

In [112]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:

    lr: float = 1e-4
    batch_size: int = 128
    gradient_accumulation_steps: int = 1
    alpha: float = 10.
    weight_decay: float = 0.0
    log_every: int = 50
    num_epochs: int = 5
    
config = TrainingConfig()

In [110]:
config

TrainingConfig(lr=0.0001, batch_size=128, gradient_accumulation_steps=1, alpha=10.0, weight_decay=0.0, log_every=50)

In [101]:
model = VQModel(in_channels=1, out_channels=1, latent_channels=1, vq_embed_dim=512, num_vq_embeddings=2000, layers_per_block=3)

In [102]:
model.num_parameters()

1876995

In [103]:
model = model.to("cuda")

In [None]:
# Separate LayerNorm and non-LayerNorm parameters
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

In [111]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [113]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size)

In [108]:
l2loss = torch.nn.MSELoss()

def train(dataloader, model, optimizer, epoch, config):

    for i, x in enumerate(dataloader):

        x = x.to(model.device)

        # commitment loss
        ze = model.encode(x)
        zq, commloss, _ = model.quantize(ze.latents)

        # recon loss
        xhat = model.decode(zq).sample

        l2 = l2loss(xhat, x)

        total_loss = l2 + config.alpha * commloss

        total_loss /= config.gradient_accumulation_steps
        total_loss.backward()

        if (i + 1) % config.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % config.log_every == 0:
            print(f"L2 loss at epoch={epoch}, batch={i} is = {float(l2.detach())}")
            print(f"Comm loss at epoch={epoch}, batch={i} is = {float(commloss.detach())}")

    return model

In [None]:
for n in range(config.num_epochs):

    model = train(train_dataloader, model, optimizer, n, config)

In [None]:
model.

# Discriminator based

In [None]:
from birdsong_gan.models.nets_16col_residual import _netD


class Discriminator(torch.nn.Module):

    def __init__(self, num_discriminator_multiplier: int = 32):
        super().__init__()

        self.convs = nn.ModuleList([
            nn.Conv2d(nc, ndf, kernel_size=4, stride=(2,2), padding=(1,1), bias=False),
            # size H = (129 +2 -4)/2 + 1 = 64, W = (16 +2 -4)/2 + 1 = 8
            nn.Conv2d(ndf, ndf * 2, kernel_size=(4, 3), stride=(2, 1), padding=(1, 0), bias=False),
            # size H = (64 +2 -4)/2 + 1 = 32, W = (8 -3) + 1 = 6
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=(4,3), stride=(2,1), padding=(2,0), bias=False),
            # size H = (32 +4 -4)/2 + 1 = 17, W = (6 -3) + 1 = 4
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=(8,4), stride=(2,1), padding=1, bias=False),
            # H = (17 +2 -8)/2 + 1 = 6, W = (4 +2 -4) + 1 = 3
            nn.Conv2d(ndf * 8, 1, kernel_size=(6,3), stride=1, padding=0, bias=False),
            # H = 6-6 + 1 =1, W = (4 - 3) +1 =  
        ])
        self.lns = nn.ModuleList([nn.LayerNorm([ndf, 64, 8]),
                                  nn.LayerNorm([ndf * 2, 32, 6]),
                                  nn.LayerNorm([ndf * 4, 17, 4]),
                                  nn.LayerNorm([ndf * 8, 6, 3])])
        self.activations = nn.ModuleList([nn.LeakyReLU(0.2),nn.LeakyReLU(0.2),
                                    nn.LeakyReLU(0.2),nn.LeakyReLU(0.2)])
        
        self.loss = torch.nn.BCEWithLogitsLoss()
        
    def true_wp(self, prob, size):
        # generate a uniform random number
        p = torch.rand(size).to(self.device).float()
        # if prob = 0.9, most of the time, this will be True
        p = (p < prob).float()
        return p 

    def forward(self, x: torch.Tensor, label: str = "real") -> torch.Tensor:
        for i in range(4):
            x = self.convs[i](x)
            x = self.lns[i](x)
            x = self.activations[i](x)
            x = self.convs[-1](x)
        x = x.view(-1, 1) # flatten

        if label == "real":
            
        return self.loss(