In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from tqdm.notebook import tqdm

import librosa
import librosa.display

import torch
import torch.nn as nn
from torchsummary import summary

from common_audio import *
from audio_dataset import *

torch.manual_seed(0)
np.random.seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

print(device)


cpu


In [80]:
mse = nn.MSELoss()
mse_no_reduction = nn.MSELoss(reduction='none')

def projection(X, B):
    B = torch.stack(B, axis=0).transpose(0, 1)
    pred_coefs = torch.inverse(B.transpose(0, 1)@B)@B.transpose(0, 1) @ X
    
    residual = X - (B*(pred_coefs[None, :])).sum(axis=1)
    residual_mag = torch.norm(residual) # want this to be 0
    return pred_coefs, residual_mag

def projection_loss(X, B, true_coefs, weights, res_weight):
    pred_coefs, residual_mag = projection(X, B)
    
    loss1 = torch.dot(weights, mse_no_reduction(pred_coefs, true_coefs))
    loss2 = res_weight * mse(residual_mag, torch.tensor(0.))
    
    return loss1 + loss2 


In [103]:
torch.manual_seed(0)

z_dim = 1000
B = (10*torch.randn(z_dim), 10*torch.randn(z_dim))

true_coefs = torch.tensor([1., 0.])
weights = torch.tensor([1., 1.])
res_weight = torch.tensor(1.)

X = 1*B[0] + 1*B[1]

print('Before: ')
print('MSE: ', mse(X, B[0]))
print('Projection: ',projection(X, B))
print('Projection loss: ', projection_loss(X, B, true_coefs, weights, res_weight))

X.requires_grad_(True)
opt = torch.optim.Adam([X], lr=0.01)
for steps in range(1000):
    opt.zero_grad()
    loss = mse_loss(X, B[0])
    loss.backward()
    opt.step()

X.requires_grad_(False)
print('After: ')
print('MSE: ', mse(X, B[0]))
print('Projection: ',projection(X, B))
print('Projection loss: ', projection_loss(X, B, true_coefs, weights, res_weight))


Before: 
MSE:  tensor(94.3321)
Projection:  (tensor([1.0000, 1.0000]), tensor(7.9897e-05))
Projection loss:  tensor(1.)
After: 
MSE:  tensor(3.7114e-06)
Projection:  (tensor([1.0000e+00, 3.1829e-05]), tensor(0.0601))
Projection loss:  tensor(0.0036)


In [102]:
torch.manual_seed(0)

z_dim = 1000
B = (10*torch.randn(z_dim), 10*torch.randn(z_dim))

true_coefs = torch.tensor([1., 0.])
weights = torch.tensor([0., 1.])
res_weight = torch.tensor(0.)

X = 1*B[0] + 1*B[1]

print('Before: ')
print('MSE: ', mse(X, B[0]))
print('Projection: ',projection(X, B))
print('Projection loss: ', projection_loss(X, B, true_coefs, weights, res_weight))

X.requires_grad_(True)
opt = torch.optim.Adam([X], lr=0.01)
for steps in range(1000):
    opt.zero_grad()
    loss = projection_loss(X, B, true_coefs, weights, res_weight)
    loss.backward()
    opt.step()

X.requires_grad_(False)
print('After: ')
print('MSE: ', mse(X, B[0]))
print('Projection: ',projection(X, B))
print('Projection loss: ', projection_loss(X, B, true_coefs, weights, res_weight))


Before: 
MSE:  tensor(94.3321)
Projection:  (tensor([1.0000, 1.0000]), tensor(7.9897e-05))
Projection loss:  tensor(1.)
After: 
MSE:  tensor(56.2494)
Projection:  (tensor([9.8108e-01, 4.3306e-08]), tensor(237.0898))
Projection loss:  tensor(1.8755e-15)


In [56]:
# utility
def to_torch(*arrays, device=device, dtype=dtype):
    if len(arrays) == 1:
        array = arrays[0]
        if isinstance(array, np.ndarray):
            array = torch.from_numpy(array)
        return array.to(device, dtype)
    ret = ()
    for array in arrays:
        ret += to_torch(array, device=device, dtype=dtype),
    return ret

def to_np(*arrays):
    if len(arrays) == 1:
        return arrays[0].detach().cpu().numpy()
    ret = ()
    for array in arrays:
        ret += to_np(array),
    return ret