In [None]:
# Various torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F

# torchvision
from torchvision import datasets, transforms

# ------------------------
# get up one directory 
import sys, os
sys.path.append(os.path.abspath('../'))
# ------------------------

# custom packages
import models.aux_funs as maf
import optimizers as op
import regularizers as reg
import train
import math
import utils.configuration as cf
import utils.datasets as ud
from utils.datasets import get_data_set, GaussianSmoothing
from models.fully_connected import fully_connected

In [None]:
# -----------------------------------------------------------------------------------
# Fix random seed
# -----------------------------------------------------------------------------------
random_seed = 3
cf.seed_torch(random_seed)

# Test cases
## Test case 0: No Skips, Denoising
* No skips (regularization parameter for skips is set to a large value)
* Denoising task, no blur is added to the images

## Test case 1: No Skips, Deblurring
* No skips (regularization parameter for skips is set to a large value)
* Deblurring task, blur is added to the images

## Test case 2: Skips, Denoising
* Skips are allowed but regularized
* Denoising task, no blur is added to the images

## Test case 3: Skips, Deblurring
* Skips are allowed but regularized
* Deblurring task, blur is added to the images 

In [None]:
test_case = 0

if test_case == 0:
    lamda = 0.07
    lamda_2 = 1e10
    add_blur = False
elif test_case == 1:
    lamda = 0.07
    lamda_2 = 1e10
    add_blur = True
elif test_case == 2:
    lamda = 0.07
    lamda_2 = 28*lamda
    add_blur = False
elif test_case == 3:
    lamda = 0.07
    lamda_2 = 28*lamda
    add_blur = True

# Configuration

In [None]:
def reshaped_mse_loss(x,y):
    return torch.nn.MSELoss()(x,y.view(-1,28*28))
# -----------------------------------------------------------------------------------
# Parameters
# -----------------------------------------------------------------------------------
conf_args = {#
    # data specification
    'data_file':"../../Data", 'train_split':0.95, 'data_set':"Encoder-MNIST", 'download':False,
    'add_noise':True, 'noise_std':0.05, 'add_blur':add_blur,
    # cuda
    'use_cuda':True, 'num_workers':4, 'cuda_device':1, 'pin_memory':False,
    #
    'epochs':100, 'loss':reshaped_mse_loss,
    # optimizer
    'delta':1.0, 'lr':0.001, 'lamda':lamda, 'lamda_2':lamda_2, 'optim':"AdaBreg", 'row_group':True,
    'reg':reg.reg_l1_l2, 'beta':0.0,
    # model
    'model_size':7*[28*28], 'act_fun':torch.nn.ReLU(),
    # initialization
    'sparse_init':0.01, 'r':[1,5,1],
    # misc
    'random_seed':random_seed, 'eval_acc':False, 'name':'main-Encoder-MNIST', 'super_type':'Encoder'
}
conf = cf.Conf(**conf_args)

# General ResNet

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_neurons, out_neurons, idx, act_fun, act_fun_outer):
        super(BasicBlock, self).__init__()
        self.act_fun = act_fun
        self.act_fun_outer = act_fun_outer
        self.linear = nn.Linear(in_neurons, out_neurons)

        s = torch.zeros((idx+1,))              
                        
        self.skips = nn.Parameter(s, requires_grad=True)
        self.idx = idx  
        
    def forward(self, x):    
        x2 = torch.sum(self.skips.view(1,1,-1) * x,2)
        x3 = x[:,:,-1]
        
        out = self.act_fun_outer(self.act_fun(self.linear(x3)) + x2)
        #
        return torch.cat((x,out.unsqueeze(2)),2)
    
class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        
    def forward(self, x):
        return x[:,:,-1]

class fully_skip_connected(nn.Module):
    def __init__(self, sizes, act_fn, outer_act_fn = nn.Identity(), mean=0.0, std=1.0):
        super(fully_skip_connected, self).__init__()
        self.mean = mean
        self.std = std
        self.num_l = len(sizes)
        
        self.act_fn = act_fn
        layer_list = []
        for i in range(self.num_l-1):
            layer_list.append(BasicBlock(sizes[i], sizes[i+1], i, self.act_fn, outer_act_fn))
            
        layer_list.append(OutBlock())
            
        self.layers = nn.Sequential(*layer_list)
        
        
    def forward(self, x):
        x = (x-self.mean)/self.std
        x = nn.Flatten()(x)
        x = x.unsqueeze(2)
        x = self.layers(x)
        return x

In [None]:
# -----------------------------------------------------------------------------------
# define the model and an instance of the best model class
# -----------------------------------------------------------------------------------
model_kwargs = {'mean':conf.data_set_mean, 'std':conf.data_set_std}    

model_conf = "fsc"
if model_conf == "fsc":
    model = fully_skip_connected(conf.model_size, conf.act_fun, **model_kwargs)
    best_model = train.best_model(fully_skip_connected(conf.model_size, conf.act_fun, **model_kwargs).to(conf.device))
elif model_conf == "fc":
    model = fully_connected(conf.model_size, conf.act_fun, **model_kwargs)
    best_model = train.best_model(fully_connected(conf.model_size, conf.act_fun, **model_kwargs).to(conf.device))
    
# sparsify
maf.sparse_bias_uniform_(model, 0,conf.r[0])
maf.sparse_weight_normal_(model, conf.r[1])
maf.sparsify_(model, conf.sparse_init, row_group = conf.row_group)
model = model.to(conf.device)

In [None]:
# -----------------------------------------------------------------------------------
# define the model and an instance of the best model class
# -----------------------------------------------------------------------------------
model_kwargs = {'mean':conf.data_set_mean, 'std':conf.data_set_std}    

def init_weights(conf, model):
    # sparsify
    maf.sparse_bias_uniform_(model, 0, conf.r[0])
    maf.sparse_weight_normal_(model, conf.r[1])
    maf.sparsify_(model, conf.sparse_init, row_group = conf.row_group)
    model = model.to(conf.device)
    
    return model

In [None]:
# -----------------------------------------------------------------------------------
# Optimizer
# -----------------------------------------------------------------------------------
def get_skips(model):
    for m in model.modules():
        if hasattr(m,'skips'):
            yield m.skips
        else:
            continue

def print_skips(model):
    for m in model.modules():
        if hasattr(m,'skips'):
            print((0.001*torch.round(1000*m.skips.data).cpu()))
            
def skips_to_list(model):
    skips = []
    for m in model.modules():
        if hasattr(m,'skips'):
            skips.append(m.skips.data.tolist())
    return skips
    

def init_opt(conf, model):
    # Get access to different model parameters
    weights_linear = maf.get_weights_linear(model)
    biases = maf.get_bias(model)
    skips = get_skips(model)

    # -----------------------------------------------------------------------------------
    # Initialize optimizer
    # -----------------------------------------------------------------------------------
    reg1 = conf.reg(lamda=conf.lamda)
    reg2 = reg.reg_l1(lamda=conf.lamda_2)

    if conf.optim == "SGD":
        opt = torch.optim.SGD(model.parameters(), lr=conf.lr, momentum=conf.beta)
    elif conf.optim == "AdaBreg":
        opt = op.AdaBreg([{'params': weights_linear, 'lr' : conf.lr, 'reg' : reg1},
                           {'params': biases, 'lr': conf.lr},
                           {'params': skips, 'lr':conf.lr, 'reg':reg2}])

    # learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=5,threshold=0.01)
    
    return opt, scheduler

In [None]:
save_params = False
if save_params:
    conf.write_to_csv()

In [None]:
# -----------------------------------------------------------------------------------
# Prepare Data
# -----------------------------------------------------------------------------------
smoothing = GaussianSmoothing(1, 9, 5.0)   
transform = []

if conf.add_blur:
    transform.append(smoothing)

if conf.add_noise:
    transform.append(ud.add_noise(std=conf.noise_std,device='cpu'))
    
transform = transforms.Compose(transform)   

train_set,test_set = ud.get_mnist(conf)
train_set = ud.get_augmented_dataset(conf, train_set,transform)
test_set = ud.get_augmented_dataset(conf, test_set,transform)
train_loader, valid_loader, test_loader = ud.train_valid_test_split(conf, train_set, test_set)

# Run Specification

In [None]:
#runs = cf.run(**{'num_runs':10}) #throws error
runs = cf.run([{} for i in range(10)])

# Training

In [None]:
while runs.step():
    # -----------------------------------------------------------------------------------
    # initalize history
    # -----------------------------------------------------------------------------------
    tracked = ['loss', 'node_sparse']
    train_history = {key: [] for key in tracked}
    val_history = {key: [] for key in tracked}
    # -----------------------------------------------------------------------------------
    # Reinit weigts and the corresponding optimizer
    # -----------------------------------------------------------------------------------
    model = init_weights(conf, model)
    opt, scheduler = init_opt(conf, model)
    
    # -----------------------------------------------------------------------------------
    # train the model
    # -----------------------------------------------------------------------------------
    for epoch in range(conf.epochs):
        print(25*"<>")
        print(50*"|")
        print(25*"<>")
        print('Epoch:', epoch)

        # ------------------------------------------------------------------------
        # train step, log the accuracy and loss
        # ------------------------------------------------------------------------
        train_data = train.train_step(conf, model, opt, train_loader)

        # update history
        for key in tracked:
            if key in train_data:
                train_history[key].append(train_data[key])        

        # ------------------------------------------------------------------------
        # validation step
        val_data = train.validation_step(conf, model, opt, valid_loader)

        print_skips(model)

        # update history
        for key in tracked:
            if key in val_data:
                val_history[key].append(val_data[key])

        for i,reg_val in enumerate(val_data['node_sparse']):
            key = "node_sparse" + str(i)
            if key in val_history:
                val_history[key].append(reg_val)
            else:
                val_history[key] = [reg_val]


        scheduler.step(train_data['loss'])
        print("Learning rate:",opt.param_groups[0]['lr'])
        best_model(train_data['acc'], val_data['acc'], model=model)

      
    # test
    test_hist = train.test(conf, model, test_loader)
    
    # add values to the run history
    runs.add_history(train_history, "train")
    runs.add_history(val_history, "val")
    runs.add_history(test_hist, "test")
    
    # other properties
    hist = {'skips': skips_to_list(model)}
    runs.add_history(hist, "")
            
    # update random seed
    conf.random_seed += 1
    cf.seed_torch(conf.random_seed)