In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import DataLoader, Dataset
import numpy as np
import copy
from contextlib import contextmanager

from six import add_metaclass
from typing import Type, Any, Callable, Union, List, Optional

In [None]:
import models
from models.resnet import ResNet, BasicBlock
from utils import *
from train_2nd_order import weight_decay, trainer
from validate_NC import compute_Wh_b_relation, compute_W_H_relation, compute_ETF, compute_Sigma_B, compute_Sigma_W,compute_info,FCFeatures
from datasets import make_dataset

import sys
import os
import shutil
import scipy.linalg as scilin

In [None]:
# architecture params
model='resnet18'

# dataset
dataset='mnist'
data_dir='~/data'

# training params
epochs = 1
optimizer="LBFGS"
lr=0.1
history_size=10
batch_size=1024
uid="tmp4"
device = "cpu"
SOTA=False

# Network params
loss = 'CrossEntropy'
bias=True
ETF_fc=False
fixdim=0

In [None]:
class args:
    def __init__(self, model='resnet18', bias=True, ETF_fc=False, fixdim=0, SOTA=False,
                 width=1024, depth=6, gpu_id=0, seed=6, use_cudnn=True,
                 dataset='mnist', data_dir='~/data', uid=None, force=False,
                 epochs=200, batch_size=1024, loss='CrossEntropy', sample_size=None,
                 lr=0.1, patience=40, decay_type='step', gamma = 0.1, optimizer='SGD',
                 weight_decay=5e-4, sep_decay=False, feature_decay_rate=1e-4,
                 history_size=10, ghost_batch=128, device = "cpu",
                 
                 # distill setting
                 distill_steps=10,
                 distill_epochs=3,
                 distilled_images_per_class_per_step=1,
                 num_classes = None,
                 distill_lr = 0.02,
                 decay_epochs = 40,
                 decay_factor = 0.5
                ):

        self.model = model
        self.bias = bias
        self.ETF_fc = ETF_fc
        self.fixdim = fixdim
        self.SOTA = SOTA
        self.width = width
        self.depth = depth
        self.gpu_id = gpu_id
        self.seed = seed
        self.use_cudnn = use_cudnn
        self.dataset = dataset
        self.data_dir = data_dir
        self.uid = uid
        self.force = force
        self.epochs = epochs
        self.batch_size = batch_size
        self.loss = loss
        self.sample_size = sample_size
        self.lr = lr
        self.patience = patience
        self.decay_type = decay_type
        self.gamma = gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.sep_decay = sep_decay
        self.feature_decay_rate = feature_decay_rate
        self.history_size = history_size
        self.ghost_batch = ghost_batch
        self.device = device
        
        # distill setting
        self.distill_steps = distill_steps
        self.distill_epochs = distill_epochs
        self.num_classes = num_classes
        self.distilled_images_per_class_per_step = distilled_images_per_class_per_step
        self.distill_lr = distill_lr
        self.decay_epochs = decay_epochs,
        self.decay_factor = decay_factor
        
        if self.uid is None:
            unique_id = str(np.random.randint(0, 100000))
            print("revise the unique id to a random number " + str(unique_id))
            self.uid = unique_id
            timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H-%M")
            save_path = './model_weights/' + self.uid + '-' + timestamp
        else:
            save_path = './model_weights/' + str(self.uid)
        
        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)
        else:
            if not self.force:
                raise ("please use another uid ")
            else:
                print("override this uid" + self.uid)
                for m in range(1, 10):
#                     print(os.path.exists(save_path + "/log.txt.bk" + str(m)))
                    if not os.path.exists(save_path + "/log.txt.bk" + str(m)):
                        shutil.copy(save_path + "/log.txt", save_path + "/log.txt.bk" + str(m))
                        shutil.copy(save_path + "/args.txt", save_path + "/args.txt.bk" + str(m))
                        break
        self.save_path = save_path

In [None]:
args = args(model=model, dataset=dataset, optimizer=optimizer, lr = lr, loss = loss,
            history_size=history_size, batch_size = batch_size, epochs=1,
            uid = uid, device = device)

In [None]:
trainloader, testloader, num_classes = make_dataset(dataset, 
                                           data_dir, 
                                           batch_size, 
                                           SOTA=SOTA)
print(num_classes)

In [None]:
images, labels = next(iter(trainloader))
size_train, channels, height, width = images.shape
num_classes = len(torch.unique(labels))
nc = channels
input_size = height, width
print("The number of class in our training set is ", num_classes)
print("Batch size:", size_train, "Number of channels:", channels, "input height:", height, "input width:", width)