In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
import torch

In [2]:
def to_one_hot(x):
    b = np.zeros((x.size, x.max()+1))
    b[np.arange(x.size), x] = 1
    return b

In [3]:
data_path = "../../data/mnist/"
x_test = torch.Tensor(load(f'{data_path}/x_test.npy'))
x_train = torch.Tensor(load(f'{data_path}/x_train.npy'))
y_test = torch.Tensor(to_one_hot(load(f'{data_path}/y_test.npy').astype(int)))
y_train = torch.Tensor(to_one_hot(load(f'{data_path}/y_train.npy').astype(int)))

In [4]:
print(x_train.shape)
print(y_train.shape)

print(torch.min(x_test))
print(torch.max(x_test))
print(y_train[0])


torch.Size([4000, 784])
torch.Size([4000, 10])
tensor(0.)
tensor(1.)
tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])


In [5]:
def get_random(shape, min_=-0.5, max_=0.5):
    return torch.FloatTensor(*shape).uniform_(min_, max_).requires_grad_()

# Model

```yaml
Generator:
  - Input: 100 (90 noise, 10 one-hot)
  - h1: 150
  - h2: 400
  - h3: 784

FeatureExtractor:
  - Input: 784 (28x28)
  - h1: 400
  - h2: 150
  - h3: 100

Discriminator:
  - Input: 100
  - h1: 50
  - h2: 1

Classifier:
  - Input: 100
  - h1: 50
  - h2: 10
```

In [6]:
class Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 150)), torch.zeros(150, requires_grad=True)
        self.w2, self.b2 = get_random((150, 400)), torch.zeros(400, requires_grad=True)
        self.w3, self.b3 = get_random((400, 784)), torch.zeros(784, requires_grad=True)
    
    def forward(self, x, **kwargs):
        # x.shape = (batch, 100)
        h1 = x.matmul(self.w1) + self.b1
        h2 = h1.matmul(self.w2) + self.b2
        h3 = torch.nn.functional.softmax(h2.matmul(self.w3) + self.b3)
        return h3
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        self.w3 = (self.w3 - lr * self.w3.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()
        self.b3 = (self.b3 - lr * self.b3.grad).detach().requires_grad_()

In [7]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((784, 400)), torch.zeros(400, requires_grad=True)
        self.w2, self.b2 = get_random((400, 150)), torch.zeros(150, requires_grad=True)
        self.w3, self.b3 = get_random((150, 100)), torch.zeros(100, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = torch.sigmoid(x.matmul(self.w1) + self.b1)
        h2 = torch.sigmoid(h1.matmul(self.w2) + self.b2)
        h3 = torch.sigmoid(h2.matmul(self.w3) + self.b3)
        return h3
    
    def optimize(self, lr):
        
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        self.w3 = (self.w3 - lr * self.w3.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()
        self.b3 = (self.b3 - lr * self.b3.grad).detach().requires_grad_()
        

In [8]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 50)), torch.zeros(50, requires_grad=True)
        self.w2, self.b2 = get_random((50, 1)), torch.zeros(1, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = x.matmul(self.w1) + self.b1
        h2 = torch.sigmoid(h1.matmul(self.w2) + self.b2)
        return h2
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()

In [9]:
    
class Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1, self.b1 = get_random((100, 50)), torch.zeros(50, requires_grad=True)
        self.w2, self.b2 = get_random((50, 10)), torch.zeros(10, requires_grad=True)
    
    def forward(self, x, **kwargs):
        h1 = torch.sigmoid(x.matmul(self.w1) + self.b1)
        h2 = torch.softmax(h1.matmul(self.w2) + self.b2)
        return h2
    
    def optimize(self, lr):
        self.w1 = (self.w1 - lr * self.w1.grad).detach().requires_grad_()
        self.w2 = (self.w2 - lr * self.w2.grad).detach().requires_grad_()
        
        self.b1 = (self.b1 - lr * self.b1.grad).detach().requires_grad_()
        self.b2 = (self.b2 - lr * self.b2.grad).detach().requires_grad_()


In [10]:
fe = FeatureExtractor()
classifier = Classifier()
discriminator = Discriminator()
generator = Generator()

In [13]:
batch_size = 5
epochs = 3

for epoch in range(epochs):
    print(f'Epoch: {epoch}')
    losses = [0, 0, 0]
    for i in range(0, len(x_train), batch_size):
        start_index = i
        end_index = i+batch_size
        
        x_batch = x_train[start_index:end_index]
        y_batch = y_train[start_index:end_index]
    
        
        # ------------ Train with real image ----------------
        real = torch.ones(batch_size)
        fake = torch.zeros(batch_size)
        
        features = fe(x_batch)
        discriminator_out = discriminator(features)
        
        loss = torch.nn.functional.binary_cross_entropy(discriminator_out, real)
        
        loss.backward()
        discriminator.optimize(0.1)
        fe.optimize(0.1)
        
        losses[0] += float(loss)
        
        # --------------- Train with fake image -------------------
        
        generator_input = np.random.uniform(0, 1, (batch_size, 90))
        generator_input = np.concatenate((generator_input, y_batch), axis=1)
        generator_input = torch.Tensor(generator_input)
        
        generated = generator(generator_input)
        features = fe(generated)
        discriminator_out = discriminator(features)
        classifier_out = classifier(features)
        
        loss_discriminator = torch.nn.functional.binary_cross_entropy(discriminator_out, fake)
        loss_classifier = torch.nn.functional.binary_cross_entropy(classifier_out, y_batch)
        
        loss = loss_discriminator + loss_classifier
        loss.backward()
        
        discriminator.optimize(0.1)
        classifier.optimize(0.1)
        fe.optimize(0.1)
        generator.optimize(0.1)
        
        losses[1] += float(loss_discriminator)
        losses[2] += float(loss_classifier)
        
    print(f'  Losses: {losses}')


Epoch: 0
  Losses: [0.6287707017072819, 0.3818388574945857, 265.1557381749153]
Epoch: 1
  Losses: [0.10480095307352144, 0.11042106978129596, 254.13847422599792]
Epoch: 2
  Losses: [0.05033660190800404, 0.07669430607711547, 248.04603338241577]


  if sys.path[0] == '':
  if __name__ == '__main__':


Epoch: 0
 Loss: 62.56604519486427
Epoch: 1
 Loss: 53.11286461353302
Epoch: 2
 Loss: 40.12722598016262
Epoch: 3
 Loss: 30.553485609591007
Epoch: 4
 Loss: 24.46516615524888
Epoch: 5
 Loss: 20.358985599130392
Epoch: 6
 Loss: 17.420197769999504
Epoch: 7
 Loss: 15.226270627230406
Epoch: 8
 Loss: 13.521448962390423
Epoch: 9
 Loss: 12.143279342912138
tensor([[3.6764e-03, 1.4693e-02, 5.2649e-01, 1.4998e-02, 2.1276e-04, 6.2665e-03,
         3.3053e-04, 3.5649e-02, 3.9260e-01, 5.0789e-03],
        [1.3938e-05, 9.8956e-01, 1.7119e-03, 1.9275e-03, 1.8419e-04, 1.5237e-03,
         6.7148e-04, 1.0899e-03, 2.9270e-03, 3.8668e-04]])
tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])


  if __name__ == '__main__':


In [None]:
# This is just the classifier.

batch_size = 20
lr = 0.5
epochs = 10

for epoch in range(epochs):
    print(f'Epoch: {epoch}')
    losses = 0
    for i in range(0, len(x_train), batch_size):
        start_index = i
        end_index = i+batch_size
        
        x_batch = x_train[start_index:end_index]
        y_batch = y_train[start_index:end_index]
        
        features = fe(x_batch)
        classes = classifier(features)
        
        loss = torch.nn.functional.binary_cross_entropy(classes, y_batch)
        loss.backward()
        
        fe.optimize(lr)
        classifier.optimize(lr)
        
        losses += loss.detach().numpy()
    print(f' Loss: {losses}')

with torch.no_grad():
    index = 133
    x_sample = x_test[index:index+2]
    y_sample = y_test[index:index+2]
    y_pred = classifier(fe(x_sample))
    print(y_pred)
    print(y_sample)