#### Get Scratches from Google Drive

In [1]:
!wget -O Zero_Shot --no-check-certificate https://www.dropbox.com/sh/kh0rpm7f52xhrkz/AACqSpx4Yc5BS9ByGpLApr68a?dl=0
!unzip Zero_Shot
!rm Zero_Shot
!rm -r CIFAR10-Generators/
!rm -r SVHN-Generators/

--2019-11-06 19:09:10--  https://www.dropbox.com/sh/kh0rpm7f52xhrkz/AACqSpx4Yc5BS9ByGpLApr68a?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:6018:1::a27d:301
Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /sh/raw/kh0rpm7f52xhrkz/AACqSpx4Yc5BS9ByGpLApr68a [following]
--2019-11-06 19:09:15--  https://www.dropbox.com/sh/raw/kh0rpm7f52xhrkz/AACqSpx4Yc5BS9ByGpLApr68a
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc690c0182e5a04c59db5b4665e8.dl.dropboxusercontent.com/zip_by_token_key?key=Ar2qV3fqtkjPUzdNtyL47hbe_y0RG5iw727Y_iOa8dkApXgEg1aHKcJFpKQi6_UZ2DL5mTrgc3wBrs-SMsz40Ty8pbhCJ-yd9xodwA81scizGeVZeSFeaiw9mkoqn-YoaYsMEAflnGyrC0eqDx_-cvO9LewAcywIDhfN873hWXLATYsmAwu-7a8YZg824zGUmifLtBwxKqJhDcT_Umx42sDBeKXF7ZhfPAjs25S1CuEQ-w [following]
--2019-11-06 19:09:16--  https://uc690c0182e5a04c59db5b46

In [0]:
import torch
import torchvision
from torch import nn as nn
from torchvision import transforms
device = torch.device('cuda:0')

In [0]:
class IndividualBlock1(nn.Module):

    def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(IndividualBlock1, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)

        self.subsample_input = subsample_input
        self.increase_filters = increase_filters
        if subsample_input:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
        elif increase_filters:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):

        if self.subsample_input or self.increase_filters:
            x = self.batch_norm1(x)
            x = self.activation(x)
            x1 = self.conv1(x)
        else:
            x1 = self.batch_norm1(x)
            x1 = self.activation(x1)
            x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        if self.subsample_input or self.increase_filters:
            return self.conv_inp(x) + x1
        else:
            return x + x1


class IndividualBlockN(nn.Module):

    def __init__(self, input_features, output_features, stride):
        super(IndividualBlockN, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)

    def forward(self, x):
        x1 = self.batch_norm1(x)
        x1 = self.activation(x1)
        x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        return x1 + x


class Nblock(nn.Module):

    def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(Nblock, self).__init__()

        layers = []
        for i in range(N):
            if i == 0:
                layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
            else:
                layers.append(IndividualBlockN(output_features, output_features, stride=1))

        self.nblockLayer = nn.Sequential(*layers)

    def forward(self, x):
        return self.nblockLayer(x)


class WideResNet(nn.Module):

    def __init__(self, d, k, n_classes, input_features, output_features, strides):
        super(WideResNet, self).__init__()

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)

        filters = [16 * k, 32 * k, 64 * k]
        self.out_filters = filters[-1]
        N = (d - 4) // 6
        increase_filters = k > 1
        self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
        self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
        self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])

        self.batch_norm = nn.BatchNorm2d(filters[-1])
        self.activation = nn.ReLU(inplace=True)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(filters[-1], n_classes)

    def forward(self, x):

        x = self.conv1(x)
        attention1 = self.block1(x)
        attention2 = self.block2(attention1)
        attention3 = self.block3(attention2)
        out = self.batch_norm(attention3)
        out = self.activation(out)
        out = self.avg_pool(out)
        out = out.view(-1, self.out_filters)

        return self.fc(out), attention1, attention2, attention3

In [0]:
def _test_set_eval(net, device, test_loader):
    with torch.no_grad():
        correct, total = 0, 0
        net.eval()

        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = net(images)[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        accuracy = round(100 * accuracy, 2)

        return accuracy

# SVHN

In [5]:
svhn_normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))

svhn_transform = transforms.Compose([
    transforms.ToTensor(),
    svhn_normalize,
])

svhn_test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=svhn_transform)
svhn_test_loader = testloader = torch.utils.data.DataLoader(svhn_test_set, batch_size=100, shuffle=False, num_workers=0)

  0%|          | 0/64275384 [00:00<?, ?it/s]

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data/test_32x32.mat


64282624it [00:02, 30634797.62it/s]                              


#### Zero Shot - Teacher WRN-40-2 Student WRN-16-1 

#### Seed 0

In [6]:
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/zero_shot_teacher_wrn-40-2_student_wrn-16-1-M-0-seed-0-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.21

#### Seed 1

In [7]:
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/zero_shot_teacher_wrn-40-2_student_wrn-16-1-M-0-seed-1-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.85

#### Seed 2

In [8]:
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/zero_shot_teacher_wrn-40-2_student_wrn-16-1-M-0-seed-2-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.94

# CIFAR10

In [9]:
cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

cifar_transform = transforms.Compose([
  transforms.ToTensor(),
  cifar_normalize,
])

cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)
cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=100, shuffle=False, num_workers=0)

  0%|          | 0/170498071 [00:00<?, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:02, 80991321.64it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data


#### Zero Shot - Teacher WRN-16-2  Student WRN-16-1

#### Seed 0

In [10]:
seed="0"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

80.59

#### Seed 1

In [11]:
seed="1"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

80.7

#### Seed 2

In [12]:
seed="2"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

82.48

#### Zero Shot - Teacher WRN-40-1  Student WRN-16-1

#### Seed 0

In [13]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

77.4

#### Seed 1

In [14]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

80.61

#### Seed 2

In [15]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

81.7

#### Zero Shot - Teacher WRN-40-1  Student WRN-16-2

#### Seed 0

In [16]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2

strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

88.71

#### Seed 1

In [17]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

87.34

#### Seed 2

In [18]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

87.08

#### Zero Shot - Teacher WRN-40-2  Student WRN-16-1

#### Seed 0

In [19]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.73

#### Seed 1

In [20]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.76

#### Seed 2

In [21]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.42

#### Zero Shot - Teacher WRN-40-2  Student WRN-16-2

#### Seed 0

In [22]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

89.13

#### Seed 1

In [23]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

89.48

#### Seed 2

In [24]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

89.32

#### Zero Shot - Teacher WRN-40-2  Student WRN-40-1

#### Seed 0

In [25]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

87.94

#### Seed 1

In [26]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

87.28

#### Seed 2

In [27]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/reproducibility_zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

87.18