In [1]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchbearer import Trial
import torchbearer
from device import DEVICE
from os.path import exists
import os
import json
from torchvision.models import resnet18

In [4]:
model1 = resnet18()
model1.fc = nn.Linear(512, 4) # four rotations
model1 = model1.to(DEVICE)

In [22]:
model1

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [25]:
class myResNet(nn.Module):
    def __init__(self, model, num_labeled_classes = 5, num_unlabeled_classes = 5, has_unlabeled = False):
        super(myResNet, self).__init__()
        self.has_unlabeled = has_unlabeled
        self.conv1    = model.conv1
        self.bn1      = model.bn1
        self.layer1   = model.layer1
        self.layer2   = model.layer2
        self.layer3   = model.layer3
        self.layer4   = model.layer4
        self.head1 = nn.Linear(512, num_labeled_classes)
        self.head2 = nn.Linear(512, num_unlabeled_classes)


    def forward(self, x):
        out = nn.MaxPool2d(F.relu(self.bn1(self.conv1(x))), kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = nn.AdaptiveAvgPool2d((1, 1))
        out = F.relu(out) #add ReLU to benifit ranking
        out1 = self.head1(out)
        out2 = self.head2(out)
        if self.has_unlabeled:
            return out1, out2, out
        return out1

In [26]:
model = myResNet(model1)

In [28]:
model

myResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d

In [27]:
# load pre-trained model
state_dict = torch.load('./selfsupervised_learning/rotnet_cifar10.pth')
del state_dict['fc.weight']
del state_dict['fc.bias']
model.load_state_dict(state_dict, strict = False)

_IncompatibleKeys(missing_keys=['head1.weight', 'head1.bias', 'head2.weight', 'head2.bias'], unexpected_keys=[])