In [20]:
import time
import copy
import numpy as np
import pandas as pd
import seaborn as sn
from tqdm import tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

import sklearn
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from models import *
from utils import *
from datasets import *

In [41]:
dab = DAB(
    approximator=ScagEstimator(size_in=1024, size_hidden=100, size_out=9), 
    hard_layer=ScagModule()
)

sum([p.numel() for p in dab.parameters()])

206818

In [42]:
z = np.load("./tsne_data_reducted_normalised.npy")
z = torch.Tensor(z)
z.requires_grad_
z.shape

torch.Size([70000, 3])

# forward

In [63]:
z_in = z[:1024]
z_in.requires_grad = True
out = dab(z_in)
out

tensor([[0.0042, 0.6789, 0.0377, 0.0245, 0.0765, 0.4221, 0.8106, 0.3769, 0.0036]],
       grad_fn=<ScagModuleBackward>)

In [65]:
gt = scagnostics.compute(z_in[:, 0], z_in[:, 1])
gt = torch.Tensor(list(gt.values())).view(1,-1)
gt.requires_grad_()

tensor([[0.0042, 0.6789, 0.0377, 0.0245, 0.0765, 0.4221, 0.8106, 0.3769, 0.0036]],
       requires_grad=True)

In [66]:
labels = gt*2
labels

tensor([[0.0084, 1.3578, 0.0755, 0.0489, 0.1530, 0.8441, 1.6213, 0.7538, 0.0071]],
       grad_fn=<MulBackward0>)

# dab loss

In [67]:
dab.loss_function()

tensor([1.1802], grad_fn=<SumBackward1>)

In [68]:
dab.approximator_output

tensor([[0.1903, 0.0866, 0.0286, 0.0000, 0.0326, 0.1485, 0.0000, 0.1315, 0.0000]],
       grad_fn=<ViewBackward0>)

In [69]:
torch.sum(F.mse_loss(dab.approximator_output, gt, reduction='none'))

tensor(1.1802, grad_fn=<SumBackward0>)

# backward

In [70]:
labels.requires_grad, out.requires_grad

(True, True)

In [71]:
z_in.requires_grad

True

In [72]:
# criterion = nn.MSELoss()

loss = F.mse_loss(target=labels, input=out)
print(loss)
loss.backward()

tensor(0.1607, grad_fn=<MseLossBackward0>)
torch.Size([1024, 3]) True
torch.Size([1, 9]) True


In [75]:
z_in.grad.shape, z_in.grad

(torch.Size([1024, 3]),
 tensor([[ 0.0007,  0.0000,  0.0000],
         [-0.0019,  0.0000,  0.0000],
         [-0.0002,  0.0000,  0.0000],
         ...,
         [-0.0023,  0.0000,  0.0000],
         [-0.0007,  0.0000,  0.0000],
         [-0.0007,  0.0000,  0.0000]]))

In [77]:
sum(z_in.grad[:, 1])

tensor(0.)