In [1]:
from torch import nn, flatten, randn

In [7]:
class AlexNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    self.fc1 = nn.Linear(in_features=9216, out_features=4096, bias=True)
    self.fc2 = nn.Linear(in_features=4096, out_features=4096, bias=True)
    self.fc3 = nn.Linear(in_features=4096, out_features=10, bias=True)
    
    self.relu = nn.ReLU(inplace=True)
    self.dropout = nn.Dropout(p=0.5, inplace=False)

  def forward(self, x):
    x = self.maxpool(self.relu(self.conv1(x)))
    x = self.maxpool(self.relu(self.conv2(x)))
    x = self.relu(self.conv3(x))
    x = self.relu(self.conv4(x))
    x = self.maxpool(self.relu(self.conv5(x)))
    x = self.dropout(self.relu(self.fc1(flatten(x))))
    print(x.shape)
    x = self.dropout(self.relu(self.fc2(x)))
    x = self.fc3(x)
    return x

In [30]:
class AlexNetStrongClientModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    self.fc = nn.Linear(in_features=256*6*6, out_features=10)

    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.maxpool(self.relu(self.conv1(x)))
    x = self.maxpool(self.relu(self.conv2(x)))
    x = self.relu(self.conv3(x))
    x = self.relu(self.conv4(x))
    x_off = self.maxpool(self.relu(self.conv5(x)))
    x = self.fc(flatten(x_off))
    return x, x_off

class AlexNetWeakClientModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))

    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.maxpool(self.relu(self.conv1(x)))
    x = self.maxpool(self.relu(self.conv2(x)))
    return x
  
class AlexNetWeakClientOffloadedModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    self.fc = nn.Linear(in_features=256*6*6, out_features=10)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.relu(self.conv3(x))
    x = self.relu(self.conv4(x))
    x_off = self.maxpool(self.relu(self.conv5(x)))
    x = self.fc(flatten(x_off))
    return x, x_off
  
class AlexNetServer(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.fc1 = nn.Linear(in_features=9216, out_features=4096, bias=True)
    self.fc2 = nn.Linear(in_features=4096, out_features=4096, bias=True)
    self.fc3 = nn.Linear(in_features=4096, out_features=10, bias=True)
    self.dropout = nn.Dropout(p=0.5, inplace=False)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.dropout(self.relu(self.fc1(flatten(x))))
    x = self.dropout(self.relu(self.fc2(x)))
    x = self.fc3(x)
    return x

In [31]:
strong = AlexNetStrongClientModel()
weak = AlexNetWeakClientModel()
weak_off = AlexNetWeakClientOffloadedModel()
server = AlexNetServer()

In [32]:
inp = randn(3, 224, 224)
inp.shape

torch.Size([3, 224, 224])

In [33]:
x, x_off = strong(inp)
x.shape, x_off.shape

(torch.Size([10]), torch.Size([256, 6, 6]))

In [35]:
server(x_off)

tensor([-0.0039, -0.0044, -0.0248,  0.0220, -0.0209, -0.0154,  0.0088, -0.0101,
        -0.0126,  0.0175], grad_fn=<ViewBackward0>)

In [36]:
x, x_off = weak_off(weak(inp))
x.shape, x_off.shape

(torch.Size([10]), torch.Size([256, 6, 6]))

In [37]:
server(x_off)

tensor([ 0.0002, -0.0233, -0.0044,  0.0137, -0.0218, -0.0164,  0.0094,  0.0106,
         0.0017,  0.0240], grad_fn=<ViewBackward0>)