In [None]:
!pip install torchaudio

Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 18.9MB/s 
Installing collected packages: torchaudio
Successfully installed torchaudio-0.8.1


In [None]:
import torchaudio
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch.backends import cudnn
import numpy as np
import glob
import os
from sklearn.metrics import confusion_matrix


In [None]:
from google.colab import drive
drive.mount('./drive')

Mounted at ./drive


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
SPEC_DATASET_PATH = "./drive/MyDrive/cs753 dataset/spectrograms/"
AUDIO_DATASET_PATH = "./drive/MyDrive/cs753 dataset/audio/"
MODEL_SAVE_PATH = "./drive/MyDrive/cs753 dataset/"
MODEL_FILENAME = "stargan_spectrograms2.pt"

INSTRUMENTS = [
    "Bansuri",
    "Shehnai",
    "Santoor",
    "Sarod",
    "Violin"
]

INSTRUMENT_LABELS = {
    "Bansuri" : 0,
    "Shehnai" : 1,
    "Santoor" : 2,
    "Sarod"   : 3,
    "Violin"  : 4
}

WEIGHT = np.array([
    891,
    1664,
    1122,
    765,
    1193
])
WEIGHT = torch.tensor((WEIGHT / (WEIGHT.sum())) ** -1)
WEIGHT = WEIGHT.float().to(device)

spec_files_path = os.path.join(SPEC_DATASET_PATH, "*.pt")
audio_files_path = os.path.join(AUDIO_DATASET_PATH, "*.pt")
SPEC_FILES = sorted(glob.glob(spec_files_path))
AUDIO_FILES = sorted(glob.glob(audio_files_path))
NUM_FILES = len(SPEC_FILES)

 


In [None]:
print(NUM_FILES)

5637


In [None]:
def normalize(data):
  data = data.type(torch.FloatTensor).unsqueeze(0)
  mean = data.mean(dim=2).unsqueeze(2)
  std = data.std(dim=2).unsqueeze(2)
  indices = (std == 0)
  std[indices] = 1
  data = (data - mean) / std
  std[indices] = 0
  return data, mean, std

In [None]:
def label2onehot(labels, dim):
  """Convert label indices to one-hot vectors."""
  batch_size = labels.size(0)
  out = torch.zeros(batch_size, dim)
  out[np.arange(batch_size), labels.long()] = 1
  return out


In [None]:
def gradient_penalty(y, x):
  """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
  weight = torch.ones(y.size(), device=device)
  dydx = torch.autograd.grad(outputs=y,
                              inputs=x,
                              grad_outputs=weight,
                              retain_graph=True,
                              create_graph=True,
                              only_inputs=True)[0]

  dydx = dydx.view(dydx.size(0), -1)
  dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
  return torch.mean((dydx_l2norm-1)**2)


In [None]:
class DataSource:

  def __init__(self, batch_size=4):
    self.order = np.random.permutation(NUM_FILES)
    self.batch_size = batch_size
    self.counter = 0

  def __next__(self):
    if self.counter >= NUM_FILES:
      self.order = np.random.permutation(NUM_FILES)
      self.counter = 0
      raise StopIteration()
    file_indices = self.order[self.counter:self.counter+self.batch_size]
    self.counter += self.batch_size
    x_tensor_list = []
    y_tensor_list = []
    spec_tensor_list = []
    for index in file_indices:
      audio_filename = AUDIO_FILES[index]
      spec_filename = SPEC_FILES[index]
      data = torch.load(audio_filename)
      d = torch.load(spec_filename)
      label = d['y'] 
      # spec = d['x']
      # spec, _, _ = normalize(spec.squeeze(0))
      x_tensor_list.append(data)
      y_tensor_list.append(label)
      # spec_tensor_list.append(spec)
    X = torch.vstack(x_tensor_list)
    # print(y_tensor_list[0].shape)
    y = torch.tensor(y_tensor_list)
    # s = torch.vstack(spec_tensor_list)
    y_perm = (y.clone() + torch.randint(1, 5, size=[y.shape[0]])) % 5
    # return X, y, y_perm, s
    return X, y, y_perm, torch.tensor([1])

class MyIterableDataset(IterableDataset):

  def __init__(self, batch_size=4):
    self.source = DataSource(batch_size)

  def __iter__(self):
    return self.source


In [None]:
class Backbone(nn.Module):
  """ Feature Extraction Network""" 
  def __init__(self): 
    super(Backbone, self).__init__() 
    layers = [] 
    layers.append(nn.Conv2d(2, 32, [5,41], [3,21])) 
    layers.append(nn.InstanceNorm2d(32)) 
    layers.append(nn.LeakyReLU(0.01)) 
    layers.append(nn.Conv2d(32, 64, 5, 3)) 
    layers.append(nn.InstanceNorm2d(64)) 
    layers.append(nn.LeakyReLU(0.01)) 
    layers.append(nn.Conv2d(64, 128, 4, 2)) 
    layers.append(nn.InstanceNorm2d(32))         
    layers.append(nn.LeakyReLU(0.01)) 
    self.main = nn.Sequential(*layers) 
      
  def forward(self, x): 
    return self.main(x)


class Classifier(nn.Module):
  """ Classification head for the backbone """
  def __init__(self): 
    super(Classifier, self).__init__() 
    self.bb = Backbone() 
    self.conv = nn.Conv2d(128, 5, [9,11], 1) 
      
  def forward(self, x): 
    x = self.bb(x) 
    x = self.conv(x) 
    y = x.view(x.shape[0], x.shape[1])
    return y


class Discriminator(nn.Module):
  """ Discrimination head for the backbone """
  def __init__(self): 
    super(Discriminator, self).__init__() 
    self.bb = Backbone() 
    self.disc = nn.Conv2d(128, 1, [9,11], 1) 
      
  def forward(self, x): 
    x = self.bb(x) 
    x = self.disc(x) 
    y = x.view(x.shape[0], x.shape[1])
    return y


class ResidualBlock(nn.Module):
  """Residual Block with instance normalization."""
  def __init__(self, dim_in, dim_out):
    super(ResidualBlock, self).__init__()
    self.main = nn.Sequential(
      nn.Conv1d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
      nn.LeakyReLU(0.01),
      nn.Conv1d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
    )
    self.relu = nn.LeakyReLU(0.01)

  def forward(self, x):
    return self.relu(x + self.main(x))


class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    layers = [
      nn.Conv1d(7, 32, 400, 50, padding=200),
      nn.LeakyReLU(0.01),
      nn.Conv1d(32, 64, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      nn.Conv1d(64, 128, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      ResidualBlock(128, 128),
      ResidualBlock(128, 128),
      ResidualBlock(128, 128),
      ResidualBlock(128, 128),
      nn.ConvTranspose1d(128, 64, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      nn.ConvTranspose1d( 64, 32, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      nn.ConvTranspose1d( 32,  2, 400, 50, padding=200),
    ]
    self.main = nn.Sequential(*layers)

  def forward(self, x, c):
    # print(x.shape, c.shape)
    c = c.view(c.size(0), c.size(1), 1) 
    c = c.repeat(1, 1, x.size(2)) 
    x = torch.cat([x, c], dim=1) 
    return self.main(x)


In [None]:
G = Generator()
D = Discriminator()
C = Classifier()
C_optim = torch.optim.Adam(C.parameters(), lr=0.001)
G_optim = torch.optim.Adam(G.parameters(), lr=0.001)
D_optim = torch.optim.Adam(D.parameters(), lr=0.001)

In [None]:
model = torch.load(os.path.join(MODEL_SAVE_PATH, "stargan_spectrograms.pt"))

In [None]:
C.load_state_dict(model['C-model'])

<All keys matched successfully>

In [None]:
def classification_loss(logit, target):
  """Compute binary or softmax cross entropy loss."""
  return F.cross_entropy(logit, target, weight=WEIGHT, size_average=False) / logit.size(0)

In [None]:
model_path = os.path.join(MODEL_SAVE_PATH, MODEL_FILENAME)
l = glob.glob(model_path)
EPOCH = 1

In [None]:
if len(l) != 0:
  checkpoints = torch.load(model_path)
  C.load_state_dict(checkpoints['C-model'])
  G.load_state_dict(checkpoints['G-model'])
  D.load_state_dict(checkpoints['D-model'])
  EPOCH = checkpoints['epoch']
  print("Model loaded")

Model loaded


In [None]:
print(EPOCH)

5


In [None]:
def train(batch_size=4, lambda_gp=10, lambda_cls=1, lambda_recon=10):
  cudnn.benchmark = True

  C.to(device)
  G.to(device)
  D.to(device)

  clipping_value = 1
  dataloader = DataLoader(MyIterableDataset(batch_size=batch_size), num_workers=1)
  file_count = 0
  gen_count = 0
  save_count = 1
  count = EPOCH
  data_iter = iter(dataloader)

  d_loss_real, d_loss_fake = 0, 0
  g_loss_fake, g_loss_cls, g_loss_recon = 0, 0, 0
  mode_flag = False

  # spec_xform = torchaudio.transforms.Spectrogram().to(device)
  # normalize = torchvision.transforms.Normalize(0,1).to(device)

  while True:
    if save_count == 0:
      print("Saving model")
      torch.save({
          'epoch': count,
          'C-model': C.state_dict(),
          'G-model': G.state_dict(),
          'D-model': D.state_dict()
          # 'C-optim': C_optim.state_dict(),
      }, model_path)
    try:
      x_real, label_org, label_trg, spec = next(data_iter)
    except StopIteration:
      count += 1
      file_count = 0
      data_iter = iter(dataloader)
      x_real, label_org, label_trg, spec = next(data_iter)

    x_real = x_real.squeeze(0)
    label_org = label_org.squeeze(0)
    label_trg = label_trg.squeeze(0)
    spec = spec.squeeze(0)
    c_org = label2onehot(label_org, 5)
    c_trg = label2onehot(label_trg, 5)

    x_real = x_real.to(device)
    spec = spec.to(device)
    label_org = label_org.to(device)
    label_trg = label_trg.to(device)
    c_org = c_org.to(device)
    c_trg = c_trg.to(device)

    if mode_flag:
      #### TRAINING THE DISCRIMINATOR
      ## Real data points
      out_src = D(spec)
      d_loss_real = - torch.mean(out_src)
      # Classifier is already trained, uncomment later
      # out_cls = C(x_real)
      # d_loss_cls  = classification_loss(out_cls, label_org)

      ## Generated data points
      with torch.no_grad():
        x_fake = G(x_real, c_trg)
        temp = spec_xform(denormalize(x_fake, ))
        temp = normalize(temp)
      out_src = D(temp)
      d_loss_fake = torch.mean(out_src)

      ## Gradient Penalty
      alpha = torch.rand(x_real.size(0), 1, 1, device=device)
      x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
      x_hat = spec_xform(x_hat)
      x_hat = normalize(x_hat)
      out_src = D(x_hat)
      d_loss_gp = gradient_penalty(out_src, x_hat)

      d_loss = d_loss_real + d_loss_fake + lambda_gp * d_loss_gp
      # c_loss = lambda_cls * d_loss_cls

      D_optim.zero_grad()
      C_optim.zero_grad()

      d_loss.backward()
      torch.nn.utils.clip_grad_norm_(D.parameters(), clipping_value)
      D_optim.step()

      tag = "D-step"
      mode_flag = (d_loss_fake >= 0)
      # mode_flag = (gen_count < 10)
      # c_loss.backward()
      # C_optim.step()

    else:
      #### TRAINING THE GENERATOR
      ## Fooling the Discriminator loss
      x_fake = G(x_real, c_trg)
      # d_in = spec_xform(x_fake)
      # d_in = normalize(d_in)
      # out_src = D(d_in)
      # out_cls = C(d_in)
      # g_loss_fake = - torch.mean(out_src)
      # g_loss_cls = classification_loss(out_cls, label_trg)

      ## Reconstruction loss
      x_recon = G(x_fake, c_org)
      g_loss_recon = torch.sum(torch.abs(x_real - x_recon)) / batch_size

      # g_loss = g_loss_recon * lambda_recon + g_loss_fake + lambda_cls * g_loss_cls
      g_loss = g_loss_recon * lambda_recon

      G_optim.zero_grad()
      D_optim.zero_grad()
      C_optim.zero_grad()

      g_loss.backward()
      # torch.nn.utils.clip_grad_norm_(G.parameters(), clipping_value)
      G_optim.step()

      tag = "G-step"
      mode_flag = (g_loss < 0)
      # mode_flag = (gen_count < 10)

    print(f"{tag}  " + 
          f"D_loss_real: {d_loss_real:.4f}, " +
          f"D_loss_fake: {d_loss_fake:.4f}, " + 
          f"G_loss_fake: {g_loss_fake:.4f}, " +
          f"G_loss_cls:  {g_loss_cls:.4f},  " +
          f"G_loss_recon: {g_loss_recon:.4f}")

    file_count += batch_size
    gen_count = (gen_count + 1) % 20
    save_count = (save_count + 1) % 20


## How to interpret the loss

```D_loss_real``` $ << 0$ (Ideally)

```D_loss_fake``` $ << 0$ (Ideally)

```G_loss_fake``` $ << 0$ (Ideally)

```G_loss_recon``` Must be small

```G_loss_cls``` Must be small

In [None]:
train(batch_size=64, lambda_cls=10, lambda_recon=0.1)

G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 44205.3203


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f790ace2e60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f790ace2e60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 35041.8516
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 37419.9219
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 37034.9062
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 41765.1328
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 35421.6797
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 38014.3711
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 35879.1055
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 47717.0078
G-step  D_loss_real: 0.0000, D_loss_fake: 0.0000, G_loss