In [None]:
import torch
from bionet.biomodule import BioModule
from bionet.modules.linear import FCNet
from bionet.datasets import Datasets
from torch import nn
from torchvision import transforms

import numpy as np

from bionet.modules.alexnetMini import AlexNetMini
from bionet.modules.resnetMini import resnet20,resnet32,resnet44,resnet56,resnet110,resnet1202
import torch

import numpy as np
import pickle as pkl


In [None]:
def train(model, dataloader, transform, epochs=2, optimizer=torch.optim.SGD, optimizer_args={'lr':1e-3}, purge=True, scale_grad=True, crystallize=True, verbose=False):
    optimizer=optimizer(model.parameters(),**optimizer_args)
    device = next(model.parameters())[0].device    
    #with torch.autograd.detect_anomaly():    
    models = []
    for epoch in range(epochs):
        correct=0
        for idx, (sample, target) in enumerate(dataloader):
            model.zero_grad()
            sample, target= transform(sample.to(device)), target.to(device)
            if len(sample.shape) == 3:                
                sample = sample.unsqueeze(1)
            output = model(sample)            
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            if scale_grad: model.scale_grad()
            if purge: model.purge()
            if crystallize: model.crystallize()
            optimizer.step()
            pred = output.argmax(dim=1, keepdim=True) 
            correct_batch = pred.eq(target.view_as(pred)).sum().item()
            correct+=correct_batch
            print(f'\r{epoch:4d} {idx:4d} accuracy {correct_batch/float(len(sample)):6.4f} loss {loss.item():6.4f} ', end='')
        if verbose: print(f'Accuracy for Epoch {epoch:3d} : {correct/dataloader.batch_size/len(dataloader): 6.4f} ')
        if epoch in [0,50,100,200]:
            models.append({"epoch":epoch, "state_dict":model.state_dict()})
    return models

In [None]:
def eval_resnet(model_name, num_channels, num_classes):
    if '1202' in model_name:
        return resnet1202(num_channels, num_classes)
    elif '32' in model_name:
        return resnet32(num_channels, num_classes)
    elif '44' in model_name:
        return resnet44(num_channels, num_classes)
    elif '56' in model_name:
        return resnet56(num_channels, num_classes)
    elif '110' in model_name:
        return resnet110(num_channels, num_classes)
    else :
        return resnet20(num_channels, num_classes)

def run(model_class =  'FCNet', model_args = {'in_feats':784, 'shapes':[1000], 'num_classes':10}, 
        optimizer = 'SGD', optimizer_args= {'lr':1e-3}, grad_scale=0.09, batch_size=100, vbatch_size=1000,
        dampening_factor = 0.6,fold=5, crystal_thresh=4.5e-5, purge_distance=8.0, accum_neurons=2,
       epochs=201, dataset_name='CIFAR100', purge=True, scale_grad=True, crystallize=False, verbose=False):
    converter = BioModule.get_convert_to_bionet_converter(grad_scale=grad_scale, dampening_factor=dampening_factor, crystal_thresh=crystal_thresh,purge_distance=purge_distance, accum_neurons=accum_neurons)
    
    str_model_class = model_class
    str_optimizer = optimizer
    
    optimizer = torch.optim.SGD
    num_channels = 3 if dataset_name.upper() != 'MNIST' else 1
    num_classes = 100 if dataset_name.upper()[-3:] == '100' else 10
    
    if model_class == "FCNet":
        
        model_class = FCNet
        model_args = {'in_feats':3*32*32, 'shapes':[3000,3000], 'num_classes':100}
    elif 'resnet' in model_class:
        model_class, model_args =  eval_resnet(model_class, num_channels, num_classes)
    else:
        model_args = {"num_classes":num_classes, "num_chans":num_channels, "target":dataset_name.upper()}
        model_class = AlexNetMini
    
    model_class = converter(model_class)
    dataset = Datasets(dataset_name)
    transform=nn.Sequential(
            transforms.Normalize(**dataset.normalization)
            )
    
    model = model_class(**model_args).cuda()
    dataset.train()
    dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, num_workers=0, shuffle=True, batch_size=batch_size)
    
    models = train(model, dataloader, transform, epochs=epochs, optimizer=optimizer, optimizer_args=optimizer_args, 
          purge=purge, scale_grad=scale_grad, crystallize=crystallize, verbose=verbose)

    options = {'model_class' :  str_model_class, 'model_args' : model_args, 'optimizer' :str_optimizer, 'optimizer_args': optimizer_args, 'grad_scale':grad_scale, 
           'batch_size':batch_size, 'vbatch_size':vbatch_size,'dampening_factor':dampening_factor,'fold':fold, 'crystal_thresh':crystal_thresh, 'purge_distance':purge_distance, 'accum_neurons':accum_neurons,'epochs':epochs, 
           'dataset_name':dataset_name, 'purge':purge, 'scale_grad':scale_grad, 'crystallize':crystallize, 'verbose':verbose}
    if verbose:
        print(np.mean(accs), np.std(accs))
        plt.plot(accs)
        plt.show()
        plt.plot(losses)
        plt.show()
    return {'model':str_model_class, 'options':options, 'models':models }

combinations = [
    [False, False], 
    [True, False],  
    [False, True], 
    [True, True], 
]
nets = ['FCNet', 'AlexNetMini', 'resnet20', 'resnet56']
accums = [0,2]

In [None]:
from time import time
#result_list = []
i=0
num_runs = len(combinations)*len(nets)*len(accums)


def hms_from_seconds(time_el):
    h_el = int(time_el//3600)
    m_el = int((time_el-3600*h_el)//60)
    s_el = int(time_el-3600*h_el-m_el*60)
    return h_el, m_el, s_el
t0 = time()
for net in nets:
    for accum in accums:
        for combination in combinations:
            t1 = time()
            time_el = int(t1-t0)
            h_el, m_el, s_el = hms_from_seconds(time_el)
            approx_time_to_go = 0 if i==0 else (time_el/i)*num_runs
            h_to, m_to, s_to = hms_from_seconds(approx_time_to_go)
            purge, scale_grad = combination
            print(f"Working on Combination #{i:4d} of {num_runs:4d}: net: {net}, accum: {accum} purge: {purge}, scale_grad: {scale_grad} - time elapsed {h_el:2d}h {m_el:2d}min {s_el:2d}s - {h_to:2d}h {m_to:2d}min {s_to:2d} to go" )
            res = run(model_class = net, accum_neurons=accum, purge= purge, scale_grad=scale_grad)
            result = {'result':res, 'options':{'net':net, 'accum':accum, 'combination':combination, 'purge':purge, 'scale_grad':scale_grad}}
            pkl.dump(result, open(f'experiments/tmp_res_v3_{net}_{accum}_{purge}_{scale_grad}.pkl','wb'))
            i+=1
            print()       
        