#### Get Scratches from Google Drive

In [1]:
!wget -O KD_ATT --no-check-certificate https://www.dropbox.com/sh/qhrb5l4btiqdwit/AAC8cyHArPTkiP_Fn9sADGCia
!unzip KD_ATT
!rm KD_ATT

--2019-11-06 19:04:56--  https://www.dropbox.com/sh/qhrb5l4btiqdwit/AAC8cyHArPTkiP_Fn9sADGCia
Resolving www.dropbox.com (www.dropbox.com)... 162.125.8.1, 2620:100:601b:1::a27d:801
Connecting to www.dropbox.com (www.dropbox.com)|162.125.8.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /sh/raw/qhrb5l4btiqdwit/AAC8cyHArPTkiP_Fn9sADGCia [following]
--2019-11-06 19:04:56--  https://www.dropbox.com/sh/raw/qhrb5l4btiqdwit/AAC8cyHArPTkiP_Fn9sADGCia
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc4878ce657fc7d7df0519721506.dl.dropboxusercontent.com/zip_by_token_key?key=Ar0SKOCnN6hJy1RF-_AsW7FBcM5koiDEc4IIv4jKxOsmaUh39XIu0fGfUfj6MIJ2a6_yfQbkVFlPdiL5DW3bIoEFFLLvtrIq79MD-ZzAmusiLLpD3wN1wi3JXjKILVWLOrlwoAxHa5g0Xo1jDXzJWEthPfKYp4vpyyfsvGH9_uMxi0LJdn6hJ5S-PF8cVteRj2SgoDMF1oAxinHrnUFf_x01 [following]
--2019-11-06 19:04:56--  https://uc4878ce657fc7d7df0519721506.dl.dropboxusercontent.

In [0]:
import torch
import torchvision
from torch import nn as nn
from torchvision import transforms
device = torch.device('cuda:0')

In [0]:
class IndividualBlock1(nn.Module):

    def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(IndividualBlock1, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)

        self.subsample_input = subsample_input
        self.increase_filters = increase_filters
        if subsample_input:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
        elif increase_filters:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):

        if self.subsample_input or self.increase_filters:
            x = self.batch_norm1(x)
            x = self.activation(x)
            x1 = self.conv1(x)
        else:
            x1 = self.batch_norm1(x)
            x1 = self.activation(x1)
            x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        if self.subsample_input or self.increase_filters:
            return self.conv_inp(x) + x1
        else:
            return x + x1


class IndividualBlockN(nn.Module):

    def __init__(self, input_features, output_features, stride):
        super(IndividualBlockN, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)

    def forward(self, x):
        x1 = self.batch_norm1(x)
        x1 = self.activation(x1)
        x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        return x1 + x


class Nblock(nn.Module):

    def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
        super(Nblock, self).__init__()

        layers = []
        for i in range(N):
            if i == 0:
                layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
            else:
                layers.append(IndividualBlockN(output_features, output_features, stride=1))

        self.nblockLayer = nn.Sequential(*layers)

    def forward(self, x):
        return self.nblockLayer(x)


class WideResNet(nn.Module):

    def __init__(self, d, k, n_classes, input_features, output_features, strides):
        super(WideResNet, self).__init__()

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)

        filters = [16 * k, 32 * k, 64 * k]
        self.out_filters = filters[-1]
        N = (d - 4) // 6
        increase_filters = k > 1
        self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
        self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
        self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])

        self.batch_norm = nn.BatchNorm2d(filters[-1])
        self.activation = nn.ReLU(inplace=True)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(filters[-1], n_classes)

    def forward(self, x):

        x = self.conv1(x)
        attention1 = self.block1(x)
        attention2 = self.block2(attention1)
        attention3 = self.block3(attention2)
        out = self.batch_norm(attention3)
        out = self.activation(out)
        out = self.avg_pool(out)
        out = out.view(-1, self.out_filters)

        return self.fc(out), attention1, attention2, attention3

In [0]:
def _test_set_eval(net, device, test_loader):
    with torch.no_grad():
        correct, total = 0, 0
        net.eval()

        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = net(images)[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        accuracy = round(100 * accuracy, 2)

        return accuracy

# SVHN

In [5]:
svhn_normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))

svhn_transform = transforms.Compose([
    transforms.ToTensor(),
    svhn_normalize,
])

svhn_test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=svhn_transform)
svhn_test_loader = testloader = torch.utils.data.DataLoader(svhn_test_set, batch_size=100, shuffle=False, num_workers=0)

  0%|          | 16384/64275384 [00:00<06:54, 154942.15it/s]

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data/test_32x32.mat


64282624it [00:01, 55862559.52it/s]                             


#### KD-ATT - Teacher WRN-40-2 Student WRN-16-1 for different values of M

#### KD-ATT for M=10 samples

#### Seed 0

In [6]:
seed="0"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

37.35

#### Seed 1

In [7]:
seed="1"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

31.32

#### Seed 2

In [8]:
seed="2"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth', map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

33.88

#### KD-ATT for M=25 samples

#### Seed 0

In [9]:
seed="0"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy                            

48.71

#### Seed 1

In [10]:
seed="1"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

48.89

#### Seed 2

In [11]:
seed="2"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy  

47.44

#### KD-ATT for M=50 samples

#### Seed 0

In [12]:
seed="0"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

68.84

#### Seed 1

In [13]:
seed="1"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

65.33

#### Seed 2

In [14]:
seed="2"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

66.48

#### KD-ATT for M=75 samples

#### Seed 0

In [15]:
seed="0"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

78.51

#### Seed 1

In [16]:
seed="1"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

78.4

#### Seed 2

In [17]:
seed="2"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

79.28

#### KD-ATT for M=100 samples

#### Seed 0

In [18]:
seed="0"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

81.18

#### Seed 1

In [19]:
seed="1"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

79.63

#### Seed 2

In [20]:
seed="2"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth'))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

81.45

#### KD-ATT for M=5000 samples

#### Seed 0

In [21]:
seed="0"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

95.19

#### Seed 1

In [22]:
seed="1"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

95.44

#### Seed 2

In [23]:
seed="2"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./SVHN/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-SVHN-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, svhn_test_loader)
accuracy 

95.72

# CIFAR10

In [24]:
cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

cifar_transform = transforms.Compose([
  transforms.ToTensor(),
  cifar_normalize,
])

cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)
cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=100, shuffle=False, num_workers=0)

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:03, 44170661.31it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data


#### KD-ATT - Teacher WRN-40-2 Student WRN-16-1 for different values of M

#### KD-ATT for M=10 samples

#### Seed 0

In [25]:
seed="0"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

39.08

#### Seed 1

In [26]:
seed="1"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

35.33

#### Seed 2

In [27]:
seed="2"
M="10"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

36.49

#### KD-ATT for M=25 samples

#### Seed 0

In [28]:
seed="0"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

60.05

#### Seed 1

In [29]:
seed="1"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

58.94

#### Seed 2

In [30]:
seed="2"
M="25"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

63.05

#### KD-ATT for M=50 samples

#### Seed 0

In [31]:
seed="0"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

70.9

#### Seed 1

In [32]:
seed="1"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

65.83

#### Seed 2

In [33]:
seed="2"
M="50"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

68.68

#### KD-ATT for M=75 samples

#### Seed 0

In [34]:
seed="0"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

73.84

#### Seed 1

In [35]:
seed="1"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

74.29

#### Seed 2

In [36]:
seed="2"
M="75"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

77.0

#### KD-ATT for M=100 samples

#### Seed 0

In [37]:
seed="0"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

76.67

#### Seed 1

In [38]:
seed="1"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

76.72

#### Seed 2

In [39]:
seed="2"
M="100"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

79.57

#### KD-ATT for M=200 samples

#### Seed 0

In [40]:
seed="0"
M="200"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

82.41

#### Seed 1

In [41]:
seed="1"
M="200"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

81.97

#### Seed 2

In [42]:
seed="2"
M="200"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

84.18

#### KD-ATT for M=5000 samples

#### Seed 0

In [43]:
seed="0"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

92.15

#### Seed 1

In [44]:
seed="1"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

92.25

#### Seed 2

In [45]:
seed="2"
M="5000"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/WRN-40-2-WRN-16-1/M-'+M+'/Seed-'+seed+'/kd_att_teacher_wrn-40-2_student_wrn-16-1-M-'+M+'-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

92.17

#### KD-ATT - Teacher WRN-16-2 Student WRN-16-1 M = 200

#### Seed 0

In [46]:
seed="0"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

85.51

#### Seed 1

In [47]:
seed="1"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

85.26

#### Seed 2

In [48]:
seed="2"
Teacher="WRN-16-2"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

85.89

#### KD-ATT - Teacher WRN-40-1 Student WRN-16-1 M = 200

#### Seed 0

In [49]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

83.9

#### Seed 1

In [50]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

83.35

#### Seed 2

In [51]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-1"
d = 16
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

83.67

#### KD-ATT - Teacher WRN-40-1 Student WRN-16-2 M = 200

#### Seed 0

In [52]:
seed="0"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

87.52

#### Seed 1

In [53]:
seed="1"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

87.14

#### Seed 2

In [54]:
seed="2"
Teacher="WRN-40-1"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

87.13

#### KD-ATT - Teacher WRN-40-2 Student WRN-16-2 M = 200

#### Seed 0

In [55]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

87.15

#### Seed 1

In [56]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

86.49

#### Seed 2

In [57]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-16-2"
d = 16
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

88.18

#### KD-ATT - Teacher WRN-40-2 Student WRN-40-1 M = 200

#### Seed 0

In [58]:
seed="0"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

88.18

#### Seed 1

In [59]:
seed="1"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

87.77

#### Seed 2

In [60]:
seed="2"
Teacher="WRN-40-2"
Student="WRN-40-1"
d = 40
k = 1
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides).to(device)
net.load_state_dict(torch.load('./CIFAR10/'+Teacher+'-'+Student+'-M200/Seed-'+seed+'/kd_att_teacher_'+Teacher.lower()+'_student_'+Student.lower()+'-M-200-seed-'+seed+'-CIFAR10-dict.pth',map_location=device))
accuracy = _test_set_eval(net, device, cifar_testloader)
accuracy 

89.29