In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.utils import weight_norm

In [2]:
%run utils.ipynb
#from utils import *
%run config.ipynb
#from config import *
%run image_model.ipynb
#from image_model import *

In [3]:
class GenerateNet(nn.Module):
    def __init__(self):
        super(GenerateNet, self).__init__()
        self.num_class = 100
        self.img_size = 28
        
        self.drop = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        self.embed_layer = nn.Embedding(100, 1024, padding_idx=0)
        
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 28 * 28 * 4)
        self.bn2 = nn.BatchNorm1d(28 * 28 * 4)
        
        self.block1 = InceptionBlock(4, 8)
        self.block2 = InceptionBlock(8, 16)
        
        self.pose_conv = nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
        
        self.classification = ImageNet()
        
    def forward(self, input_label):
        # input size [batch_size]
        x = input_label.view(-1)
        x = self.embed_layer(x) # batch_size, 512
        x = self.drop(x)
        
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.drop(x)
        
        x = x.view(-1, 4, 28, 28)
        x = self.block1(x)
        x = self.relu(x)
        x = self.block2(x)
        x = self.relu(x)
        x = self.pose_conv(x)
        x = self.sigmoid(x)
        l2_loss = (x**2).sum()
        
        int_x = torch.round(x.detach(), out=None)
        x = (int_x - x.detach()) + x
        
        x = x * 255
        predict = self.classification(x, None)
        
        return x, predict, l2_loss