In [None]:
!pip install torchaudio

Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 29.4MB/s 
Installing collected packages: torchaudio
Successfully installed torchaudio-0.8.1


In [None]:
import torchaudio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch.backends import cudnn
import numpy as np
import glob
import os
from sklearn.metrics import confusion_matrix


In [None]:
from google.colab import drive
drive.mount('./drive')

Mounted at ./drive


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
DATASET_PATH = "./drive/MyDrive/cs753 dataset/stft/"
MODEL_SAVE_PATH = "./drive/MyDrive/cs753 dataset/"
MODEL_FILENAME = "stargan_stft.pt"

INSTRUMENTS = [
    "Bansuri",
    "Shehnai",
    "Santoor",
    "Sarod",
    "Violin"
]

INSTRUMENT_LABELS = {
    "Bansuri" : 0,
    "Shehnai" : 1,
    "Santoor" : 2,
    "Sarod"   : 3,
    "Violin"  : 4
}

WEIGHT = np.array([
    891,
    1664,
    1122,
    765,
    1193
])
WEIGHT = torch.tensor((WEIGHT / (WEIGHT.sum())) ** -1)
WEIGHT = WEIGHT.float().to(device)

files_path = os.path.join(DATASET_PATH, "*.pt")
FILES = sorted(glob.glob(files_path))
NUM_FILES = len(FILES)

 


In [None]:
print(NUM_FILES)

11274


In [None]:
class DataSource:

  def __init__(self, batch_size=4):
    self.order = np.random.permutation(NUM_FILES)
    self.batch_size = batch_size
    self.counter = 0

  def __next__(self):
    if self.counter >= NUM_FILES:
      self.order = np.random.permutation(NUM_FILES)
      self.counter = 0
      raise StopIteration()
    file_indices = self.order[self.counter:self.counter+self.batch_size]
    self.counter += self.batch_size
    x_tensor_list = []
    y_tensor_list = []
    for index in file_indices:
      filename = FILES[index]
      d = torch.load(filename)
      data  = d['x']
      label = d['y'] 
      # data, _, _ = normalize(data.squeeze(0))
      x_tensor_list.append(data)
      y_tensor_list.append(label)
    X = torch.vstack(x_tensor_list)
    # print(y_tensor_list[0].shape)
    y = torch.tensor(y_tensor_list)
    return X, y

class MyIterableDataset(IterableDataset):

  def __init__(self, batch_size=4):
    self.source = DataSource(batch_size)

  def __iter__(self):
    return self.source


In [None]:
class Backbone(nn.Module):
  """ Feature Extraction Network""" 
  def __init__(self): 
    super(Backbone, self).__init__() 
    self.conv1 = nn.Conv3d(2, 32, [5,33,2], stride=[2, 16, 1], padding=[2,16,0])
    self.main = nn.Sequential(
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.01),
        nn.Conv2d(32, 64, 5, stride=2, padding=2),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(3,3),
        nn.Conv2d(64, 128, 5, stride=2, padding=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.01),
        # nn.Conv2d(128, 256, 5, stride=2, padding=2),
        # nn.BatchNorm2d(256),
        # nn.LeakyReLU(0.01),
        # nn.Conv2d(256, 128, 5, stride=2, padding=2),
        # nn.BatchNorm2d(128),        
        # nn.LeakyReLU(0.01),
    )

  def forward(self, x):
    x = self.conv1(x)
    x = x.squeeze(-1)
    return self.main(x)


class Classifier(nn.Module):
  """ Classification head for the backbone """
  def __init__(self): 
    super(Classifier, self).__init__() 
    self.bb = Backbone() 
    self.conv = nn.Conv2d(128, 5, 9, stride=1) 
      
  def forward(self, x): 
    x = self.bb(x)
    x = self.conv(x) 
    y = x.view(x.shape[0], x.shape[1])
    return y


In [None]:
C = Classifier()
optim = torch.optim.Adam(C.parameters(), lr=0.001)

In [None]:
def classification_loss(logit, target):
  """Compute binary or softmax cross entropy loss."""
  return F.cross_entropy(logit, target, weight=WEIGHT, size_average=False) / logit.size(0)

In [None]:
model_path = os.path.join(MODEL_SAVE_PATH, MODEL_FILENAME)
l = glob.glob(model_path)
EPOCH = 1

In [None]:
if len(l) != 0:
  checkpoints = torch.load(model_path)
  C.load_state_dict(checkpoints['C-model'])
  EPOCH = checkpoints['epoch']
  print("Model loaded")

In [None]:
print(EPOCH)

5


In [None]:
def train(batch_size=4):
  cudnn.benchmark = True

  C.to(device)

  dataloader = DataLoader(MyIterableDataset(batch_size=batch_size), num_workers=1)
  file_count = 0
  save_count = 1
  log_count = 7
  count = EPOCH
  data_iter = iter(dataloader)

  while True:
    if save_count == 0:
      print("Saving model")
      torch.save({
          'epoch': count,
          'C-model': C.state_dict(),
      }, model_path)
    try:
      x_real, label_org = next(data_iter)
    except StopIteration:
      count += 1
      file_count = 0
      data_iter = iter(dataloader)
      x_real, label_org = next(data_iter)

    x_real = x_real.squeeze(0)
    label_org = label_org.squeeze(0)

    x_real = x_real.to(device)
    label_org = label_org.to(device)

    out_cls = C(x_real)
    loss = classification_loss(out_cls, label_org)

    optim.zero_grad()
    (loss*10).backward()
    torch.nn.utils.clip_grad_norm_(C.parameters(), 1)
    optim.step()

    print(f"Loss: {loss}")
    if log_count == 0:
      print(confusion_matrix(torch.argmax(out_cls,dim=1).cpu().numpy(), label_org.cpu().numpy()))

    file_count += batch_size
    log_count = (log_count +1) % 10
    save_count = (save_count + 1) % 20


In [None]:
train(batch_size=64)



Loss: 6.059098720550537
Loss: 8.655223846435547
Loss: 7.61798620223999
Loss: 5.800319194793701
[[ 8  0  0  0  0]
 [ 0  3  0  1  0]
 [ 0  1  2  1  0]
 [ 3  0  6 11  7]
 [ 1  3  4  2 11]]
Loss: 7.318748474121094
Loss: 6.686604976654053
Loss: 6.638411998748779
Loss: 6.321233749389648
Loss: 5.254118919372559
Loss: 5.765625953674316
Loss: 5.0959553718566895
Loss: 5.871968746185303
Loss: 7.6342267990112305
Loss: 5.960012435913086
[[ 6  1  7  3  1]
 [ 0  9  1  0  0]
 [ 0  0  4  0  0]
 [ 1  0 10  8  6]
 [ 0  1  0  1  5]]
Loss: 5.677911281585693
Loss: 4.986918926239014
Loss: 6.0295281410217285
Loss: 4.89791202545166
Loss: 5.679818630218506
Saving model
Loss: 5.841158866882324
Loss: 4.188082695007324
Loss: 5.898998260498047
Loss: 7.735911846160889
Loss: 4.721609592437744
[[ 5  0  0  0  0]
 [ 0  7  2  0  0]
 [ 0  3  9  2  0]
 [ 4  2  6 12  5]
 [ 0  0  0  0  7]]
Loss: 4.335073471069336
Loss: 4.655589580535889
Loss: 5.398458957672119
Loss: 6.236058235168457
Loss: 3.789640426635742
Loss: 5.349390029