In [1]:
import torch
from torch import nn

In [None]:
# Stacked Convolutional Auto-Encoder (the unsupervised sub-network)
class SCAE(nn.Module):
  def __init__(self, num_channels):
    super.__init__()

    self.conv1 = nn.Conv2d(
        in_channels=num_channels,
        out_channels=64,
        kernel_size=11,
        padding=1,
        stride=2
    )
    self.conv2 = nn.Conv2d(
        in_channels=64,
        out_channels=128,
        kernel_size=5,
        padding=2
    )
    self.deconv1 = nn.ConvTranspose2d(
        in_channels = 128,
        out_channels = 64,
        kernel_size = 5,
        padding = 2
    )
    self.deconv2 = nn.ConvTranspose2d(
        in_channels=64,
        out_channels=1,
        kernel_size=11,
        padding=1,
        stride=2
    )
    self.maxpool = nn.MaxPool2d(2, return_indices=True)
    self.unpool = nn.MaxUnpool2d(2)
    self.relu = nn.ReLU()

  def forward(self, x):
    self.conv1(x)
    _, indices = self.maxpool(x)
    self.relu(x)

    self.conv2(x)
    self.relu(x)

    self.deconv1(x)
    self.unpool(x, indices)
    self.relu(x)

    self.deconv2(x)
    self.relu(x)

    return x

In [None]:
class DeepFont(nn.Module):
  def __init__(self, num_channels, num_classes):
    self.conv1 = nn.Conv2d(
        in_channels=num_channels,
        out_channels=64,
        kernel_size=11,
        padding=1,
        stride=2
    )
    self.conv2 = nn.Conv2d(
        in_channels=64,
        out_channels=128,
        kernel_size=5,
        padding=2
    )
    self.conv3 = nn.Conv2d(
        in_channels=128,
        out_channels=256,
        kernel_size=3,
        padding=1
    )
    self.conv4 = nn.Conv2d(
        in_channels=256,
        out_channels=256,
        kernel_size=3,
        padding=1
    )
    self.conv5 = self.conv4
    self.fc6 = nn.Linear(in_features=12*12*256, out_features=4096) # assuming input image size of 105. change in_feats for different sample size
    self.fc7 = nn.Linear(in_features=4096, out_features=4096)
    self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)
    self.norm1 = nn.BatchNorm2d(num_features=64)
    self.norm2 = nn.BatchNorm2d(num_features=128)
    self.dropout = nn.Dropout(0.5)
    self.maxpool = nn.MaxPool2d(2)
    self.relu = nn.ReLU()
    self.flatten = nn.Flatten()
    self.softmax = nn.CrossEntropyLoss()

  def forward(self, x):
    self.conv1(x)
    self.norm1(x)
    self.maxpool(x)
    self.relu(x)

    self.conv2(x)
    self.norm2(x)
    self.maxpool(x)
    self.relu(x)

    self.conv3(x)
    self.relu(x)

    self.conv4(x)
    self.relu(x)

    self.conv5(x)
    self.relu(x)

    self.flatten(x)

    self.dropout(self.fc6(x))
    self.relu(x)

    self.dropout(self.fc7(x))
    self.relu(x)

    self.fc8(x)
    self.relu(x)

    self.softmax(x)

    return x