In [None]:
# pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
from torchvision import datasets, transforms, models
from tqdm import tqdm

def dataload(key, bs):
    '''data agumentaiton'''

    if key == 'HIGH10':
        traindir ='../data/train/'
        testdir = '../data/test/'
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        test_dataset = datasets.ImageFolder(
            testdir,
            transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                normalize,
            ]))
    if key == 'MNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_dataset = datasets.MNIST(root='MNIST', train=True, download=True,
                                       transform=transform)
        test_dataset = datasets.MNIST(root='MNIST', train=False,
                                      transform=transform)

    trainloader = torch.utils.data.DataLoader(train_dataset,batch_size=bs, shuffle=True)
    testloader = torch.utils.data.DataLoader(test_dataset,batch_size=bs, shuffle=True)

    return trainloader, testloader


In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from scipy import signal
import scipy
from torch import nn
import time


def dft_conv(imgR, imgIm, kernelR, kernelIm):
    # Fast complex multiplication
    #print(kernelR.shape, imgR.shape)
    ac = torch.mul(kernelR, imgR)
    bd = torch.mul(kernelIm, imgIm)

    ab_cd = torch.mul(torch.add(kernelR, kernelIm), torch.add(imgR, imgIm))
    # print(ab_cd.sum(1)[0,0,:,:])
    imgsR = ac - bd
    imgsIm = ab_cd - ac - bd

    # Sum over in channels
    imgsR = imgsR.sum(1)
    imgsIm = imgsIm.sum(1)

    return imgsR, imgsIm


class FFT_Conv_Layer(nn.Module):

    def __init__(self, imgSize, inCs, outCs1, outCs2, outCs3, imagDim, filtSize, cuda=False):
        super(FFT_Conv_Layer, self).__init__()
        self.filts1 = np.random.normal(0, 0.01, (1, inCs, outCs1, filtSize, filtSize, imagDim))
        self.filts2 = np.random.normal(0, 0.01, (1, outCs1, outCs2, filtSize, filtSize, imagDim))
        self.filts3 = np.random.normal(0, 0.01, (1, outCs2, outCs3, filtSize, filtSize, imagDim))
        self.imgSize = imgSize
        self.filtSize = np.size(self.filts1, 4)

        if cuda:
            self.filts1 = torch.from_numpy(self.filts1).type(torch.float32).cuda()
            self.filts1 = Parameter(self.filts1)

            self.filts2 = torch.from_numpy(self.filts2).type(torch.float32).cuda()
            self.filts2 = Parameter(self.filts2)

            self.filts3 = torch.from_numpy(self.filts3).type(torch.float32).cuda()
            self.filts3 = Parameter(self.filts3)

    def forward(self, imgs):
        # Pad and transform the image
        # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..)

        imgs = imgs.unsqueeze(2)
        imgs = imgs.unsqueeze(5)

        imgs = F.pad(imgs, (0, 0, 0, self.filtSize - 1, 0, self.filtSize - 1))
        imgs = imgs.squeeze(5)

        imgs = torch.rfft(imgs, 2, onesided=False)
        # print(imgs.shape)

        # Extract the real and imaginary parts
        imgsR = imgs[:, :, :, :, :, 0]
        imgsIm = imgs[:, :, :, :, :, 1]

        # Pad and transform the filters
        filts1 = F.pad(self.filts1, (0, 0, 0, self.imgSize - 1, 0, self.imgSize - 1))
        filts2 = F.pad(self.filts2, (0, 0, 0, self.imgSize - 1, 0, self.imgSize - 1))
        filts3 = F.pad(self.filts3, (0, 0, 0, self.imgSize - 1, 0, self.imgSize - 1))

        filts1 = torch.fft(filts1, 2)
        filts2 = torch.fft(filts2, 2)
        filts3 = torch.fft(filts3, 2)

        # Extract the real and imaginary parts
        filt1R = filts1[:, :, :, :, :, 0]
        filt1Im = filts1[:, :, :, :, :, 1]

        filt2R = filts2[:, :, :, :, :, 0]
        filt2Im = filts2[:, :, :, :, :, 1]

        filt3R = filts3[:, :, :, :, :, 0]
        filt3Im = filts3[:, :, :, :, :, 1]

        # Do element wise complex multiplication
        imgsR_old, imgsIm_old =imgsR, imgsIm 
        imgsR, imgsIm = dft_conv(imgsR, imgsIm, filt1R, filt1Im)
        imgsR = imgsR.unsqueeze(2)
        imgsIm = imgsIm.unsqueeze(2)
        imgsR, imgsIm = dft_conv(imgsR, imgsIm, filt2R, filt2Im)
        imgsR = imgsR.unsqueeze(2)
        imgsIm = imgsIm.unsqueeze(2)
        imgsR, imgsIm = dft_conv(imgsR, imgsIm, filt3R, filt3Im)
        # print('ref',imgsR.shape)


        # print(filt1R.shape, filt1Im.shape,filt2R.shape, filt2Im.shape)

        f12r, f12i = dft_conv(filt1R.view(1,32,1,30,30), filt1Im.view(1,32,1,30,30),filt2R, filt2Im)
        # print(f12r.shape)
        f12r = f12r.unsqueeze(2)
        f12i = f12i.unsqueeze(2)
        f123r, f123i = dft_conv(f12r.view(1,128,1,30,30), f12i.view(1,128,1,30,30),filt3R, filt3Im)
        f123r = f123r.unsqueeze(2)
        f123i = f123i.unsqueeze(2)
       
        # print(f123r.shape)
        imgsRnew, imgsImnew = dft_conv(imgsR_old, imgsIm_old,f123r.view(1,1,256,30,30),f123i.view(1,1,256,30,30))

        imgsR, imgsIm = imgsRnew, imgsImnew 
        f123r = f123r.cpu().detach().numpy()
        f123i = f123i.cpu().detach().numpy()
        torch.save(f123r,'f123r.npy')
        torch.save(f123i,'f123i.npy')
        # print(imgsR.shape)

        # assert(imgsR==imgsRnew)
        # assert(imgsImnew==imgsIm)




        # Add dim to concat over
        imgsR = imgsR.unsqueeze(4)
        imgsIm = imgsIm.unsqueeze(4)

        # Concat the real and imaginary again then IFFT
        imgs = torch.cat((imgsR, imgsIm), -1)
        # print("1",imgs.shape)
        imgs = torch.ifft(imgs, 2)
        # print("2",imgs.shape)

        # Filter and imgs were real so imag should be ~0
        imgs = imgs[:, :, 1:-1, 1:-1, 0]
        # print("3",imgs.shape)
        return imgs


class StudentNetwork_noRelu(nn.Module):
    def __init__(self,in_channels):
        super(StudentNetwork_noRelu, self).__init__()
        self.conv1 = FFT_Conv_Layer(imgSize=28, inCs=in_channels, outCs1=32, outCs2=128, outCs3=256, imagDim=2, filtSize=3, cuda=True)

        self.conv2_bn = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(9216, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.dropout_input = 0.5
        self.dropout_hidden = 0.5
        self.is_training = True
        self.avepool = nn.AdaptiveAvgPool2d((6, 6))
        self.m = nn.LogSoftmax(dim=1)


    def forward(self, x):
        forw = self.conv1(x)


        forw = self.conv2_bn(forw)
        #forw = self.maxpool(forw)
        forw = self.avepool(forw)
        forw = forw.view(-1, 9216)
        forw = F.dropout(forw, p=self.dropout_input, training=self.is_training)
        forw = F.dropout(self.fc1(forw), p=self.dropout_hidden, training=self.is_training)
        forw = F.relu(forw)
        forw = self.fc2(forw)
        forw = F.relu(forw)
        forw = self.fc3(forw)
        return self.m(forw)


class Teacher_Network(nn.Module):
    def __init__(self, in_channels):
        super(Teacher_Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3)

        self.conv2_bn = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(9216, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

        self.dropout_input = 0.5
        self.dropout_hidden = 0.5
        self.is_training = True
        self.avepool = nn.AdaptiveAvgPool2d((6, 6))
        self.m = nn.LogSoftmax(dim=1)

    def forward(self, x):
        forw = nn.functional.relu(self.conv1(x))
        forw = nn.functional.relu(self.conv2(forw))
        forw = nn.functional.relu(self.conv3(forw))

        forw = self.conv2_bn(forw)
        forw = self.avepool(forw)
        forw = forw.view(-1, 9216)
        forw = F.dropout(forw, p=self.dropout_input, training=self.is_training)
        forw = F.dropout(self.fc1(forw), p=self.dropout_hidden, training=self.is_training)
        forw = F.relu(forw)
        forw = self.fc2(forw)
        forw = F.relu(forw)
        forw = self.fc3(forw)
        return self.m(forw)

In [None]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models



class teacher_solver():

    def __init__(self, train_loader, test_loader, model, criterion, student_optimizer,
                 student_lr_scheduler,
                 epochs, model_path, model_name):

        self.model_path = model_path
        self.model_name = model_name
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model = model
        self.student_optimizer = student_optimizer
        self.student_lr_scheduler = student_lr_scheduler
        self.epochs = epochs
        self.criterion = criterion
        self.step = 0

    def train(self):
        val_loss = None
        print('epochs', self.epochs)
        for epoch in range(self.epochs):
            print("Start Training...")
            self.val_predictions = []
            self.val_gts = []
            start = datetime.now()
            tr_stu_avg_loss = self.train_loop()
            val_stu_avg_loss, testaccuracy = self.validate()
            print('-' * 50)
            print('Summary: Epoch {0} | Time {1}s'.format(epoch, datetime.now() - start))
            print('Train | Loss {0:.4f}'.format(tr_stu_avg_loss))
            print('Validate | Loss {0:.4f}'.format(val_stu_avg_loss))
            print('Validate | Accuracy {0:.4f}'.format(testaccuracy))
            # load the model
            if val_loss is None or val_stu_avg_loss < val_loss:
                val_loss = val_stu_avg_loss
                torch.save(self.model.state_dict(), self.model_path + self.model_name)
                best_model = epoch
            print('best_model is on epoch:', best_model)


    def train_loop(self):
        self.model.train()
        running_loss = 0
        for inputs, labels in self.train_loader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            optimizer.zero_grad()
            logps = self.model.forward(inputs)
            loss = self.criterion(logps, labels)
            loss.backward()
            self.student_optimizer.step()
            running_loss += loss.item()
        traininglosses = running_loss / len(self.train_loader)
        return traininglosses

    def validate(self):
        self.model.eval()
        test_loss = 0
        accuracy = 0
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                logps = self.model.forward(inputs)
                batch_loss = criterion(logps, labels)
                test_loss += batch_loss.item()
                # Calculate accuracy
                ps = torch.exp(logps)
                top_p, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor)).item()


        testinglosses = test_loss / len(self.test_loader)
        testaccuracy = accuracy / len(self.test_loader)
        return testinglosses, testaccuracy


class student_solver():

    def __init__(self, train_loader, test_loader, model, teacher_model, criterion, student_optimizer,
                 student_lr_scheduler,
                 epochs, model_path, model_name,temperatures=10, alphas=0.5, learning_rates=0.0001,
                 learning_rate_decays=0.95, weight_decays=1e-5, momentums= 0.9, dropout_probabilities = (0,0)):

        self.model_path = model_path
        self.model_name = model_name
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model = model
        self.teacher_model = teacher_model
        self.student_optimizer = student_optimizer
        self.student_lr_scheduler = student_lr_scheduler
        self.epochs = epochs
        self.criterion = criterion
        self.step = 0
        self.T = temperatures
        self.alphas = alphas
        self.dropout_input = dropout_probabilities[0]
        self.dropout_hidden = dropout_probabilities[1]
        self.lr_decay = learning_rate_decays
        self.weight_decay = weight_decays
        self.momentum = momentums
        self.lr = learning_rates

        reproducibilitySeed()

    def train(self):
        """
        Trains teacher on given hyperparameters for given number of epochs; Pass val_loader=None when not required to validate for every epoch
        Return: List of training loss, accuracy for each update calculated only on the batch; List of validation loss, accuracy for each epoch
        """
        self.model.dropout_input = self.dropout_input
        self.model.dropout_hidden = self.dropout_hidden
        optimizer = optim.SGD(self.model.parameters(), lr=self.lr,
                              momentum=self.momentum, weight_decay=self.weight_decay)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.lr_decay)
        val_loss = 1

        for epoch in range(self.epochs):
            print("Start KD Training...")
            lr_scheduler.step()
            start = datetime.now()
            tr_stu_avg_loss = self.train_loop()
            val_stu_avg_loss, testaccuracy = self.validate()
            print('-' * 50)
            print('Summary: Epoch {0} | Time {1}s'.format(epoch, datetime.now() - start))
            print('Train | Loss {0:.4f}'.format(tr_stu_avg_loss))
            print('Validate | Loss {0:.4f}'.format(val_stu_avg_loss))
            print('Validate | Accuracy {0:.4f}'.format(testaccuracy))
            # load the model
            if val_loss is None or val_stu_avg_loss < val_loss:
                val_loss = val_stu_avg_loss
                torch.save(self.model.state_dict(), self.model_path + self.model_name)
                best_model = epoch
            print('best_model is on epoch:', best_model)

    def train_loop(self):
        # print_every = 1000
        for i, data in enumerate(self.train_loader, 0):
            X, y = data
            X, y = X.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            teacher_pred = None
            if (self.alphas > 0):
                with torch.no_grad():
                    teacher_pred = self.teacher_model(X)
            student_pred = self.model(X)
            loss = self.studentLossFn(teacher_pred, student_pred, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 20)
            optimizer.step()
            # accuracy = float(torch.sum(torch.argmax(student_pred, dim=1) == y).item()) / y.shape[0]
            # if i % print_every == 0:
            #     loss, acc = self.validate()
            #     print('train loss: %.3f, train loss: %.3f' %(loss, acc))
        return loss

    def validate(self):
        loss, val_acc = self.getLossAccuracyOnDataset(self.model, self.test_loader, 'cuda')
        return loss, val_acc


    def getLossAccuracyOnDataset(self, network, dataset_loader, fast_device, criterion=None):
        """
        Returns (loss, accuracy) of network on given dataset
        """
        network.is_training = False
        accuracy = 0.0
        loss = 0.0
        dataset_size = 0
        for j, D in enumerate(dataset_loader, 0):
            X, y = D
            X = X.to(fast_device)
            y = y.to(fast_device)
            with torch.no_grad():
                pred = network(X)
                if criterion is not None:
                    loss += criterion(pred, y) * y.shape[0]
                accuracy += torch.sum(torch.argmax(pred, dim=1) == y).item()
            dataset_size += y.shape[0]
        loss, accuracy = loss / dataset_size, accuracy / dataset_size
        network.is_training = True
        return loss, accuracy

    def studentLossFn(self, teacher_pred, student_pred, y):
        """
        Loss function for student network: Loss = alpha * (distillation loss with soft-target) + (1 - alpha) * (cross-entropy loss with true label)
        Return: loss
        """
        T = self.T
        alpha = self.alphas
        if (alpha > 0):
            loss = F.kl_div(F.log_softmax(student_pred / T, dim=1), F.softmax(teacher_pred / T, dim=1),
                            reduction='batchmean') * (T ** 2) * alpha + F.cross_entropy(student_pred, y) * (
                           1 - alpha)
        else:
            loss = F.cross_entropy(student_pred, y)
        return loss


def reproducibilitySeed(use_gpu=True):
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False



In [None]:


# train_loader, test_loader = dataload('MNIST', bs=10)
# teacher_model = Teacher_Network(in_channels=1).cuda()
# student_model = StudentNetwork_noRelu(in_channels=1).cuda()

# student_model_name = 'student.pth'

# student_model.load_state_dict(torch.load(student_model_name))
# print("init weight from {}".format(student_model_name))
# print(sum([param.nelement() * param.element_size() for param in student_model.parameters()]))

# solver = student_solver(train_loader, test_loader, student_model, teacher_model, None, None,
#                         None,
#                         None, '', student_model_name)
# loss, val_acc = solver.validate()
# print('Validate | Loss {0:.4f}'.format(loss))
# print('Validate | Accuracy {0:.4f}'.format(val_acc))

In [None]:
student_model = StudentNetwork_noRelu(in_channels=1).cuda()
student_model_name = 'student.pth'

student_model.load_state_dict(torch.load(student_model_name))
print("init weight from {}".format(student_model_name))
print(sum([param.nelement() * param.element_size() for param in student_model.parameters()]))

test = torch.zeros(10,1,28,28).cuda()
student_model(test)