#### Get Scratches from Google Drive

In [1]:

!wget -O modified_Zero_Shot --no-check-certificate https://www.dropbox.com/sh/5oxfl1xisskw0ub/AAB8wpAH3AslQxiyPxbzUhr0a?dl=0
!unzip modified_Zero_Shot
!rm modified_Zero_Shot


--2019-11-06 19:09:13--  https://www.dropbox.com/sh/5oxfl1xisskw0ub/AAB8wpAH3AslQxiyPxbzUhr0a?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:601b:1::a27d:801
Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /sh/raw/5oxfl1xisskw0ub/AAB8wpAH3AslQxiyPxbzUhr0a [following]
--2019-11-06 19:09:14--  https://www.dropbox.com/sh/raw/5oxfl1xisskw0ub/AAB8wpAH3AslQxiyPxbzUhr0a
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc57cd5f4a36c0a167cba88e6031.dl.dropboxusercontent.com/zip_by_token_key?key=Ar1BagEdnj1LpVeopV9yqLvx44N3PmaCJnZ3JOndGdlh6iqd-UXLaXgU5fGgaYH6vekV_7i2ZhTWbFHF8xs7rBa1kOdZrdceMp-kjaz1Dl7r_F_InnTvqO-JA7kb1EqqtQedvXjwtjve_8qS_m2sXPYzXYeB-uc_jrXJGf7qE0d3xt4irY7mEsa8AQrCrnxZbDsQmXhznrSyG_uibpWS47inBrRNfplg0fWzWCtwNuvfX-pz4Zf1u5GOTzwNqkTFBZg [following]
--2019-11-06 19:09:14--  https://uc5

In [0]:
import torch
import torchvision
from torch import nn as nn
from torchvision import transforms
device = torch.device('cuda:0')

In [0]:
class IndividualBlock1(nn.Module):

    def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(IndividualBlock1, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)

        self.subsample_input = subsample_input
        self.increase_filters = increase_filters
        if subsample_input:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
        elif increase_filters:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):

        if self.subsample_input or self.increase_filters:
            x = self.batch_norm1(x)
            x = self.activation(x)
            x1 = self.conv1(x)
        else:
            x1 = self.batch_norm1(x)
            x1 = self.activation(x1)
            x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        if self.subsample_input or self.increase_filters:
            return self.conv_inp(x) + x1
        else:
            return x + x1


class IndividualBlockN(nn.Module):

    def __init__(self, input_features, output_features, stride):
        super(IndividualBlockN, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)

    def forward(self, x):
        x1 = self.batch_norm1(x)
        x1 = self.activation(x1)
        x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        return x1 + x


class Nblock(nn.Module):

    def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(Nblock, self).__init__()

        layers = []
        for i in range(N):
            if i == 0:
                layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
            else:
                layers.append(IndividualBlockN(output_features, output_features, stride=1))

        self.nblockLayer = nn.Sequential(*layers)

    def forward(self, x):
        return self.nblockLayer(x)


class WideResNet(nn.Module):

    def __init__(self, d, k, n_classes, input_features, output_features, strides):
        super(WideResNet, self).__init__()

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)

        filters = [16 * k, 32 * k, 64 * k]
        self.out_filters = filters[-1]
        N = (d - 4) // 6
        increase_filters = k > 1
        self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
        self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
        self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])

        self.batch_norm = nn.BatchNorm2d(filters[-1])
        self.activation = nn.ReLU(inplace=True)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(filters[-1], n_classes)

    def forward(self, x):

        x = self.conv1(x)
        attention1 = self.block1(x)
        attention2 = self.block2(attention1)
        attention3 = self.block3(attention2)
        out = self.batch_norm(attention3)
        out = self.activation(out)
        out = self.avg_pool(out)
        out = out.view(-1, self.out_filters)

        return self.fc(out), attention1, attention2, attention3

In [0]:
def _test_set_eval(net, device, test_loader):
    with torch.no_grad():
        correct, total = 0, 0
        net.eval()

        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = net(images)[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        accuracy = round(100 * accuracy, 2)

        return accuracy

# CIFAR10

In [5]:
cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

cifar_transform = transforms.Compose([
  transforms.ToTensor(),
  cifar_normalize,
])

cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)
cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=100, shuffle=False, num_workers=0)

  0%|          | 0/170498071 [00:00<?, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:02, 76204897.08it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data


#### Zero Shot - Teacher WRN-16-2  Student WRN-16-1

#### Seed 0

In [6]:
seed="0"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

82.42

#### Seed 1

In [7]:
seed="1"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

81.73

#### Seed 2

In [8]:
seed="2"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.32

#### Zero Shot - Teacher WRN-40-1  Student WRN-16-1

#### Seed 0

In [9]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

79.87

#### Seed 1

In [10]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.62

#### Seed 2

In [11]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.34

#### Zero Shot - Teacher WRN-40-1  Student WRN-16-2

#### Seed 0

In [12]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2

strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

88.71

#### Seed 1

In [13]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2

strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

90.11

#### Seed 2

In [23]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2

strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

88.99

#### Zero Shot - Teacher WRN-40-2  Student WRN-16-1

#### Seed 0

In [14]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

85.09

#### Seed 1

In [15]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.07

#### Seed 2

In [16]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

85.18

#### Zero Shot - Teacher WRN-40-2  Student WRN-16-2

#### Seed 0

In [17]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

90.67

#### Seed 1

In [18]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

91.41

#### Seed 2

In [19]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

91.27

#### Zero Shot - Teacher WRN-40-2  Student WRN-40-1

#### Seed 0

In [20]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

90.08

#### Seed 1

In [21]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

90.16

#### Seed 2

In [22]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'/Seed-'+seed+'/zero_shot_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-0-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

90.59