In [121]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [122]:
class experiment(nn.Module):
    def __init__(self):
        super(experiment, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(3,4),
            nn.ReLU(),
        )
        self.fc2 = nn.Linear(4,3)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

net = experiment()
net2 = experiment()
net3 = experiment()

for params in net.parameters():
    nn.init.normal_(params, mean=0, std=0.01)
    
net2.load_state_dict(net.state_dict())
net3.load_state_dict(net.state_dict())

print(net)

experiment(
  (fc1): Sequential(
    (0): Linear(in_features=3, out_features=4, bias=True)
    (1): ReLU()
  )
  (fc2): Linear(in_features=4, out_features=3, bias=True)
)


In [123]:
print('net1 fc2.weight')
print(net.state_dict()['fc2.weight'])
print('net2 fc2.weight')
print(net2.state_dict()['fc2.weight'])
print('net3 fc2.weight')
print(net3.state_dict()['fc2.weight'])

net1 fc2.weight
tensor([[ 0.0118, -0.0146, -0.0047,  0.0008],
        [-0.0121, -0.0112,  0.0092, -0.0127],
        [-0.0025,  0.0031, -0.0149, -0.0013]])
net2 fc2.weight
tensor([[ 0.0118, -0.0146, -0.0047,  0.0008],
        [-0.0121, -0.0112,  0.0092, -0.0127],
        [-0.0025,  0.0031, -0.0149, -0.0013]])
net3 fc2.weight
tensor([[ 0.0118, -0.0146, -0.0047,  0.0008],
        [-0.0121, -0.0112,  0.0092, -0.0127],
        [-0.0025,  0.0031, -0.0149, -0.0013]])


In [124]:
loss_func = nn.CrossEntropyLoss()
# loss_func2 = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)
optimizer2 = optim.SGD(net2.parameters(), lr=0.1)
optimizer3 = optim.SGD(net3.parameters(), lr=0.1)
def my_cross_entropy(y, t):
    delta = 1e-7
#     print(y.size(0))
#     print("这是 my_cross_entropy 中传入的 y")
#     print(y)
#     print("这是 my_cross_entropy 中传入的 y，经过log处理")
#     print(torch.log(y))
    return -torch.sum(t * torch.log(y + delta)) / y.size(0)

In [125]:
def train_with_rule(net, loss_func, optimizer, data, labels, PI):
    student_outputs = net(data) 
#     print("this is student_outputs")
#     print(student_outputs)
    student_loss = loss_func(student_outputs, labels)
    print("这是student loss")
    print(student_loss)
    softmax = torch.nn.Softmax(dim=1)
    student_outputs_softmax = softmax(student_outputs)
#     print("this is softmax student_outputs")
#     print(student_outputs_softmax)
    teacher_q = student_outputs_softmax * math.exp(0.5)
#    print(teacher_q)
    teacher_loss = my_cross_entropy(student_outputs_softmax, teacher_q)
    print("这是teacher loss")
    print(teacher_loss)
#     teacher_loss = loss_func2(student_outputs, student_outputs*math.exp(0.5))
    loss = (1-PI) * student_loss + PI * teacher_loss
    print("这是联合 loss")
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(net.state_dict()['fc2.weight'])

In [126]:
def train_with_norule(net, loss_func, optimizer, data, labels):
    student_outputs = net(data)    
    loss = loss_func(student_outputs, labels) # input(N,C) where C= number of classes Target(N)
    print("这是 loss")
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(net.state_dict()['fc2.weight'])

In [127]:
x_input = torch.randn(5, 3)
# print(x_input)
y_target = torch.empty(5, dtype=torch.long).random_(3)
train_with_rule(net, loss_func, optimizer, x_input, y_target, 0.8)
# print('net3 fc2.weight')
# train_with_norule(net3, loss_func, optimizer3, x_input, y_target)

这是student loss
tensor(1.0991, grad_fn=<NllLossBackward>)
这是teacher loss
tensor(1.8113, grad_fn=<DivBackward0>)
这是联合 loss
tensor(1.6689, grad_fn=<AddBackward0>)
tensor([[ 0.0118, -0.0144, -0.0047,  0.0009],
        [-0.0122, -0.0112,  0.0092, -0.0127],
        [-0.0024,  0.0030, -0.0149, -0.0013]])


In [128]:
print('net2 fc2.weight')
train_with_norule(net2, loss_func, optimizer2, x_input, y_target)

net2 fc2.weight
这是 loss
tensor(1.0991, grad_fn=<NllLossBackward>)
tensor([[ 0.0119, -0.0140, -0.0044,  0.0010],
        [-0.0126, -0.0115,  0.0091, -0.0128],
        [-0.0021,  0.0028, -0.0150, -0.0014]])


In [129]:
print('net3 fc2.weight')
train_with_norule(net3, loss_func, optimizer3, x_input, y_target)

net3 fc2.weight
这是 loss
tensor(1.0991, grad_fn=<NllLossBackward>)
tensor([[ 0.0119, -0.0140, -0.0044,  0.0010],
        [-0.0126, -0.0115,  0.0091, -0.0128],
        [-0.0021,  0.0028, -0.0150, -0.0014]])
