# Train Teacher Models

In [167]:
# import libraries
!pip install pthflops
import torch
import torch.nn as nn
#from fvcore.nn import FlopCountAnalysis
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import time
import copy
from torchvision.utils import save_image
import os
from pthflops import count_ops
# import utility classes
import networks
import utils

# initialize args
args = type('', (), {})()
args.batch_train = 256
args.batch_real = 256
args.epoch_eval_train = 10  # the number of epoch to train the student model with synthetic dataset 
args.epoch_train = 30 # the number of epoch to train the teacher model
args.lr_net = 0.01 # learning rate of the network
args.lr_img = 0.1 # learning rate of synthetic dataset
args.dsa_strategy = None
args.num_eval = 1
args.num_exp = 1 # number of experiments
args.ipc = 10 # image per class
args.outer_loop, args.inner_loop = utils.get_loops(args.ipc)
args.batch_real = 256 # batch size for original dataset
args.dsa = None
args.method = 'DC'
args.Iteration = 1000 # training iterations
args.model = 'MLP'
args.dataset = 'MNIST'
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = utils.ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False
args.dc_aug_param = None
args.dis_metric = 'ours' # distance metric
args.data_path = 'data'
args.save_path = 'result'

# Load dataset
# no augmentation
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(args.dataset, './data')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [168]:
#function printing expert trajectory
def trajectory_dataset(it_eval, net, dst_train, testloader, args):
    net = net.to(args.device)
    trajectory = []
    # images_train = images_train.to(args.device)
    # labels_train = labels_train.to(args.device)
    lr = float(args.lr_net)
    Epoch = int(args.epoch_train)
    lr_schedule = [Epoch//2+1]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss().to(args.device)

    # dst_train = TensorDataset(images_train, labels_train)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    start = time.time()
    for ep in range(Epoch+1):
        loss_train, acc_train = utils.epoch('train', trainloader, net, optimizer, criterion, args, aug = True)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

        #record training trajectory
        trajectory_layers = []
        trajectory_layers.append(net.fc_1.weight)
        trajectory_layers.append(net.fc_2.weight)
        trajectory_layers.append(net.fc_3.weight)
        trajectory.append(copy.deepcopy(trajectory_layers))
        #print(sum(sum(trajectory_layers[2])))


    #time_train = time.time() - start
    loss_test, acc_test = utils.epoch('test', testloader, net, optimizer, criterion, args, aug = False)
    print('Evaluate_%02d: epoch = %04d train loss = %.6f train acc = %.4f, test acc = %.4f' % (it_eval, Epoch, loss_train, acc_train, acc_test))

    return net, acc_train, acc_test, trajectory

In [169]:
# load and train 5 teacher netwrok (30epochs
it_eval = 1
t_net=[]
expert_trajectory=[]
for i in range(5):
  t_net.append(utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device))
  _, train_acc, test_acc, expert_trajectory_single = trajectory_dataset(it_eval, t_net[i], dst_train, testloader, args)
  expert_trajectory.append(expert_trajectory_single)

#accurracy = []


#accurracy.append(test_acc)

Evaluate_01: epoch = 0030 train loss = 0.019926 train acc = 0.9967, test acc = 0.9793
Evaluate_01: epoch = 0030 train loss = 0.019178 train acc = 0.9969, test acc = 0.9796
Evaluate_01: epoch = 0030 train loss = 0.019325 train acc = 0.9969, test acc = 0.9795
Evaluate_01: epoch = 0030 train loss = 0.019135 train acc = 0.9972, test acc = 0.9798
Evaluate_01: epoch = 0030 train loss = 0.018763 train acc = 0.9971, test acc = 0.9801


# Train Student

In [172]:
# initialize synthetic dataset: randomly or gaussian noise
images_syn = torch.randn(
    size = (10 * num_classes, channel, im_size[0], im_size[1]),
    dtype = torch.float,
    requires_grad=True, 
    device=args.device
)
labels_syn = torch.tensor(
    [np.ones(10) * n for n in range(num_classes)], 
    dtype=torch.long, 
    requires_grad=False, 
    device=args.device
).view(-1)
dst_train = utils.TensorDataset(images_syn, labels_syn)

"""


#network optimization function
def network_optimization(it_eval, net, images_train, labels_train, testloader, args):
    net = net.to(args.device)
    images_train = images_train.to(args.device)
    labels_train = labels_train.to(args.device)
    lr = float(args.lr_net)
    Epoch = int(args.epoch_eval_train)
    lr_schedule = [Epoch//2+1]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss().to(args.device)

    dst_train = TensorDataset(images_train, labels_train)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    start = time.time()
    for ep in range(Epoch+1):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

    time_train = time.time() - start
    loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False)
    print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))

    return net, acc_train, acc_test"""

  labels_syn = torch.tensor(


"\n\n\n#network optimization function\ndef network_optimization(it_eval, net, images_train, labels_train, testloader, args):\n    net = net.to(args.device)\n    images_train = images_train.to(args.device)\n    labels_train = labels_train.to(args.device)\n    lr = float(args.lr_net)\n    Epoch = int(args.epoch_eval_train)\n    lr_schedule = [Epoch//2+1]\n    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)\n    criterion = nn.CrossEntropyLoss().to(args.device)\n\n    dst_train = TensorDataset(images_train, labels_train)\n    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)\n\n    start = time.time()\n    for ep in range(Epoch+1):\n        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True)\n        if ep in lr_schedule:\n            lr *= 0.1\n            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)\n

In [173]:
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchvision import models
from torchsummary import summary

import random

for u_e in range(3): #syn_set update_epoch

    # prints a random start epoch
    t = random.choice([3,4,5,6,7,8,9,10,11,12]) #using name t in accordance with the paper
    teacher = random.choice([0,1,2,3,4])

    # load and train netwrok on synthetic dataset
    s_net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)
    s_net.fc_1.weight=copy.copy(expert_trajectory[teacher][t][0])
    s_net.fc_2.weight=copy.copy(expert_trajectory[teacher][t][1])
    s_net.fc_3.weight=copy.copy(expert_trajectory[teacher][t][2])


    #accurracy = []
    #it_eval = 1
    #_, train_acc, test_acc = utils.evaluate_synset(it_eval, s_net, images_syn, labels_syn, testloader, args)
    #accurracy.append(test_acc)
    #print(labels_syn)

    #Loss function parameters
    #theta t+N
    student_vector=torch.cat((torch.reshape(s_net.fc_1.weight, (-1,)),torch.reshape(s_net.fc_2.weight, (-1,)),torch.reshape(s_net.fc_3.weight, (-1,))),0)

    #theta t
    print("starting epoch=",t)
    print("teacher model=",teacher)
    start=expert_trajectory[teacher][t]
    start[0]=torch.reshape(start[0], (-1,))
    start[1]=torch.reshape(start[1], (-1,))
    start[2]=torch.reshape(start[2], (-1,))
    start_vector=torch.cat((start[0],start[1],start[2]),0)

    #theta t+M
    stop=expert_trajectory[teacher][t+15]
    stop[0]=torch.reshape(stop[0], (-1,))
    stop[1]=torch.reshape(stop[1], (-1,))
    stop[2]=torch.reshape(stop[2], (-1,))
    stop_vector=torch.cat((stop[0],stop[1],stop[2]),0)

    #change of trajectory in teacher model
    change=start_vector-stop_vector 
    #difference with target trajectory
    diff=student_vector-stop_vector

    #numerator=sum(diff*diff)
    #denominator=sum(change*change)
    #Loss = numerator/denominator
    #print("num=",numerator, "\nden=",denominator,"\nloss=", Loss)

    #s_net.requires_grad=False
    optimizer = torch.optim.SGD(      [images_syn, ],       lr=args.lr_img,       momentum=0.5) 

    #print(sum(diff*diff))
    #print(sum(change*change))
    loss = sum(torch.square(diff))/sum(torch.square(change))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

starting epoch= 4
teacher model= 2
starting epoch= 3
teacher model= 4
starting epoch= 11
teacher model= 4


In [177]:
#test accuracy on real data
args.dataset = 'MNIST'
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(args.dataset, args.data_path)
_, acc_train, acc_test = utils.evaluate_dataset(it_eval,s_net,dst_train,testloader,args)
print("test with real dataset, accuracy = %.4f"%(acc_test))

[2022-12-13 06:10:02] Evaluate_01: epoch = 0030 train time = 395 s train loss = 0.008933 train acc = 0.9994, test acc = 0.9808
test with real dataset, accuracy = 0.9808


In [None]:
"""#f = torch.nn.Linear((28*28*1), 128)
f1=s_net.fc_1
f1.requires_grad_(True)
f2=s_net.fc_2
f2.requires_grad_(True)
f3=s_net.fc_3
f3.requires_grad_(True)


x = torch.reshape(images_syn[1][0].cpu().detach(),(-1,))
x.requires_grad_(True)


optim = torch.optim.SGD([x], lr=1e-1)
mse = torch.nn.MSELoss()
y = torch.ones(10)  # the desired network response

num_steps = 10  # how many optim steps to take
for _ in range(num_steps):
   print(x)
   #print("\nf(x)=f",f(x))
   student_vector=torch.cat((torch.reshape(f1.weight, (-1,)),torch.reshape(f2.weight, (-1,)),torch.reshape(f3.weight, (-1,))),0)
   #print(x.grad)


   #loss = sum(student_vector-stop_vector)
   loss = sum(f3(f2(f1(x)))-y)
   print("l1",loss)
  #print("      l2      ",sum(student_vector-stop_vector))
   loss.backward()
   optim.step()
   optim.zero_grad()"""

tensor([-0.9931, -0.5961, -0.6495,  0.0716,  1.7227,  1.5579, -1.5856,  2.6244,
         1.9461,  2.0679,  0.8432, -1.6464, -1.1372, -0.2274, -0.1564,  1.2998,
         0.1342,  1.3755, -1.4324, -0.4696, -0.6495,  0.1275,  0.5144,  0.3751,
        -0.2503,  0.4796,  1.0519, -1.3087,  1.9546,  0.5875, -0.7739,  1.1205,
         0.5599,  1.0647, -0.0527,  0.9240,  0.3069,  2.5498, -0.8826,  0.3722,
        -0.0129,  0.5585,  2.5793,  0.2311,  2.3806,  1.1239,  0.6050, -1.7025,
        -2.0710,  1.1498, -0.4500, -0.0157, -1.0158,  0.3019,  0.4190, -0.8803,
        -1.6326,  0.9991,  0.6456,  0.8690, -1.3480,  1.6841,  0.1612, -0.1131,
         1.0756,  0.1031, -0.7188,  0.5213,  0.8304,  0.4944,  0.8884,  0.2142,
        -0.8121, -1.1202, -0.8186,  0.2136,  0.7296, -0.0669,  2.1977,  0.8094,
        -0.0854,  2.0282,  0.3520,  0.9850,  0.7189,  0.0116,  0.3542,  0.4822,
         1.1538,  2.2306, -0.2048, -0.0144,  1.0107,  1.4280,  0.1073,  0.9973,
        -1.0682,  0.5159,  0.9627,  0.09