In [1]:
!pip install torchaudio

Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 11.1MB/s 
Installing collected packages: torchaudio
Successfully installed torchaudio-0.8.1


In [2]:
import torchaudio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch.backends import cudnn
import numpy as np
import glob
import os
from sklearn.metrics import confusion_matrix


In [3]:
from google.colab import drive
drive.mount('./drive')

Mounted at ./drive


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
DATASET_PATH = "./drive/MyDrive/cs753 dataset/stft/"
MODEL_SAVE_PATH = "./drive/MyDrive/cs753 dataset/"
MODEL_FILENAME = "stargan_stft.pt"

INSTRUMENTS = [
    "Bansuri",
    "Shehnai",
    "Santoor",
    "Sarod",
    "Violin"
]

INSTRUMENT_LABELS = {
    "Bansuri" : 0,
    "Shehnai" : 1,
    "Santoor" : 2,
    "Sarod"   : 3,
    "Violin"  : 4
}

WEIGHT = np.array([
    891,
    1664,
    1122,
    765,
    1193
])
WEIGHT = torch.tensor((WEIGHT / (WEIGHT.sum())) ** -1)
WEIGHT = WEIGHT.float().to(device)

files_path = os.path.join(DATASET_PATH, "*.pt")
FILES = sorted(glob.glob(files_path))
NUM_FILES = len(FILES)

 


In [6]:
print(NUM_FILES)

11274


In [None]:
def normalize(data):
  data = data.type(torch.FloatTensor).unsqueeze(0)
  mean = data.mean(dim=2).unsqueeze(2)
  std = data.std(dim=2).unsqueeze(2)
  std[std == 0] = 1
  data = (data - mean) / std
  return data, mean, std

In [6]:
def label2onehot(labels, dim):
  """Convert label indices to one-hot vectors."""
  batch_size = labels.size(0)
  out = torch.zeros(batch_size, dim)
  out[np.arange(batch_size), labels.long()] = 1
  return out


In [7]:
def gradient_penalty(y, x):
  """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
  weight = torch.ones(y.size(), device=device)
  dydx = torch.autograd.grad(outputs=y,
                              inputs=x,
                              grad_outputs=weight,
                              retain_graph=True,
                              create_graph=True,
                              only_inputs=True)[0]

  dydx = dydx.view(dydx.size(0), -1)
  dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
  return torch.mean((dydx_l2norm-1)**2)


In [8]:
class DataSource:

  def __init__(self, batch_size=4):
    self.order = np.random.permutation(NUM_FILES)
    self.batch_size = batch_size
    self.counter = 0

  def __next__(self):
    if self.counter >= NUM_FILES:
      self.order = np.random.permutation(NUM_FILES)
      self.counter = 0
      raise StopIteration()
    file_indices = self.order[self.counter:self.counter+self.batch_size]
    self.counter += self.batch_size
    x_tensor_list = []
    y_tensor_list = []
    for index in file_indices:
      filename = FILES[index]
      d = torch.load(filename)
      data  = d['x']
      label = d['y'] 
      # data, _, _ = normalize(data.squeeze(0))
      x_tensor_list.append(data)
      y_tensor_list.append(label)
    X = torch.vstack(x_tensor_list)
    # print(y_tensor_list[0].shape)
    y = torch.tensor(y_tensor_list)
    y_perm = torch.randint(0, 5, size=[y.shape[0]])
    # y_perm = (y.clone() + torch.randint(1, 5, size=[y.shape[0]])) % 5
    return X, y, y_perm

class MyIterableDataset(IterableDataset):

  def __init__(self, batch_size=4):
    self.source = DataSource(batch_size)

  def __iter__(self):
    return self.source


In [9]:
class Backbone(nn.Module):
  """ Feature Extraction Network""" 
  def __init__(self): 
    super(Backbone, self).__init__() 
    self.conv1 = nn.Conv3d(2, 32, [5,33,2], stride=[2, 16, 1], padding=[2,16,0])
    self.main = nn.Sequential(
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.01),
        nn.Conv2d(32, 64, 5, stride=2, padding=2),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(3,3),
        nn.Conv2d(64, 128, 5, stride=2, padding=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.01),
    )

  def forward(self, x):
    x = self.conv1(x)
    x = x.squeeze(-1)
    return self.main(x)


class Discriminator(nn.Module):
  """ Discrimination head for the backbone """
  def __init__(self): 
    super(Discriminator, self).__init__() 
    self.bb = Backbone() 
    self.conv = nn.Conv2d(128, 5, 9, stride=1) 
    self.disc = nn.Conv2d(128, 1, 9, stride=1) 
      
  def forward(self, x, classify=True): 
    x = self.bb(x)
    src = self.disc(x) 
    if classify == False:
      return src.view(src.shape[0], src.shape[1]), None
    cls = self.conv(x)
    return src.view(src.shape[0], src.shape[1]), cls.view(cls.shape[0], cls.shape[1])


class ResidualBlock(nn.Module):
  """Residual Block with instance normalization."""
  def __init__(self, dim_in, dim_out):
    super(ResidualBlock, self).__init__()
    self.main = nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
      nn.LeakyReLU(0.01),
      nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
    )
    self.relu = nn.LeakyReLU(0.01)

  def forward(self, x):
    return self.relu(x + self.main(x))


class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.conv1 = nn.Conv3d(7, 32, [5,33,2], stride=[2, 16, 1], padding=[2,16,0])
    self.main = nn.Sequential(
      nn.LeakyReLU(0.01),
      nn.Conv2d(32, 64, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      nn.Conv2d(64, 128, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      ResidualBlock(128, 128),
      ResidualBlock(128, 128),
      nn.ConvTranspose2d(128, 64, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
      nn.ConvTranspose2d( 64, 32, 5, 2, padding=2),
      nn.LeakyReLU(0.01),
    )
    self.deconv1 = nn.ConvTranspose3d(32, 2, [5,33,2], stride=[2,16,1], padding=[2,16,0])

  def forward(self, x, c):
    # print(x.shape, c.shape)
    c = c.view(c.size(0), c.size(1), 1, 1, 1)
    c = c.repeat(1, 1, x.size(2), x.size(3), x.size(4))
    x = torch.cat([x, c], dim=1)
    x = self.conv1(x).squeeze(-1)
    x = self.main(x).unsqueeze(-1)
    x = self.deconv1(x)
    return x


In [10]:
D = Discriminator()
G = Generator()
D_optim = torch.optim.Adam(D.parameters(), lr=0.001)
G_optim = torch.optim.Adam(G.parameters(), lr=0.001)

In [29]:
temp = torch.randn([1,2,201,1601,2])
with torch.no_grad():
  r1, r2 = D(temp)
  g1 = G(temp, torch.tensor([[1,0,0,0,0]]))

In [31]:
print(g1.shape)

torch.Size([1, 2, 201, 1601, 2])


In [32]:
model = torch.load(os.path.join(MODEL_SAVE_PATH, "classifier_stft.pt"))

In [33]:
D.load_state_dict(model['C-model'], strict=False)

_IncompatibleKeys(missing_keys=['disc.weight', 'disc.bias'], unexpected_keys=[])

In [11]:
def classification_loss(logit, target):
  """Compute binary or softmax cross entropy loss."""
  return F.cross_entropy(logit, target, weight=WEIGHT, size_average=False) / logit.size(0)

In [12]:
model_path = os.path.join(MODEL_SAVE_PATH, MODEL_FILENAME)
l = glob.glob(model_path)
EPOCH = 1

In [13]:
if len(l) != 0:
  checkpoints = torch.load(model_path)
  G.load_state_dict(checkpoints['G-model'])
  D.load_state_dict(checkpoints['D-model'])
  EPOCH = checkpoints['epoch']
  print("Model loaded")

Model loaded


In [None]:
print(EPOCH)

5


In [28]:
def train(batch_size=4, lambda_gp=10, lambda_cls=1, lambda_recon=10, accumulate=1):
  cudnn.benchmark = True

  G.to(device)
  D.to(device)

  clipping_value = 1
  dataloader = DataLoader(MyIterableDataset(batch_size=batch_size), num_workers=1)
  file_count = 0
  gen_count = 0
  save_count = 1
  apply_grad = 1
  count = EPOCH
  data_iter = iter(dataloader)

  d_loss_real, d_loss_fake, d_loss_cls = 0, 0, 0
  g_loss_fake, g_loss_cls, g_loss_recon = 0, 0, 0
  mode_flag = True

  G_optim.zero_grad()
  D_optim.zero_grad()

  while True:
    if save_count == 0:
      print("Saving model")
      torch.save({
          'epoch': count,
          'G-model': G.state_dict(),
          'D-model': D.state_dict()
      }, model_path)
    try:
      x_real, label_org, label_trg = next(data_iter)
    except StopIteration:
      count += 1
      file_count = 0
      data_iter = iter(dataloader)
      x_real, label_org, label_trg = next(data_iter)

    x_real = x_real.squeeze(0)
    label_org = label_org.squeeze(0)
    label_trg = label_trg.squeeze(0)
    c_org = label2onehot(label_org, 5)
    c_trg = label2onehot(label_trg, 5)

    x_real = x_real.to(device)
    label_org = label_org.to(device)
    label_trg = label_trg.to(device)
    c_org = c_org.to(device)
    c_trg = c_trg.to(device)

    # if gen_count < 10:
    if mode_flag:
      #### TRAINING THE DISCRIMINATOR
      ## Real data points
      out_src, out_cls = D(x_real)
      d_loss_real = - torch.mean(out_src)
      d_loss_cls  = classification_loss(out_cls, label_org)

      ## Generated data points
      with torch.no_grad():
        x_fake = G(x_real, c_trg)
      out_src, _ = D(x_fake, classify=False)
      d_loss_fake = torch.mean(out_src)

      ## Gradient Penalty
      alpha = torch.rand(x_real.size(0), 1, 1, 1, 1, device=device)
      x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
      out_src, _ = D(x_hat, classify=False)
      d_loss_gp = gradient_penalty(out_src, x_hat)

      d_loss = d_loss_real + d_loss_fake + (lambda_gp * d_loss_gp) + (lambda_cls * d_loss_cls)

      d_loss.backward()
      # torch.nn.utils.clip_grad_norm_(D.parameters(), clipping_value)
      if apply_grad == 0:
        print("Applying gradients")
        D_optim.step()
        G_optim.step()
        D_optim.zero_grad()
        G_optim.zero_grad()

      tag = "D-step"
      mode_flag = (d_loss_fake >= 0)

    else:
      #### TRAINING THE GENERATOR
      ## Fooling the Discriminator loss
      x_fake = G(x_real, c_trg)
      out_src, out_cls = D(x_fake)
      g_loss_fake = - torch.mean(out_src)
      g_loss_cls = classification_loss(out_cls, label_trg)

      # indices = (label_org == label_trg)
      # if indices.sum() > 0:
      #   g_loss_recon = torch.sum(torch.abs(x_real[indices] - x_fake[indices])) / indices.sum()

      ## Reconstruction loss
      # x_recon = G(x_fake, c_org)
      # g_loss_recon = torch.sum(torch.abs(x_real - x_recon)) / batch_size

      g_loss = g_loss_fake + (lambda_cls * g_loss_cls)
      # g_loss = (g_loss_recon * lambda_recon) + g_loss_fake + (lambda_cls * g_loss_cls)

      g_loss.backward()

      # torch.nn.utils.clip_grad_norm_(G.parameters(), clipping_value)
      if apply_grad == 0:
        print("Applying gradients")
        G_optim.step()
        D_optim.step()
        G_optim.zero_grad()
        D_optim.zero_grad()

      tag = "G-step"
      mode_flag = (g_loss < 0)

    print(f"{tag}  " + 
          f"D_loss_real: {d_loss_real:.4f}, " +
          f"D_loss_fake: {d_loss_fake:.4f}, " + 
          f"D_loss_cls: {d_loss_cls:.4f}, " + 
          f"G_loss_fake: {g_loss_fake:.4f}, " +
          f"G_loss_cls:  {g_loss_cls:.4f},  " +
          f"G_loss_recon: {g_loss_recon:.4f}")

    file_count += batch_size
    gen_count = (gen_count + 1) % 50
    save_count = (save_count + 1) % 20
    apply_grad = (apply_grad + 1) % accumulate

## How to interpret the loss

```D_loss_real``` $ << 0$ (Ideally)

```D_loss_fake``` $ << 0$ (Ideally)

```G_loss_fake``` $ << 0$ (Ideally)

```G_loss_recon``` Must be small

```G_loss_cls``` Must be small

In [None]:
train(batch_size=64, lambda_cls=10, lambda_recon=10, accumulate=5)



D-step  D_loss_real: -43.5539, D_loss_fake: -5.7336, D_loss_cls: 0.4982, G_loss_fake: 0.0000, G_loss_cls:  0.0000,  G_loss_recon: 0.0000
G-step  D_loss_real: -43.5539, D_loss_fake: -5.7336, D_loss_cls: 0.4982, G_loss_fake: 4.7542, G_loss_cls:  0.0254,  G_loss_recon: 0.0000
G-step  D_loss_real: -43.5539, D_loss_fake: -5.7336, D_loss_cls: 0.4982, G_loss_fake: -0.7236, G_loss_cls:  0.0242,  G_loss_recon: 0.0000
D-step  D_loss_real: -43.6462, D_loss_fake: -0.0915, D_loss_cls: 0.5127, G_loss_fake: -0.7236, G_loss_cls:  0.0242,  G_loss_recon: 0.0000
Applying gradients
G-step  D_loss_real: -43.6462, D_loss_fake: -0.0915, D_loss_cls: 0.5127, G_loss_fake: 0.0846, G_loss_cls:  0.0347,  G_loss_recon: 0.0000
G-step  D_loss_real: -43.6462, D_loss_fake: -0.0915, D_loss_cls: 0.5127, G_loss_fake: -3.3558, G_loss_cls:  0.0691,  G_loss_recon: 0.0000
D-step  D_loss_real: -45.8130, D_loss_fake: 3.2365, D_loss_cls: 0.2203, G_loss_fake: -3.3558, G_loss_cls:  0.0691,  G_loss_recon: 0.0000
D-step  D_loss_real