In [1]:
import torch
from models.ast_model import ASTModel
import numpy as np
import os
import torchaudio



In [2]:
def train_one_epoch(model,optimizer,training_loader,loss_fn):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            #tb_x = epoch_index * len(training_loader) + i + 1
            #tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [3]:
def train(epoch_nb, training_loader, validation_loader, model, optimizer, loss_fn, scheduler):
    best_vloss = 1_000_000.

    for epoch in range(epoch_nb):
        print('EPOCH {}:'.format(epoch + 1))
    
        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        avg_loss = train_one_epoch(model,optimizer,training_loader,loss_fn)
    
    
        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()
    
        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss
    
        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    
        # Log the running loss averaged per batch
        # for both training and validation
        #writer.add_scalars('Training vs. Validation Loss',
                        #{ 'Training' : avg_loss, 'Validation' : avg_vloss },
                       # epoch_number + 1)
        #writer.flush()
    
        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            print('Best model found at epoch: ' + str(epoch))
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
            
    

In [4]:
class AST_multichannel(torch.nn.Module):
    def __init__(self):
        super(AST_multichannel,self).__init__()
        self.ast = ASTModel(audioset_pretrain=True)
        idt = torch.nn.Identity()
        for i in range(5,12):
            self.ast.v.blocks[i] = idt
        self.ast.v.mlp_head = idt
        self.lin1 = torch.nn.Linear(768*4, 768)
        self.lin2 = torch.nn.Linear(768, 128)
        self.lin3 = torch.nn.Linear(128,2)
        self.act = torch.nn.ReLU()
        self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x1 = self.ast(x[:,0])
        x2 = self.ast(x[:,1])
        x3 = self.ast(x[:,2])
        x4 = self.ast(x[:,3])
        x = torch.cat((x1,x2,x3,x4),axis=1)
        x = self.act(self.lin1(x))
        x = self.act(self.lin2(x))
        x = self.sig(self.lin3(x))
        return x

In [5]:
DATASET_PATH = "LivingRoom_preprocessed_hack/Human1"

centroid = np.load(os.path.join(DATASET_PATH, "centroid.npy"))
print("Shape of Centroid:")
print(centroid.shape)

#Loading Room Impulse Response (1000 human locations x 10 microphones x M time samples)
RIRs = np.load(os.path.join(DATASET_PATH, "deconvoled_trim.npy"), mmap_mode='r')
print("Shape of RIRs:")
print(RIRs.shape)

Shape of Centroid:
(1000, 2)
Shape of RIRs:
(1000, 4, 667200)


In [6]:
spec = torchaudio.transforms.MelSpectrogram(48000)

  "At least one mel filterbank has all zero values. "


In [7]:
audio = torch.tensor(RIRs)
X_all = torch.stack([spec(audio[:,0]),spec(audio[:,1]),spec(audio[:,2]),spec(audio[:,3])],dim=1)

In [8]:
from torch.utils.data import Dataset, DataLoader

In [9]:
class RIRDataset(Dataset):
    def __init__(self, specs, centroids):
        self.specs = specs
        self.centroids = centroids

    def __len__(self):
        return self.specs.shape[0]

    def __getitem__(self,idx):
        return (self.specs[idx],self.centroids[idx])

In [10]:
print(X_all.shape)
Y_all = torch.tensor(centroid)
print(Y_all.shape)

torch.Size([1000, 4, 128, 3337])
torch.Size([1000, 2])


In [11]:
model = AST_multichannel()

optim = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn = torch.nn.MSELoss()
scheduler=None

train_set = RIRDataset(X_all[:800],Y_all[:800])
valid_set = RIRDataset(X_all[800:],Y_all[800:])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
frequncey stride=10, time stride=10
number of patches=3996


  "See the documentation of nn.Upsample for details.".format(mode)


In [None]:
train(15,train_loader,valid_loader,model,optim,loss_fn,scheduler)

EPOCH 1:
