In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class BasicBlockUnit( nn.Module ):
    def __init__( self, ni, no, stride=1, convdim=False, segment=1, i=None, rd_buf=None ):
        super( BasicBlockUnit, self ).__init__()
        self.segment = segment
        self.ni      = ni
        self.no      = no
        self.id    = i
        self.bn0   = nn.BatchNorm2d(int(ni/self.segment))
        self.relu0 = nn.ReLU(inplace=True)
        self.conv0 = nn.Conv2d( ni, int(no/self.segment), 3, padding=1, stride=stride, bias=False )
        self.bn1   = nn.BatchNorm2d(int(no/self.segment))
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d( no, int(no/self.segment), 3, stride=1, padding=1, bias=False )
        self.convdim = nn.Conv2d(ni,int(no/self.segment),1,stride=stride, bias=False) if convdim else None
        self.rd_buf = rd_buf

    def forward( self, x ):
        res = x
        x = self.bn0  (x)
        x = self.relu0(x)

        x1 = self.conv0(x)
        x1 = self.bn1  (x1)
        x1 = self.relu1(x1)

        x2 = self.conv1(x1)
        if self.convdim is not None:
            res = self.convdim(x)
        x = x2 + res

        return x

    def set_param( self, param, prefix ):
        i = self.id
        if self.convdim is not None:
            assert( self.convdim.weight.shape == nn.Parameter(param[ prefix+'.convdim' ]).shape )
            self.convdim.weight = nn.Parameter(param[ prefix+'.convdim' ])
        assert( self.bn0.weight.shape == nn.Parameter(param[ prefix+'.bn0.weight' ]).shape)
        self.bn0.weight   = nn.Parameter(param[ prefix+'.bn0.weight' ])
        assert( self.bn0.bias.shape == nn.Parameter(param[ prefix+'.bn0.bias' ]).shape)
        self.bn0.bias     = nn.Parameter(param[ prefix+'.bn0.bias' ])
        assert( self.conv0.weight.shape == nn.Parameter(param[ prefix+'.conv0' ]).shape)
        self.conv0.weight = nn.Parameter(param[ prefix+'.conv0'])
        assert( self.bn1.weight.shape == nn.Parameter(param[ prefix+'.bn1.weight' ]).shape)
        self.bn1.weight   = nn.Parameter(param[ prefix+'.bn1.weight' ])
        assert( self.bn1.bias.shape == nn.Parameter(param[ prefix+'.bn1.bias' ]).shape)
        self.bn1.bias     = nn.Parameter(param[ prefix+'.bn1.bias' ])
        assert( self.conv1.weight.shape == nn.Parameter(param[ prefix+'.conv1' ]).shape)
        self.conv1.weight = nn.Parameter(param[ prefix+'.conv1'])

    def set_stats( self, stats, prefix ):
        i = self.id
        self.bn0.running_mean = nn.Parameter(stats[ prefix+'.bn0.running_mean' ])
        self.bn0.running_var = nn.Parameter(stats[ prefix+'.bn0.running_var' ])
        self.bn1.running_mean = nn.Parameter(stats[ prefix+'.bn1.running_mean' ])
        self.bn1.running_var = nn.Parameter(stats[ prefix+'.bn1.running_var' ])

    def set_requires_grad( self, val ):
        for para in self.parameters():
            para.requires_grad = val


def AvgPool2d_in_conv( n, kernel_size ):
    avg = nn.Conv2d( n, n, kernel_size, bias=False)
    avg_weight = torch.zeros_like(avg.weight)
    for i in range(n):
        avg_weight[i][i] = torch.ones_like(avg.weight[0][0])
    avg.weight = nn.Parameter(avg_weight)
    return avg

In [3]:
class Group(nn.Module):
    def __init__( self, ni, no, n, stride=1, segment=1, id=None ):
        super( Group, self ).__init__()
        self.blocks = []
        self.blocks.append( BasicBlockUnit(ni,no,stride,convdim=True,segment=segment,i=id) )
        for _ in range(1,n):
            self.blocks.append( BasicBlockUnit(no,no,segment=segment,i=id) )
    def forward( self, x ):
        for b in self.blocks:
            x = b.forward(x)
        return x
    def set_param( self, param, prefix ):
        for i, b in enumerate(self.blocks):
            b.set_param( param, prefix + '.block' + str(i) )
    def set_stats( self, stats, prefix ):
        for i, b in enumerate(self.blocks):
            b.set_stats( stats, prefix + '.block' + str(i) )

    def eval( self ):
        for b in self.blocks:
            b.eval()
        return self.train(False)

    def set_requires_grad( self, val ):
        for b in self.blocks:
            b.set_requires_grad(val)

In [27]:
class WRN_extract(nn.Module):
    def __init__(self, n=2, num_classes=10, batch_size=1, no_avgpool=True, weights=None, stats=None, widths=None, segment=1, reduce_ip='127.0.0.1', reduce_port=None, id=0, id_stu=0, students=None):

        super(WRN_extract,self).__init__()
        self.id = id
        self.batch_size = batch_size
        self.segment = segment
        self.id_stu  = id_stu

        self.conv0 = nn.Conv2d(3,int(16/segment),3, padding=1, bias=False)
        self.groups = []
        self.g0 = Group(16, int(widths[0]), n=n, segment=segment, id=id)
        self.g1 = Group( int(widths[0]), no=int(widths[1]), n=n, stride=2 )
        self.g2 = Group( int(widths[1]), no=int(widths[2]), n=n, stride=2 )
        self.conv_g2_dimComm = nn.Conv2d( int(widths[2]), int(widths[3]), 1, bias=False)

        assert(id_stu<students)
        if students == 8: 
            g0index = chr( ord('a')+int(id_stu/2) )
            gindex  = chr( ord('a')+id_stu )
#             print("%d:(%c,%c)"%(id_stu,g0index,gindex))
            self.groups_info = []
            self.groups_info.append(('group0'+g0index, self.g0))
            self.groups_info.append(('group1'+gindex,  self.g1))
            self.groups_info.append(('group2'+gindex,  self.g2))
        elif students == 2: 
            gindex = chr( ord('a')+id_stu )
            self.groups_info = []
            self.groups_info.append(('group0',self.g0))
            self.groups_info.append(('group1'+gindex,self.g1))
            self.groups_info.append(('group2'+gindex,self.g2))

        if weights is not None:
            self.load_weight(weights, prefix='student.')

        if stats is not None:
            self.load_stats(stats, prefix='student.')

    def __del__(self):
        if hasattr( self, 'c' ):
            print("Close the reduce port")
            self.c.close()

    def load_weight( self, param, prefix="" ):
        assert( self.conv0.weight.shape == param[prefix+'conv0'].shape)
        self.conv0.weight = nn.Parameter(param[prefix+'conv0'])
        g0index = chr( ord('a')+self.id_stu )
        assert( self.conv_g2_dimComm.weight.shape == param[prefix+'conv_g2'+g0index+'_dimComm'].shape)
        self.conv_g2_dimComm.weight = nn.Parameter(param[prefix+'conv_g2'+g0index+'_dimComm'])
        for sub_prefix, g in self.groups_info:
            g.set_param( param, prefix+sub_prefix )

    def load_stats( self, stats, prefix="" ):
        for sub_prefix, g in self.groups_info:
            g.set_stats( stats, prefix+sub_prefix )

    def forward( self, x ):

        x = self.conv0(x)
        x = self.g0.forward(x)
        x = self.g1.forward(x)
        x = self.g2.forward(x)
        x = self.conv_g2_dimComm(x)
        return x

    def set_requires_grad( self, val ):
        for g in [self.g0, self.g1, self.g2]:
            g.set_requires_grad(val)
        for para in self.parameters():
            para.requires_grad = val
        return

    def eval( self ):
        for g in [self.g0, self.g1, self.g2]:
            g.eval()
        return self.train(False)

In [5]:
class WRN_fc(nn.Module):

    def __init__(self, n=6, num_classes=10, batch_size=1, no_avgpool=True, weights=None, stats=None, widths=torch.Tensor([16,32,64]).mul(4), students=None):
        super(WRN_fc,self).__init__()

        if students == 2:
            widths = torch.Tensor([32, 64, 64, 87, 87, 34, 46]).int()
        elif students == 8:
            widths = torch.Tensor([32, 64, 64, 87, 87, 34*4, 46*4]).int()
        else:
            raise ValueError()
        self.width    = int(widths[5]+widths[6])
        self.n = n
        self.batch_size = batch_size
        self.fc       = nn.Linear( self.width, num_classes )
        self.bn       = nn.BatchNorm2d(self.width)
        self.relu     = nn.ReLU(inplace=True)

        if no_avgpool:
            self.avg_pool = AvgPool2d_in_conv(int(widths[5]+widths[6]),8)
        else:
            self.avg_pool = nn.AvgPool2d(8,1,0)



        if weights is not None:
            self.load_weight(weights, prefix='student.')

        if stats is not None:
            self.load_stats(stats, prefix='student.')

    def load_weight( self, param, prefix ):
        self.fc.weight = nn.Parameter(param[prefix+'fc.weight'])
        self.fc.bias   = nn.Parameter(param[prefix+'fc.bias'])
        assert( self.bn.weight.shape == param[prefix+'bn.weight'].shape)
        self.bn.weight = nn.Parameter(param[prefix+'bn.weight'])
        assert( self.bn.bias.shape == param[prefix+'bn.bias'].shape)
        self.bn.bias   = nn.Parameter(param[prefix+'bn.bias'])

    def load_stats( self, stats, prefix ):
        assert( self.bn.running_mean.shape == stats[prefix+'bn.running_mean'].shape)
        self.bn.running_mean = nn.Parameter(stats[prefix+'bn.running_mean'])
        assert( self.bn.running_var.shape == stats[prefix+'bn.running_var'].shape)
        self.bn.running_var = nn.Parameter(stats[prefix+'bn.running_var'])
        return

    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        x = self.avg_pool(x)
        x = self.fc(x.view(self.batch_size,self.width))
        return x

    def set_requires_grad( self, val ):
        for para in self.parameters():
            para.requires_grad = val

In [30]:
class WRN(nn.Module):

    def __init__(self, n=6, num_classes=10, batch_size=1, no_avgpool=True, weights=None, stats=None, widths=torch.Tensor([16,32,64]).mul(4), students=None):
        super(WRN,self).__init__()
        self.extr = []
        if students == 2:
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=0, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=1, students=students))
        elif students == 8:
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=0, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=1, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=2, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=3, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=4, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=5, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=6, students=students))
            self.extr.append( WRN_extract( weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=7, students=students))
        self.fc    = WRN_fc(      weights=param['params'], stats=param['stats'], students=students )

    def forward(self, x):
        partial_x = []
        for extr in self.extr:
            partial_x.append(extr(x))
        x = torch.cat(tuple(partial_x), dim=1)
#         for i,xx in enumerate(partial_x):
#             print("%d:"%i, xx.shape, xx[0][0][0])
        return self.fc(x)

    def eval( self ):
        self.fc.eval()
        for m in self.extr:
            m.eval()

    def set_requires_grad( self, val ):
        self.fc.set_requires_grad(val)
        for m in self.extr:
            m.set_requires_grad(val)
        for para in self.parameters():
            para.requires_grad = val

In [51]:
class Args:
    batch_size = 1
    pt7_path   = None
    cifar      = ".."
    config     = 'config.xml'
    dataset    = 'CIFAR10'
args = Args()

# Model generation
## 2 students

In [52]:
verbose = False
args.output = 'wrn_2s_'
args.pt7_path = "/home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6a_WRN40-4_b=10K_a=0.9_norm_2S_fixFLOPS_run4/model.pt7"
param = torch.load(  args.pt7_path, map_location='cpu' )
extr = []
n_students = 2
for i in range(int(n_students/2)):
    extr.append( WRN_extract( batch_size=args.batch_size, weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=2*i+0, students=n_students))
    extr.append( WRN_extract( batch_size=args.batch_size, weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=2*i+1, students=n_students))
fc = WRN_fc(                  batch_size=args.batch_size, weights=param['params'], stats=param['stats'], students=n_students )

# Load parameters    
param = torch.load( args.pt7_path, map_location='cpu' )
for m in extr+[fc]:
    m.set_requires_grad(False)
    m.eval()

from torch.autograd import Variable
import torch.onnx

for i, m in enumerate(extr):
    dummy_input = Variable(torch.randn(args.batch_size, 3, 32, 32))
    torch.onnx.export(m, dummy_input, args.output+"_g"+str(i)+".onnx", verbose=verbose)

dummy_input = Variable(torch.randn(args.batch_size, 80, 8, 8))
torch.onnx.export(fc, dummy_input, args.output+"_fc.onnx", verbose=verbose)

## 8 students

In [53]:
verbose = False
args.output = 'wrn_8s_'
args.pt7_path = "/home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6n_quadruple6a_WRN40-4_b=10K_a=0.9_norm_8S_run5_fixFLOPS/model.pt7"
param = torch.load(  args.pt7_path, map_location='cpu' )
extr = []
n_students = 8
for i in range(int(n_students/2)):
    extr.append( WRN_extract( batch_size=args.batch_size, weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 34]), id_stu=2*i+0, students=n_students))
    extr.append( WRN_extract( batch_size=args.batch_size, weights=param['params'], stats=param['stats'], widths=torch.Tensor([32, 64, 87, 46]), id_stu=2*i+1, students=n_students))
fc = WRN_fc(                  batch_size=args.batch_size, weights=param['params'], stats=param['stats'], students=n_students )

# Load parameters    
param = torch.load( args.pt7_path, map_location='cpu' )
for m in extr+[fc]:
    m.set_requires_grad(False)
    m.eval()

from torch.autograd import Variable
import torch.onnx

for i, m in enumerate(extr):
    dummy_input = Variable(torch.randn(args.batch_size, 3, 32, 32))
    torch.onnx.export(m, dummy_input, args.output+"_8s_g"+str(i)+".onnx", verbose=verbose)

dummy_input = Variable(torch.randn(args.batch_size, 320, 8, 8))
torch.onnx.export(fc, dummy_input, args.output+"_fc.onnx", verbose=verbose)

# Inference
## Argument Setting

In [32]:
class Args:
    batch_size = 1
    pt7_path   = "/home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6a_WRN40-4_b=10K_a=0.9_norm_2S_fixFLOPS_run4/model.pt7"
#     pt7_path   = " /home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6a_WRN40-4_b=10K_a=0.9_norm_2S_fixFLOPS_run4/model.pt7"
    cifar      = ".."
    config     = 'config.xml'
    dataset    = 'CIFAR10'
    to_onnx    = False
    output     = 'wrn'
    inference  = True
args = Args()

## Utility Function

In [33]:
def get_dataset( dataset_name ):
    if dataset_name == "CIFAR10":
        import torchvision
        dataset = torchvision.datasets.CIFAR10( args.cifar, train=False, download=True)
    return dataset

def normalize_unsqueeze( img, mean, std ):
    import torch
    import torchvision
    to_tensor = torchvision.transforms.ToTensor()
    img_tensor = to_tensor(img).transpose(0,2)
    img_tensor = img_tensor-(torch.Tensor(mean)/256)
    img_tensor = img_tensor/(torch.Tensor( std)/256)
    img_tensor = img_tensor.transpose(0,2)
    return img_tensor.unsqueeze(0)

## 2 students

In [56]:
args.pt7_path = "/home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6a_WRN40-4_b=10K_a=0.9_norm_2S_fixFLOPS_run4/model.pt7"
param = torch.load(  args.pt7_path, map_location='cpu' )
model = WRN( weights=param['params'], stats=param['stats'], students=2 )
model.set_requires_grad(False)
model.eval()

correct, total = 0, 0
cifar = get_dataset( args.dataset )
for batch_idx, (inputs, targets) in enumerate(cifar):
    normed    = normalize_unsqueeze(inputs, [125.3, 123.0, 113.9], [63, 62.1, 66.7])
    scrs = model(normed)
    predicted = np.argmax(scrs.numpy())
    total   += 1
    correct += predicted==targets
    print(correct,'/',total)

Files already downloaded and verified
1 / 1
2 / 2
3 / 3
4 / 4
5 / 5
6 / 6
7 / 7
8 / 8
9 / 9
10 / 10
11 / 11
12 / 12
13 / 13
14 / 14
15 / 15
15 / 16
16 / 17
17 / 18
18 / 19
19 / 20
20 / 21
21 / 22
22 / 23
23 / 24
24 / 25
25 / 26
26 / 27
27 / 28
28 / 29
29 / 30
30 / 31
31 / 32
32 / 33
32 / 34
33 / 35
34 / 36
35 / 37
35 / 38
36 / 39
37 / 40
38 / 41
39 / 42
40 / 43
41 / 44
42 / 45
43 / 46
44 / 47
45 / 48
46 / 49
47 / 50
48 / 51
49 / 52
49 / 53
50 / 54
51 / 55
52 / 56
53 / 57
54 / 58
54 / 59
54 / 60
55 / 61
55 / 62
56 / 63
57 / 64
58 / 65
59 / 66
60 / 67
61 / 68
62 / 69
63 / 70
64 / 71
65 / 72
66 / 73
67 / 74
68 / 75
69 / 76
70 / 77
71 / 78
72 / 79
73 / 80
74 / 81
75 / 82
76 / 83
77 / 84
78 / 85
79 / 86
80 / 87
81 / 88
82 / 89
83 / 90
84 / 91
85 / 92
86 / 93
87 / 94
88 / 95
89 / 96
90 / 97
91 / 98
92 / 99
93 / 100
94 / 101
95 / 102
96 / 103
97 / 104
98 / 105
99 / 106
100 / 107
101 / 108
102 / 109
103 / 110
104 / 111
105 / 112
106 / 113
107 / 114
108 / 115
109 / 116
110 / 117
111 / 118
111 /

KeyboardInterrupt: 

## 8 students

In [57]:
args.pt7_path = "/home/chingyi/Downloads/NoNN_fixFLOPS_RPI_models/logs/ST_swrn_v6n_quadruple6a_WRN40-4_b=10K_a=0.9_norm_8S_run5_fixFLOPS/model.pt7"
param = torch.load(  args.pt7_path, map_location='cpu' )
model = WRN( weights=param['params'], stats=param['stats'], students=8 )
model.set_requires_grad(False)
model.eval()

correct, total = 0, 0
cifar = get_dataset( args.dataset )
for batch_idx, (inputs, targets) in enumerate(cifar):
    normed    = normalize_unsqueeze(inputs, [125.3, 123.0, 113.9], [63, 62.1, 66.7])
    scrs = model(normed)
    predicted = np.argmax(scrs.numpy())
    total   += 1
    correct += predicted==targets
    print(correct,'/',total)

Files already downloaded and verified
1 / 1
2 / 2
3 / 3
4 / 4
5 / 5
6 / 6
7 / 7
8 / 8
9 / 9
10 / 10
11 / 11
12 / 12
13 / 13
14 / 14
15 / 15
16 / 16
17 / 17
18 / 18
19 / 19
20 / 20
21 / 21
22 / 22
23 / 23
24 / 24
25 / 25
26 / 26
27 / 27
28 / 28


KeyboardInterrupt: 