In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import skorch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from skorch.helper import predefined_split

In [2]:
class MyClassifier(nn.Module):
    def __init__(self):
        super(MyClassifier, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)
        
        for param in self.model.parameters():
            param.requires_grad = False
        
        self.fc1 = nn.Sequential(
            nn.Linear(1000, 64),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(64, 10)
        )
    
    def embed(self, x):
        return self.fc1(self.model(x))
    
    def forward(self, x):
        resnet_out = self.embed(x)
        return F.softmax(self.fc2(resnet_out), dim=-1)

In [4]:
from torchinfo import summary

summary(MyClassifier(), input_size=(1, 3, 96, 96))

Layer (type:depth-idx)                                       Output Shape              Param #
MyClassifier                                                 --                        --
├─EfficientNet: 1-1                                          [1, 1000]                 --
│    └─Sequential: 2-1                                       [1, 1280, 3, 3]           --
│    │    └─ConvNormActivation: 3-1                          [1, 32, 48, 48]           (928)
│    │    └─Sequential: 3-2                                  [1, 16, 48, 48]           (1,448)
│    │    └─Sequential: 3-3                                  [1, 24, 24, 24]           (16,714)
│    │    └─Sequential: 3-4                                  [1, 40, 12, 12]           (46,640)
│    │    └─Sequential: 3-5                                  [1, 80, 6, 6]             (242,930)
│    │    └─Sequential: 3-6                                  [1, 112, 6, 6]            (543,148)
│    │    └─Sequential: 3-7                                  

In [3]:
data = datasets.STL10(root='stl10', download=True, transform=transforms.ToTensor())

Files already downloaded and verified


In [4]:
net = skorch.NeuralNetClassifier(
    MyClassifier,
    max_epochs=20,
    train_split=predefined_split(data),
    lr=0.1,
    batch_size=128,
    device='cuda',
    # criterion=
)

In [6]:
trained = net.fit(data, y=None)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m1.9202[0m       [32m0.6558[0m        [35m1.0996[0m  5.5114
      2        [36m1.0588[0m       [32m0.7628[0m        [35m0.7346[0m  4.4213
      3        [36m0.9112[0m       [32m0.7814[0m        [35m0.6694[0m  4.4353
      4        [36m0.8333[0m       [32m0.8110[0m        [35m0.5697[0m  4.4731
      5        [36m0.7806[0m       [32m0.8290[0m        [35m0.5171[0m  4.3940
      6        [36m0.7350[0m       [32m0.8452[0m        [35m0.4797[0m  4.5165
      7        [36m0.7150[0m       [32m0.8536[0m        [35m0.4498[0m  4.7205
      8        [36m0.6882[0m       [32m0.8662[0m        [35m0.4207[0m  4.6702
      9        [36m0.6508[0m       [32m0.8754[0m        [35m0.4006[0m  4.5942
     10        [36m0.6382[0m       [32m0.8832[0m        [35m0.3647[0m  4.7451
     11        [36m0.6272[0m       [32m0.89

In [None]:
from random import randint
xs, ys = zip(*[data[randint(0, len(data))] for _ in range(500)])
xs = torch.stack(xs)

In [None]:
from tensorboardX import SummaryWriter

with SummaryWriter() as writer:
    writer.add_embedding(trained.module_.embed(xs.to('cuda')),
                         metadata=ys,
                         label_img=xs)