In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
class ConditionProjection(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    
    def forward(self, x):
        data, condition = x
        return torch.cat([data, self.f(condition)], dim=1)

In [4]:
p = builder([
    nn.Linear(10, 512),
    nn.ReLU(),
    nn.Linear(512, 784),
    layer.Reshape((-1, 1, 28, 28))
])

In [5]:
g = builder([
    nn.Linear(110, 8*4*4),
    layer.Reshape((-1, 8, 4, 4)),
    AutoConvTransposeNet2d(8, 3),
    nn.Conv2d(256, 128, 3),
    nn.BatchNorm2d(128),
    nn.ELU(),
    nn.Conv2d(128, 1, 3),
    nn.Sigmoid()
])  

In [6]:
d = builder([
    AutoConvNet2d(1, 4),
    Auxiliary(builder([
        nn.Linear(512, 10),
        nn.LogSoftmax(-1)
    ])),
    nn.Linear(512, 1),
    nn.Sigmoid(),
])

In [7]:
train_loader = VisionDataset(vision.MNIST, batch_size=32, train=True, normalize=False)

In [8]:
runner = Runner(GANModel(g, d, condition=[True, False]), train_loader, None, 'rmsprop', GANLoss('vanilla', aux_loss=F.nll_loss))

In [12]:
runner.init_model('orthogonal')

Initialize to orthogonal: Linear(in_features=110, out_features=128, bias=True)
Initialize to orthogonal: Conv2d(8, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Initialize to orthogonal: ConvTranspose2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
Initialize to orthogonal: Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Initialize to orthogonal: ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
Initialize to orthogonal: Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Initialize to orthogonal: ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
Initialize to orthogonal: Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
Initialize to orthogonal: Conv2d(128, 1, kernel_size=(3, 3), stride=(1, 1))
Initialize to orthogonal: Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Initialize to orthogonal: Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Initialize to orthogonal: Conv2d(128, 256, ker

In [None]:
runner.fit(10)

In [None]:
condition = to_one_hot(torch.from_numpy(np.array(list(range(10)))).cuda(), 10).cuda()
img = runner.model.sample(10, condition_data=condition)

In [None]:
import matplotlib.pyplot as plt
fig, arr = plt.subplots(1, 10, figsize=(20, 20))
for i in range(10):
    arr[i].imshow(img[i][0].detach().cpu().numpy(), cmap='gray')
plt.show()

In [None]:
runner.history()