# September 30 - M2 Implementation

In [193]:
import torch

### Goal : Write a dummy implementation which simulates the process of the semi-supervised VAE to ensure the correctness of the tensor operations to be performed.

In [194]:
# Helper functions
def print_tensor(in_tensor):
    assert in_tensor is not None
    print(in_tensor.size(), "\n", in_tensor)

---
### Unlabelled case
---

### Assume `batch_size = 2`, `num_latent_dims=4` and `num_classes=3`

In [207]:
batch_size = 2
num_latent_dims = 4
num_classes = 3

In [208]:
z_prime = torch.randn(batch_size, num_latent_dims)
print(z_prime.size(), "\n", z_prime)

torch.Size([2, 4]) 
 tensor([[ 0.1447, -0.2225,  1.3976,  1.7152],
        [-0.7281, -0.3237,  1.0932, -0.1912]])


In [209]:
z_prime = z_prime.view(-1, 1, z_prime.size(1))
print(z_prime.size(), "\n", z_prime)

torch.Size([2, 1, 4]) 
 tensor([[[ 0.1447, -0.2225,  1.3976,  1.7152]],

        [[-0.7281, -0.3237,  1.0932, -0.1912]]])


In [210]:
z_prime = z_prime + torch.zeros(batch_size, num_classes, num_latent_dims)
print(z_prime.size(), "\n", z_prime)

torch.Size([2, 3, 4]) 
 tensor([[[ 0.1447, -0.2225,  1.3976,  1.7152],
         [ 0.1447, -0.2225,  1.3976,  1.7152],
         [ 0.1447, -0.2225,  1.3976,  1.7152]],

        [[-0.7281, -0.3237,  1.0932, -0.1912],
         [-0.7281, -0.3237,  1.0932, -0.1912],
         [-0.7281, -0.3237,  1.0932, -0.1912]]])


### Finished reshaping and making 3 copies of the deterministic latent vector

In [211]:
i_3 = torch.eye(num_classes)
print(i_3.size(), "\n", i_3)

torch.Size([3, 3]) 
 tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])


In [212]:
i_3 = i_3.view(-1, num_classes, num_classes)
print(i_3.size(), "\n", i_3)

torch.Size([1, 3, 3]) 
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])


In [213]:
i_3 = i_3 + torch.zeros(batch_size, num_classes, num_classes)
print(i_3.size(), "\n", i_3)

torch.Size([2, 3, 3]) 
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])


### Finished initializing the one-hot vectors for each of the three classes as an identity and reshaping

### Concatenate the two tensor at the latest dimension to construct the z_prime + y tensor

In [214]:
z_prime_y = torch.cat((z_prime, i_3), dim=2)
print(z_prime_y.size(), "\n", z_prime_y)

torch.Size([2, 3, 7]) 
 tensor([[[ 0.1447, -0.2225,  1.3976,  1.7152,  1.0000,  0.0000,  0.0000],
         [ 0.1447, -0.2225,  1.3976,  1.7152,  0.0000,  1.0000,  0.0000],
         [ 0.1447, -0.2225,  1.3976,  1.7152,  0.0000,  0.0000,  1.0000]],

        [[-0.7281, -0.3237,  1.0932, -0.1912,  1.0000,  0.0000,  0.0000],
         [-0.7281, -0.3237,  1.0932, -0.1912,  0.0000,  1.0000,  0.0000],
         [-0.7281, -0.3237,  1.0932, -0.1912,  0.0000,  0.0000,  1.0000]]])


### Now to use the `torch.nn.Linear` layers, collapse the 3-D tensors onto 2-D tensors where the first dimension resembles the batch dimension

In [215]:
z_prime_y = z_prime_y.view(-1, num_latent_dims+num_classes)
print(z_prime_y.size(), "\n", z_prime_y)

torch.Size([6, 7]) 
 tensor([[ 0.1447, -0.2225,  1.3976,  1.7152,  1.0000,  0.0000,  0.0000],
        [ 0.1447, -0.2225,  1.3976,  1.7152,  0.0000,  1.0000,  0.0000],
        [ 0.1447, -0.2225,  1.3976,  1.7152,  0.0000,  0.0000,  1.0000],
        [-0.7281, -0.3237,  1.0932, -0.1912,  1.0000,  0.0000,  0.0000],
        [-0.7281, -0.3237,  1.0932, -0.1912,  0.0000,  1.0000,  0.0000],
        [-0.7281, -0.3237,  1.0932, -0.1912,  0.0000,  0.0000,  1.0000]])


In [216]:
z_prime_collapsed = z_prime.view(-1, num_latent_dims)
print(z_prime_collapsed.size(), "\n", z_prime_collapsed)

torch.Size([6, 4]) 
 tensor([[ 0.1447, -0.2225,  1.3976,  1.7152],
        [ 0.1447, -0.2225,  1.3976,  1.7152],
        [ 0.1447, -0.2225,  1.3976,  1.7152],
        [-0.7281, -0.3237,  1.0932, -0.1912],
        [-0.7281, -0.3237,  1.0932, -0.1912],
        [-0.7281, -0.3237,  1.0932, -0.1912]])


### Define and apply the reparameterization layers corresponding to the new dimensions

In [217]:
mu = torch.nn.Linear(num_latent_dims+num_classes, num_latent_dims)
logvar = torch.nn.Linear(num_latent_dims, num_latent_dims)

In [218]:
z_mu = mu(z_prime_y)
z_logvar = logvar(z_prime_collapsed)

print(z_mu.size(), "\n", z_mu)
print(z_logvar.size(), "\n", z_logvar)

torch.Size([6, 4]) 
 tensor([[-0.0693,  0.0369,  0.8561,  1.0006],
        [-0.4374,  0.0777,  0.7816,  1.0024],
        [-0.3745, -0.0354,  0.6796,  1.0796],
        [ 0.2304, -0.1905,  0.4715,  0.6405],
        [-0.1378, -0.1498,  0.3970,  0.6424],
        [-0.0749, -0.2629,  0.2950,  0.7195]], grad_fn=<AddmmBackward>)
torch.Size([6, 4]) 
 tensor([[-0.9302, -1.0325, -0.3666, -0.8098],
        [-0.9302, -1.0325, -0.3666, -0.8098],
        [-0.9302, -1.0325, -0.3666, -0.8098],
        [-0.5283, -0.4938, -0.7802, -0.6109],
        [-0.5283, -0.4938, -0.7802, -0.6109],
        [-0.5283, -0.4938, -0.7802, -0.6109]], grad_fn=<AddmmBackward>)


In [219]:
# Reparameterization trick
std = z_logvar.mul(0.5).exp()
eps = std.new(std.size()).normal_()
z = eps.mul(std).add(z_mu)

In [220]:
print(z.size(), "\n", z)

torch.Size([6, 4]) 
 tensor([[-0.4760,  0.8756,  1.1344,  0.6945],
        [-1.1555,  1.3650, -0.4837,  0.8919],
        [-0.2572, -0.3965,  1.1423,  0.7819],
        [ 0.8803,  0.2928,  0.3984,  0.2667],
        [ 0.4242,  0.7377,  1.1044,  0.3319],
        [ 0.5081,  0.0090,  0.9278,  0.8481]], grad_fn=<AddBackward0>)


In [221]:
print_tensor(i_3)

torch.Size([2, 3, 3]) 
 tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])


In [222]:
print_tensor(i_3.view(-1, num_classes))

torch.Size([6, 3]) 
 tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])


In [224]:
z = torch.cat((z, i_3.view(-1, num_classes)), dim=1)
print_tensor(z)

torch.Size([6, 7]) 
 tensor([[-0.4760,  0.8756,  1.1344,  0.6945,  1.0000,  0.0000,  0.0000],
        [-1.1555,  1.3650, -0.4837,  0.8919,  0.0000,  1.0000,  0.0000],
        [-0.2572, -0.3965,  1.1423,  0.7819,  0.0000,  0.0000,  1.0000],
        [ 0.8803,  0.2928,  0.3984,  0.2667,  1.0000,  0.0000,  0.0000],
        [ 0.4242,  0.7377,  1.1044,  0.3319,  0.0000,  1.0000,  0.0000],
        [ 0.5081,  0.0090,  0.9278,  0.8481,  0.0000,  0.0000,  1.0000]],
       grad_fn=<CatBackward>)


### Assue for simplicity that the decoder is just an identity, then implement the correct evaluation of the MSE loss

In [109]:
x_prime = z
print(x_prime.size(), "\n", x_prime)

torch.Size([6, 4]) 
 tensor([[ 1.3055,  0.9925,  0.2280, -0.0165],
        [ 0.2937,  0.5772,  1.2061,  0.9281],
        [ 0.5168, -0.0194,  1.0562,  0.6417],
        [-1.2945, -0.3507, -0.9659,  1.9311],
        [-0.5934, -1.3755, -1.9839,  1.1759],
        [-0.2052, -0.7946, -0.0683,  1.1974]], grad_fn=<AddBackward0>)


In [110]:
x = torch.randn(2,4)
print(x.size(), "\n", x)

torch.Size([2, 4]) 
 tensor([[-1.0631,  0.9716, -0.3089,  2.3596],
        [ 1.6066,  1.3053,  0.0722, -1.4586]])


In [111]:
x = x.view(2,1,4) + torch.zeros(2,3,4)
print(x.size(), "\n", x)

torch.Size([2, 3, 4]) 
 tensor([[[-1.0631,  0.9716, -0.3089,  2.3596],
         [-1.0631,  0.9716, -0.3089,  2.3596],
         [-1.0631,  0.9716, -0.3089,  2.3596]],

        [[ 1.6066,  1.3053,  0.0722, -1.4586],
         [ 1.6066,  1.3053,  0.0722, -1.4586],
         [ 1.6066,  1.3053,  0.0722, -1.4586]]])


In [112]:
x = x.view(6,4)
print(x.size(), "\n", x)

torch.Size([6, 4]) 
 tensor([[-1.0631,  0.9716, -0.3089,  2.3596],
        [-1.0631,  0.9716, -0.3089,  2.3596],
        [-1.0631,  0.9716, -0.3089,  2.3596],
        [ 1.6066,  1.3053,  0.0722, -1.4586],
        [ 1.6066,  1.3053,  0.0722, -1.4586],
        [ 1.6066,  1.3053,  0.0722, -1.4586]])


In [113]:
mse_loss = torch.nn.MSELoss(reduction='none')

In [114]:
mse_x_x_prime = mse_loss(x, x_prime)
print(mse_x_x_prime.size(), "\n", mse_x_x_prime)

torch.Size([6, 4]) 
 tensor([[5.6102e+00, 4.3468e-04, 2.8829e-01, 5.6457e+00],
        [1.8407e+00, 1.5558e-01, 2.2951e+00, 2.0491e+00],
        [2.4961e+00, 9.8208e-01, 1.8635e+00, 2.9509e+00],
        [8.4165e+00, 2.7423e+00, 1.0778e+00, 1.1490e+01],
        [4.8401e+00, 7.1869e+00, 4.2278e+00, 6.9407e+00],
        [3.2826e+00, 4.4096e+00, 1.9733e-02, 7.0544e+00]],
       grad_fn=<PowBackward0>)


In [115]:
mse_x_x_prime = torch.sum(mse_x_x_prime, dim=1)
print(mse_x_x_prime.size(), "\n", mse_x_x_prime)

torch.Size([6]) 
 tensor([11.5446,  6.3405,  8.2926, 23.7269, 23.1955, 14.7664],
       grad_fn=<SumBackward2>)


In [116]:
mse_x_x_prime = mse_x_x_prime.view(batch_size, num_classes, 1)
print(mse_x_x_prime.size(), "\n", mse_x_x_prime)

torch.Size([2, 3, 1]) 
 tensor([[[11.5446],
         [ 6.3405],
         [ 8.2926]],

        [[23.7269],
         [23.1955],
         [14.7664]]], grad_fn=<ViewBackward>)


In [117]:
batch_mse_loss = mse_x_x_prime

### Similarility computing the KL divergence loss using the mu and logvar vectors

In [118]:
print(z_mu.size(), "\n", z_mu)
print(z_logvar.size(), "\n", z_logvar)

torch.Size([6, 4]) 
 tensor([[ 0.1198, -0.7514, -0.3353,  0.5805],
        [-0.3165, -0.4047, -0.2002,  0.9048],
        [-0.0194, -0.1035, -0.4782,  0.7672],
        [-0.2469, -0.3395, -0.6040,  0.7591],
        [-0.6832,  0.0071, -0.4688,  1.0834],
        [-0.3861,  0.3083, -0.7468,  0.9458]], grad_fn=<AddmmBackward>)
torch.Size([6, 4]) 
 tensor([[ 0.1019,  0.5624,  0.4668, -0.3781],
        [ 0.1019,  0.5624,  0.4668, -0.3781],
        [ 0.1019,  0.5624,  0.4668, -0.3781],
        [ 0.4031,  0.0436,  0.0706,  0.0051],
        [ 0.4031,  0.0436,  0.0706,  0.0051],
        [ 0.4031,  0.0436,  0.0706,  0.0051]], grad_fn=<AddmmBackward>)


In [120]:
batch_kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1)

In [121]:
batch_kl_loss = batch_kl_loss.view(batch_size, num_classes, 1)
print(batch_kl_loss.size(), "\n", batch_kl_loss)

torch.Size([2, 3, 1]) 
 tensor([[[0.7088],
         [0.7559],
         [0.6088]],

        [[0.6071],
         [0.9786],
         [0.8966]]], grad_fn=<ViewBackward>)


### Add the batch MSE loss and batch KL loss

In [122]:
batch_total_loss = batch_mse_loss + batch_kl_loss

In [123]:
print(batch_total_loss.size(), "\n", batch_total_loss)

torch.Size([2, 3, 1]) 
 tensor([[[12.2534],
         [ 7.0964],
         [ 8.9014]],

        [[24.3340],
         [24.1741],
         [15.6630]]], grad_fn=<AddBackward0>)


### Weight the loss terms by the corresponding $q_{\phi}(y|x)$ for each $y \in Y$. Assume the $\pi_{\phi}(x)$ vector is passed as batch_size $\times$ num_classes

In [147]:
pi_x = torch.randn(batch_size, num_classes)
print(pi_x.size(), "\n", pi_x)

torch.Size([2, 3]) 
 tensor([[ 1.1621,  0.5065, -0.5819],
        [-0.8934,  0.9607,  1.1394]])


In [148]:
softmax = torch.nn.Softmax(dim=1)
pi_x = softmax(pi_x)
print(pi_x.size(), "\n", pi_x)

torch.Size([2, 3]) 
 tensor([[0.5903, 0.3065, 0.1032],
        [0.0666, 0.4251, 0.5083]])


In [149]:
pi_x = pi_x.view(batch_size, 1, num_classes)
print(pi_x.size(), "\n", pi_x)

torch.Size([2, 1, 3]) 
 tensor([[[0.5903, 0.3065, 0.1032]],

        [[0.0666, 0.4251, 0.5083]]])


In [151]:
weighted_loss = torch.bmm(pi_x, batch_total_loss).view(-1)
print(weighted_loss.size(), "\n", weighted_loss)

torch.Size([2]) 
 tensor([10.3271, 19.8585], grad_fn=<ViewBackward>)


In [152]:
final_weighted_loss = torch.mean(weighted_loss)
print(final_weighted_loss.size(), "\n", final_weighted_loss)

torch.Size([]) 
 tensor(15.0928, grad_fn=<MeanBackward0>)


### For the unlabelled case, need to calculate the entropy $$\mathcal{H}(q_{\phi}(y|x) = \sum_{y}q_{\phi}(y|x)logq_{\phi}(y|x)$$

In [128]:
print(pi_x.size(), "\n", pi_x)

torch.Size([2, 1, 3]) 
 tensor([[[0.1962, 0.2579, 0.5459]],

        [[0.3447, 0.4228, 0.2325]]])


### Using Entropy loss from https://discuss.pytorch.org/t/calculating-the-entropy-loss/14510

In [132]:
class HLoss(torch.nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x):
        b = torch.nn.functional.softmax(x, dim=1) * torch.nn.functional.log_softmax(x, dim=1)
        print(b.size())
        b = -1.0 * torch.sum(b, dim=1)
        return b

In [133]:
pi_x = torch.randn(batch_size, num_classes)
print(pi_x.size(), "\n", pi_x)

torch.Size([2, 3]) 
 tensor([[-0.3424, -0.1135, -0.8758],
        [-0.2473,  0.7513,  1.0534]])


In [134]:
hloss = HLoss()
batch_h = hloss(pi_x)

print(batch_h.size(), "\n", batch_h)

torch.Size([2, 3])
torch.Size([2]) 
 tensor([1.0540, 0.9860])


In [137]:
softmax_pi = torch.nn.functional.softmax(pi_x, dim=1)
print(softmax_pi.size(), "\n", softmax_pi)

torch.Size([2, 3]) 
 tensor([[0.3516, 0.4421, 0.2063],
        [0.1354, 0.3675, 0.4971]])


In [138]:
log_softmax_pi = torch.nn.functional.log_softmax(pi_x, dim=1)
print(log_softmax_pi.size(), "\n", log_softmax_pi)

torch.Size([2, 3]) 
 tensor([[-1.0452, -0.8162, -1.5785],
        [-1.9996, -1.0011, -0.6989]])


In [144]:
h = softmax_pi * log_softmax_pi
print(h.size(), "n", h)

torch.Size([2, 3]) n tensor([[-0.3675, -0.3609, -0.3256],
        [-0.2707, -0.3679, -0.3474]])


In [145]:
h = torch.sum(h, dim=1)
print(h.size(), "n", h)

torch.Size([2]) n tensor([-1.0540, -0.9860])


In [153]:
final_h_loss = torch.mean(h)
print(final_h_loss.size(), "\n", final_h_loss)

torch.Size([]) 
 tensor(-1.0200)


In [189]:
unlabelled_loss = final_weighted_loss + final_h_loss
print_tensor(unlabelled_loss)

torch.Size([]) 
 tensor(14.0728, grad_fn=<AddBackward0>)


---
### Labelled case
---

### Assume `batch_size = 2`, `num_latent_dims=4` and `num_classes=3`

In [238]:
batch_size = 2
num_latent_dims = 4
num_classes = 3

In [239]:
z_prime = torch.randn(batch_size, num_latent_dims)
print_tensor(z_prime)

torch.Size([2, 4]) 
 tensor([[-1.0374,  1.0577,  1.8441, -1.4078],
        [ 0.5333, -1.4268, -0.3109,  0.9253]])


In [240]:
labels = torch.tensor([1,2])
print_tensor(labels)

torch.Size([2]) 
 tensor([1, 2])


In [241]:
y_onehot = torch.zeros(batch_size, num_classes)
print_tensor(y_onehot)

torch.Size([2, 3]) 
 tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [242]:
y_onehot = y_onehot.scatter(1, labels.reshape(-1,1), 1)
print_tensor(y_onehot)

torch.Size([2, 3]) 
 tensor([[0., 1., 0.],
        [0., 0., 1.]])


In [243]:
z_prime_y = torch.cat((z_prime, y_onehot), dim=1)
print_tensor(z_prime_y)

torch.Size([2, 7]) 
 tensor([[-1.0374,  1.0577,  1.8441, -1.4078,  0.0000,  1.0000,  0.0000],
        [ 0.5333, -1.4268, -0.3109,  0.9253,  0.0000,  0.0000,  1.0000]])


In [244]:
mu = torch.nn.Linear(num_latent_dims+num_classes, num_latent_dims)
logvar = torch.nn.Linear(num_latent_dims, num_latent_dims)

In [245]:
z_mu = mu(z_prime_y)
z_logvar = logvar(z_prime)

print(z_mu.size(), "\n", z_mu)
print(z_logvar.size(), "\n", z_logvar)

torch.Size([2, 4]) 
 tensor([[ 0.7054, -0.2334, -0.5250,  1.4733],
        [-0.4138,  0.2683,  0.2979, -0.4144]], grad_fn=<AddmmBackward>)
torch.Size([2, 4]) 
 tensor([[ 1.6079,  0.0586,  0.1093,  0.6846],
        [ 0.1963, -0.5290, -0.3938, -0.2295]], grad_fn=<AddmmBackward>)


In [246]:
# Reparameterization trick
std = z_logvar.mul(0.5).exp()
eps = std.new(std.size()).normal_()
z = eps.mul(std).add(z_mu)

print_tensor(z)

torch.Size([2, 4]) 
 tensor([[ 0.9510,  1.2275,  0.0739,  0.6057],
        [-0.8320,  0.4875,  1.2786, -0.2984]], grad_fn=<AddBackward0>)


In [247]:
z = torch.cat((z, y_onehot), dim=1)

In [248]:
print_tensor(z)

torch.Size([2, 7]) 
 tensor([[ 0.9510,  1.2275,  0.0739,  0.6057,  0.0000,  1.0000,  0.0000],
        [-0.8320,  0.4875,  1.2786, -0.2984,  0.0000,  0.0000,  1.0000]],
       grad_fn=<CatBackward>)


## Calculate the loss for the labelled data

In [249]:
batch_kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1)
print_tensor(batch_kl_loss)

torch.Size([2]) 
 tensor([2.8445, 0.3676], grad_fn=<MulBackward0>)


In [251]:
batch_mse_loss = torch.nn.MSELoss(reduction="sum")(torch.randn(2,7), z)
print_tensor(batch_mse_loss)

torch.Size([]) 
 tensor(34.1392, grad_fn=<SumBackward0>)


In [252]:
mse_loss = batch_mse_loss/batch_size
kl_loss = torch.sum(batch_kl_loss, dim=0)

vae_loss = mse_loss + kl_loss
print_tensor(vae_loss)

torch.Size([]) 
 tensor(20.2816, grad_fn=<AddBackward0>)


In [181]:
pi_x = torch.randn(batch_size, num_classes)
print_tensor(pi_x)

torch.Size([2, 3]) 
 tensor([[-1.4511,  0.7269,  2.6331],
        [ 1.0599, -0.1613,  0.3179]])


In [185]:
ce_loss = torch.nn.CrossEntropyLoss(reduction="none")(pi_x, labels)
print_tensor(ce_loss)

torch.Size([2]) 
 tensor([2.0593, 1.3136])


In [186]:
ce_loss = torch.nn.CrossEntropyLoss()(pi_x, labels)
print_tensor(ce_loss)

torch.Size([]) 
 tensor(1.6864)


In [187]:
loss = vae_loss + ce_loss
print_tensor(loss)

torch.Size([]) 
 tensor(10.4115, grad_fn=<AddBackward0>)


## Miscellanous

In [253]:
batch_mse = torch.randn((4,3,1))
print_tensor(batch_mse)

torch.Size([4, 3, 1]) 
 tensor([[[-1.7813],
         [ 0.7235],
         [ 0.5495]],

        [[-1.9287],
         [ 2.2646],
         [ 1.3902]],

        [[ 0.3168],
         [ 0.5489],
         [-0.7260]],

        [[ 0.0548],
         [ 1.0034],
         [-0.7268]]])


In [255]:
mse = torch.mean(batch_mse, dim=1)
print_tensor(mse)

torch.Size([4, 1]) 
 tensor([[-0.1695],
        [ 0.5753],
        [ 0.0466],
        [ 0.1105]])


In [256]:
mse = torch.mean(mse)
print_tensor(mse)

torch.Size([]) 
 tensor(0.1407)
