In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
import torch

In [2]:
def to_one_hot(x):
    b = np.zeros((x.size, x.max()+1))
    b[np.arange(x.size), x] = 1
    return b

In [3]:
data_path = "../../data/mnist/"
x_test = torch.Tensor(load(f'{data_path}/x_test.npy'))
x_train = torch.Tensor(load(f'{data_path}/x_train.npy'))
y_test = torch.Tensor(to_one_hot(load(f'{data_path}/y_test.npy').astype(int)))
y_train = torch.Tensor(to_one_hot(load(f'{data_path}/y_train.npy').astype(int)))

In [4]:
print(x_train.shape)
print(y_train.shape)

print(torch.min(x_test))
print(torch.max(x_test))
print(y_train[0])


torch.Size([4000, 784])
torch.Size([4000, 10])
tensor(0.)
tensor(1.)
tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])


In [5]:
def get_random(shape, min_=-0.5, max_=0.5):
    return torch.FloatTensor(*shape).uniform_(min_, max_).requires_grad_()

# Model

```yaml
Generator:
  - Input: 100 (90 noise, 10 one-hot)
  - h1: 150
  - h2: 400
  - h3: 784

FeatureExtractor:
  - Input: 784 (28x28)
  - h1: 400
  - h2: 150
  - h3: 100

Discriminator:
  - Input: 100
  - h1: 50
  - h2: 1

Classifier:
  - Input: 100
  - h1: 50
  - h2: 10
```

In [60]:
class Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 150)), torch.zeros(150, requires_grad=True)
        self.w2, self.b2 = get_random((150, 400)), torch.zeros(400, requires_grad=True)
        self.w3, self.b3 = get_random((400, 784)), torch.zeros(784, requires_grad=True)
    
    def forward(self, x, **kwargs):
        # x.shape = (batch, 100)
        h1 = x.matmul(self.w1) + self.b1
        h2 = h1.matmul(self.w2) + self.b2
        h3 = torch.nn.functional.softmax(h2.matmul(self.w3) + self.b3)
        return h3
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        self.w3 = (self.w3 - lr * self.w3.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()
        self.b3 = (self.b3 - lr * self.b3.grad).detach().requires_grad_()

In [80]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((784, 400)), torch.zeros(400, requires_grad=True)
        self.w2, self.b2 = get_random((400, 150)), torch.zeros(150, requires_grad=True)
        self.w3, self.b3 = get_random((150, 100)), torch.zeros(100, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = torch.nn.functional.sigmoid(x.matmul(self.w1) + self.b1)
        h2 = torch.nn.functional.sigmoid(h1.matmul(self.w2) + self.b2)
        h3 = torch.nn.functional.sigmoid(h2.matmul(self.w3) + self.b3)
        return h3
    
    def optimize(self, lr):
        
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        self.w3 = (self.w3 - lr * self.w3.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()
        self.b3 = (self.b3 - lr * self.b3.grad).detach().requires_grad_()
        

In [81]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 50)), torch.zeros(50, requires_grad=True)
        self.w2, self.b2 = get_random((50, 1)), torch.zeros(1, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = x.matmul(self.w1) + self.b1
        h2 = torch.nn.functional.tanh(h1.matmul(self.w2) + self.b2)
        return h2
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()

In [82]:
    
class Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 50)), torch.zeros(50, requires_grad=True)
        self.w2, self.b2 = get_random((50, 10)), torch.zeros(10, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = torch.nn.functional.sigmoid(x.matmul(self.w1) + self.b1)
        h2 = torch.nn.functional.softmax(h1.matmul(self.w2) + self.b2)
        return h2
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()


In [107]:
fe = FeatureExtractor()
classifier = Classifier()
discriminator = Discriminator()
generator = Generator()

In [110]:
batch_size = 20
lr = 0.5
epochs = 10

In [112]:
for epoch in range(epochs):
    print(f'Epoch: {epoch}')
    losses = 0
    for i in range(0, len(x_train), batch_size):
        start_index = i
        end_index = i+batch_size
        
        x_batch = x_train[start_index:end_index]
        y_batch = y_train[start_index:end_index]
        
        features = fe(x_batch)
        classes = classifier(features)
        
        loss = torch.nn.functional.binary_cross_entropy(classes, y_batch)
        loss.backward()
        
        fe.optimize(lr)
        classifier.optimize(lr)
        
        losses += loss.detach().numpy()
    print(f' Loss: {losses}')

Epoch: 0
 Loss: 0.1742004120023921
Epoch: 1
 Loss: 0.15941612530878047
Epoch: 2
 Loss: 0.14681795617070748
Epoch: 3
 Loss: 0.13594687085424084
Epoch: 4
 Loss: 0.12647330131585477
Epoch: 5
 Loss: 0.11815060551452916
Epoch: 6
 Loss: 0.11078955536504509
Epoch: 7
 Loss: 0.1042407589557115
Epoch: 8
 Loss: 0.09838390069489833
Epoch: 9
 Loss: 0.09312021703226492


  if __name__ == '__main__':
