#### Get models from Google Drive

In [1]:
!wget -O Zero-Shot-M --no-check-certificate https://www.dropbox.com/sh/g6itjt8bep6ftg5/AABSZeQNUxVnpnTtQk327h6Ja?dl=0
!unzip Zero-Shot-M
!rm Zero-Shot-M

--2019-11-06 19:20:00--  https://www.dropbox.com/sh/g6itjt8bep6ftg5/AABSZeQNUxVnpnTtQk327h6Ja?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.82.1, 2620:100:6032:1::a27d:5201
Connecting to www.dropbox.com (www.dropbox.com)|162.125.82.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /sh/raw/g6itjt8bep6ftg5/AABSZeQNUxVnpnTtQk327h6Ja [following]
--2019-11-06 19:20:01--  https://www.dropbox.com/sh/raw/g6itjt8bep6ftg5/AABSZeQNUxVnpnTtQk327h6Ja
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc9f8073ef5f096f99aa059d954a.dl.dropboxusercontent.com/zip_by_token_key?key=Ar0ya4Ybz6BucEtVZn4QzMOB5YXsgkjbniXcuEQ7ye1lBjWCDdwRuGkcKD_HDTqLmBssENOEtNvTgBz9tR4ug3GWUhDykpiOIGTAwZ1TN_hesNW5VyNF-_WMfRdzewyuU2BQrNne4Pkc-B8zXc56d8UQzsiV9fxEFv39gUIevH3JM7aYTQ0JmMc2QWqNP8vVLT-A-50E_9SpE-S2unVg96nJmiPL2SguR40In3VZFY2Y2w [following]
--2019-11-06 19:20:01--  https://uc9f8073ef5f096f99aa0

In [0]:
import torch
import torchvision
from torch import nn as nn
from torchvision import transforms
device = torch.device('cuda:0')

In [0]:
class IndividualBlock1(nn.Module):

    def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(IndividualBlock1, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)

        self.subsample_input = subsample_input
        self.increase_filters = increase_filters
        if subsample_input:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
        elif increase_filters:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):

        if self.subsample_input or self.increase_filters:
            x = self.batch_norm1(x)
            x = self.activation(x)
            x1 = self.conv1(x)
        else:
            x1 = self.batch_norm1(x)
            x1 = self.activation(x1)
            x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        if self.subsample_input or self.increase_filters:
            return self.conv_inp(x) + x1
        else:
            return x + x1


class IndividualBlockN(nn.Module):

    def __init__(self, input_features, output_features, stride):
        super(IndividualBlockN, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)

    def forward(self, x):
        x1 = self.batch_norm1(x)
        x1 = self.activation(x1)
        x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        return x1 + x


class Nblock(nn.Module):

    def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(Nblock, self).__init__()

        layers = []
        for i in range(N):
            if i == 0:
                layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
            else:
                layers.append(IndividualBlockN(output_features, output_features, stride=1))

        self.nblockLayer = nn.Sequential(*layers)

    def forward(self, x):
        return self.nblockLayer(x)


class WideResNet(nn.Module):

    def __init__(self, d, k, n_classes, input_features, output_features, strides):
        super(WideResNet, self).__init__()

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)

        filters = [16 * k, 32 * k, 64 * k]
        self.out_filters = filters[-1]
        N = (d - 4) // 6
        increase_filters = k > 1
        self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
        self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
        self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])

        self.batch_norm = nn.BatchNorm2d(filters[-1])
        self.activation = nn.ReLU(inplace=True)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(filters[-1], n_classes)

    def forward(self, x):

        x = self.conv1(x)
        attention1 = self.block1(x)
        attention2 = self.block2(attention1)
        attention3 = self.block3(attention2)
        out = self.batch_norm(attention3)
        out = self.activation(out)
        out = self.avg_pool(out)
        out = out.view(-1, self.out_filters)

        return self.fc(out), attention1, attention2, attention3

In [0]:
def _test_set_eval(net, device, test_loader):
    with torch.no_grad():
        correct, total = 0, 0
        net.eval()

        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = net(images)[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        accuracy = round(100 * accuracy, 2)

        return accuracy

# SVHN

In [5]:
svhn_normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))

svhn_transform = transforms.Compose([
    transforms.ToTensor(),
    svhn_normalize,
])

svhn_test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=svhn_transform)
svhn_test_loader = testloader = torch.utils.data.DataLoader(svhn_test_set, batch_size=100, shuffle=False, num_workers=0)

0it [00:00, ?it/s]

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data/test_32x32.mat


64282624it [00:06, 9560399.85it/s]                               


#### Zero-Shot-M - Teacher WRN-40-2 Student WRN-16-1 for different values of fine-tuned M

#### M=10 samples

#### Seed 0

In [6]:
seed="0"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.29

#### Seed 1

In [7]:
seed="1"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.9

#### Seed 2

In [8]:
seed="2"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.0

#### M=25 samples

#### Seed 0

In [9]:
seed="0"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.26

#### Seed 1

In [10]:
seed="1"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.97

#### Seed 2

In [11]:
seed="2"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.98

#### M=50 samples

#### Seed 0

In [12]:
seed="0"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.26

#### Seed 1

In [13]:
seed="1"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.94

#### Seed 2

In [14]:
seed="2"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.97

#### M=75 samples

#### Seed 0

In [15]:
seed="0"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.27

#### Seed 1

In [16]:
seed="1"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.95

#### Seed 2 - Did not improve initial model!

#### M=100 samples

#### Seed 0

In [17]:
seed="0"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

94.24

#### Seed 1

In [18]:
seed="1"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-SVHN-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

93.97

#### Seed 2 - Did not improve initial model!

# CIFAR10

In [19]:
cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

cifar_transform = transforms.Compose([
  transforms.ToTensor(),
  cifar_normalize,
])

cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)
cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=100, shuffle=False, num_workers=0)

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


 99%|█████████▉| 168894464/170498071 [00:12<00:00, 17695885.57it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


#### Zero-Shot-M - Teacher WRN-40-2 Student WRN-16-1 for different values of fine-tuned M

#### M=10 samples

#### Seed 0

In [20]:
seed="0"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.89

#### Seed 1

In [21]:
seed="1"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.37

#### Seed 2

In [22]:
seed="2"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.77

#### M=25 samples

#### Seed 0

In [23]:
seed="0"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.08

#### Seed 1

In [24]:
seed="1"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

83.57

#### Seed 2

In [25]:
seed="2"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.22

#### M=50 samples

#### Seed 0

In [26]:
seed="0"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.69

#### Seed 1

In [27]:
seed="1"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.37

#### Seed 2

In [28]:
seed="2"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.94

#### M=75 samples

#### Seed 0

In [29]:
seed="0"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

170500096it [00:30, 17695885.57it/s]                               

84.98

#### Seed 1

In [30]:
seed="1"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.53

#### Seed 2

In [31]:
seed="2"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

85.0

#### M=100 samples

#### Seed 0

In [32]:
seed="0"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

85.27

#### Seed 1

In [33]:
seed="1"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

84.73

#### Seed 2

In [34]:
seed="2"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-WideResNet-40-2-Student-WideResNet-16-1-CIFAR10-M-'+M+'-Zero-Shot-seed-'+seed+'.pth', map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy  

85.35