In [1]:
import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [2]:
net = AlexNet(num_classes=10)

In [3]:
def join_layers(vision_model):
    layers = [
        *vision_model.features,
        vision_model.avgpool,
        lambda x: torch.flatten(x, 1),
        *vision_model.classifier,
    ]
    return layers

In [4]:
pipeline_layer = join_layers(net)

In [5]:
pipeline_layer

[Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(inplace=True),
 Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(inplace=True),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 AdaptiveAvgPool2d(output_size=(6, 6)),
 <function __main__.join_layers.<locals>.<lambda>(x)>,
 Dropout(p=0.5, inplace=False),
 Linear(in_features=9216, out_features=4096, bias=True),
 ReLU(inplace=True),
 Dropout(p=0.5, inplace=False),
 Linear(in_features=4096, out_features=4096, bias=True),
 ReLU(inplace=True),
 

In [4]:
from torchvision import datasets, transforms

import torchvision

In [5]:
def cifar_trainset(local_rank, dl_path='/tmp/cifar10-data'):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Ensure only one rank downloads.
    # Note: if the download path is not on a shared filesytem, remove the semaphore
    # and switch to args.local_rank
   
    trainset = torchvision.datasets.CIFAR10(root=dl_path,
                                            train=True,
                                            download=True,
                                            transform=transform)
    
    return trainset

In [7]:
d = cifar_trainset(0)

Files already downloaded and verified


In [9]:
it = iter(d)

In [12]:
next(it)

(tensor([[[ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          ...,
          [-0.2856, -0.3027, -0.3198,  ..., -0.9020, -0.9020, -0.9020],
          [-0.3027, -0.3198, -0.3369,  ..., -0.9020, -0.9020, -0.9020],
          [-0.3198, -0.3369, -0.3541,  ..., -0.9192, -0.9020, -0.9020]],
 
         [[ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          ...,
          [-0.0399, -0.0574, -0.0749,  ..., -0.6352, -0.6352, -0.6352],
          [-0.0574, -0.0749, -0.0924,  ..., -0.6527, -0.6527, -0.6527],
          [-0.0749, -0.0924, -0.1099,  ..., -0.6702, -0.6702, -0.6702]],
 
         [[ 2.6226,  2.6226,  2.6226,  ...,  2.6226,  2.6226,  2.6226],
          [ 2.6226,  2.6226,