In [1]:
import sys
#sys.path.insert(0,'..')
import pandas as pd
import torch
import torch.nn as nn
from model import MLP, VGG11, resnet18
from data import get_ds
from activation import MyReLU
from utils import unif_weight_copy, touch0, diff_grad_chk
from tqdm import tqdm

In [2]:
torch.backends.cudnn.deterministic = True

device = 'cuda:5'
torch.cuda.set_device(5)

act_fn = MyReLU
models = ['mlp','vgg','res']
weight_sampling = 100

df = pd.DataFrame(columns=['model','data_idx','precision','touch0','diff_grad'])

for m in models:
    
    if m == 'mlp':
        data = 'mnist'
        NN0 = MLP(512,0,True).to(device); NN1 = MLP(512,1,True).to(device)

    else:
        data = 'cifar10'
        if m == 'vgg':
            NN0 = VGG11(0,True).to(device); NN1 = VGG11(1,True).to(device)
        else:
            NN0 = resnet18(bias=True,norm_layer=nn.Identity,relu_fn = lambda : MyReLU(0)).to(device)
            NN1 = resnet18(bias=True,norm_layer=nn.Identity,relu_fn = lambda : MyReLU(1)).to(device)
    
    train, _ = get_ds(data)

    for precision in [8,16,32]:
        t0_cnt = 0; df_cnt = 0

        for run_id in tqdm(range(weight_sampling),desc=m+"_prc_%d"%precision):
            
            NN0, NN1 = unif_weight_copy(NN0, NN1, precision, device)

            for i,(x,_) in enumerate(train):
                
                if i == 200 : break

                x = x.double().to(device) 

                NN0.zero_grad(); NN1.zero_grad()
                
                y0 = NN0(x).sum(); y1 = NN1(x).sum()
                y0.backward(); y1.backward()
                
                t0_cnt = touch0(NN0,act_fn)

                if touch0(NN0,act_fn) != touch0(NN1,act_fn):
                    raise Exception("different weight")

                df_cnt = diff_grad_chk(NN0,NN1)

                if t0_cnt < df_cnt: 
                    raise Exception("different gradient")
                
                result_dict = {"model":m, "data_idx": i, "precision":precision, "touch0":t0_cnt, "diff_grad": df_cnt}
                df0 = pd.DataFrame(result_dict,index=[run_id])
                df = pd.concat([df,df0])

##############################
        df.to_csv('result0.csv')
##############################

mlp_prc_8: 100%|██████████| 100/100 [03:33<00:00,  2.14s/it]
mlp_prc_16: 100%|██████████| 100/100 [03:48<00:00,  2.28s/it]
mlp_prc_32: 100%|██████████| 100/100 [04:01<00:00,  2.41s/it]


Files already downloaded and verified
Files already downloaded and verified


vgg_prc_8: 100%|██████████| 100/100 [2:47:27<00:00, 100.47s/it] 
vgg_prc_16: 100%|██████████| 100/100 [2:49:17<00:00, 101.57s/it] 
vgg_prc_32: 100%|██████████| 100/100 [2:50:09<00:00, 102.09s/it] 


Files already downloaded and verified
Files already downloaded and verified


res_prc_8: 100%|██████████| 100/100 [4:14:01<00:00, 152.42s/it] 
res_prc_16: 100%|██████████| 100/100 [4:14:29<00:00, 152.69s/it] 
res_prc_32:   5%|▌         | 5/100 [11:20<3:35:33, 136.15s/it]


KeyboardInterrupt: 