In [1]:
import argparse
import torch
import wandb
import os
import time
import numpy as np
from tqdm import tqdm_notebook as tqdm
from torch.utils.data import DataLoader
from data.dataset import get_cf_dataset,get_ctr_dataset,CTR_Dataset,get_avazu_dataset,Avazu_Dataset
from utils import seed_everything
from model.nas import Dnis, AdamNas, FM
from shutil import copyfile
import sys

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default='data/avazu/click.pickle')
parser.add_argument("--exp", type=str, default='nas 31')
parser.add_argument("--cuda", nargs='*', type=int, default=[3], help='cuda visible devices')
parser.add_argument("--embedding_dim", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=4096)
parser.add_argument("--lr_w", type=float, default=1e-2)
parser.add_argument("--lr_a", type=float, default=1e-2)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--init_alpha", type=float, default=1)
parser.add_argument("--alpha_optim", type=str, default='SGD')
parser.add_argument("--load_checkpoint", type=int, default=1)
parser.add_argument("--warm_start", type=int, default=0)
parser.add_argument("--num_dim_split", type=int, default=64)
parser.add_argument("--search_space", type=str, default='free')
parser.add_argument("--l1", type=float, default=0)
parser.add_argument("--normalize", type=int, default=0)
parser.add_argument("--use_second_grad", type=int, default=1)
#parser.add_argument("--model_name", type=str, default='Wide_and_Deep')
parser.add_argument("--model_name", type=str, default='DeepFM')
parser.add_argument("--alpha_upper_round", type=int, default=0)
parser.add_argument("--dataset_type", type=str, default='ava')
args = parser.parse_args("".split())
os.environ["CUDA_VISIBLE_DEVICES"] = f'{args.cuda}'[1:-1]
device = torch.device('cuda')

In [3]:
train_dataset, val_dataset, test_dataset, num_features = get_avazu_dataset(args.data_path)
num_fields = 23   
batch_size = args.batch_size
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)

The dataset has been processed. Reading the cache...


In [4]:
#dnis = Dnis(num_features, args.embedding_dim, num_dim_split=args.num_dim_split, search_space=args.search_space,normalize=args.normalize, model_name=args.model_name,num_fields=num_fields, feature_split=[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1])
dnis = Dnis(num_features, args.embedding_dim, num_dim_split=args.num_dim_split, search_space=args.search_space,
                normalize=args.normalize, model_name=args.model_name,num_fields=num_fields, feature_split=[0.1,0.2,0.2,0.2,0.3])

dnis.feature_nums:tensor([154448, 308897, 308897, 308897, 463350])


In [5]:
#dnis.load_state_dict(torch.load("wandb/run-20210720_122759-3i1ozpwr/files/DNIS-CTR-Avazu.tar"))
dnis.load_state_dict(torch.load("wandb/run-20210716_184325-3qsykpim/files/DNIS-CTR-Avazu.tar"))
#dnis.load_state_dict(torch.load("wandb/run-20210728_125107-sfshogmi/files/DNIS-CTR-Avazu.tar"))


<All keys matched successfully>

In [6]:
dnis.eval()

Dnis(
  (feature_embeddings): Embedding(1544489, 64)
  (model): DeepFM(
    (feature_biases): Embedding(1544489, 1)
    (fc1): Linear(in_features=1472, out_features=400, bias=True)
    (fc2): Linear(in_features=400, out_features=400, bias=True)
    (fc3): Linear(in_features=400, out_features=400, bias=True)
    (fc4): Linear(in_features=400, out_features=1, bias=True)
  )
)

In [7]:
for name, parameter in dnis.named_parameters():
    print (name, parameter)

alpha Parameter containing:
tensor([[8.9068e-01, 7.4982e-01, 6.5229e-02, 2.8706e-01, 4.5175e-01, 3.8904e-02,
         6.0947e-02, 9.4172e-01, 1.1652e-01, 1.3816e-01, 1.0624e-01, 4.5850e-01,
         4.4884e-01, 7.0286e-01, 8.5268e-02, 2.9903e-01, 1.0281e-01, 1.5134e-01,
         1.5441e-01, 5.2884e-02, 7.3251e-02, 3.7504e-01, 6.9772e-02, 1.7425e-01,
         7.6491e-01, 1.4102e-01, 5.4107e-01, 4.8125e-02, 1.2655e-01, 1.0688e-01,
         4.5544e-02, 3.7616e-01, 4.8382e-01, 7.0013e-01, 5.8183e-02, 4.5206e-03,
         7.8600e-02, 1.0800e-01, 5.4182e-01, 6.2045e-01, 5.4111e-02, 1.3194e-01,
         3.8040e-01, 3.9374e-02, 4.1311e-03, 4.2428e-01, 0.0000e+00, 1.0000e+00,
         2.9202e-01, 1.1475e-01, 4.7421e-01, 8.5632e-01, 6.2379e-01, 1.5601e-01,
         3.6184e-01, 1.8883e-01, 3.6037e-02, 4.6095e-01, 2.8671e-01, 2.5716e-02,
         6.1751e-01, 3.1376e-01, 3.8785e-02, 1.2487e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0466e-02, 3.8053e-03, 4.8646e-03,
         0.0000e+00, 0.

In [8]:
from sklearn.metrics import roc_auc_score
def val(model, dataloader):
    model.eval()
    running_loss = 0
    pred_arr = np.array([])
    label_arr = np.array([])
    with torch.no_grad():
        for itr, batch in tqdm(enumerate(dataloader)):
            batch = [item.to(device) for item in batch]
            feature_ids, feature_vals, labels = batch
            outputs = model(feature_ids, feature_vals)
            loss = torch.nn.BCEWithLogitsLoss()(outputs.squeeze(), labels.squeeze())
            running_loss += loss.data.detach().cpu().item()
            pred_arr = np.hstack(
                [pred_arr, outputs.data.detach().cpu()]) if pred_arr.size else outputs.data.detach().cpu()
            label_arr = np.hstack(
                [label_arr, labels.data.detach().cpu()]) if label_arr.size else labels.data.detach().cpu()
        val_loss = running_loss / (itr + 1)
        torch.cuda.empty_cache()
    if args.dataset_type == "ava":
        auc = roc_auc_score(label_arr, pred_arr)
        return val_loss, auc
    return val_loss, 0

In [9]:
dnis=dnis.to(device)
val(dnis,test_dataloader)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

(0.739168377419715, 0.7655923347969603)

In [10]:
alpha_checkpoint = dnis.alpha.data.clone().detach()

In [11]:
alpha_checkpoint

tensor([[8.9068e-01, 7.4982e-01, 6.5229e-02, 2.8706e-01, 4.5175e-01, 3.8904e-02,
         6.0947e-02, 9.4172e-01, 1.1652e-01, 1.3816e-01, 1.0624e-01, 4.5850e-01,
         4.4884e-01, 7.0286e-01, 8.5268e-02, 2.9903e-01, 1.0281e-01, 1.5134e-01,
         1.5441e-01, 5.2884e-02, 7.3251e-02, 3.7504e-01, 6.9772e-02, 1.7425e-01,
         7.6491e-01, 1.4102e-01, 5.4107e-01, 4.8125e-02, 1.2655e-01, 1.0688e-01,
         4.5544e-02, 3.7616e-01, 4.8382e-01, 7.0013e-01, 5.8183e-02, 4.5206e-03,
         7.8600e-02, 1.0800e-01, 5.4182e-01, 6.2045e-01, 5.4111e-02, 1.3194e-01,
         3.8040e-01, 3.9374e-02, 4.1311e-03, 4.2428e-01, 0.0000e+00, 1.0000e+00,
         2.9202e-01, 1.1475e-01, 4.7421e-01, 8.5632e-01, 6.2379e-01, 1.5601e-01,
         3.6184e-01, 1.8883e-01, 3.6037e-02, 4.6095e-01, 2.8671e-01, 2.5716e-02,
         6.1751e-01, 3.1376e-01, 3.8785e-02, 1.2487e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0466e-02, 3.8053e-03, 4.8646e-03,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000

In [12]:
alpha_checkpoint[0][0].item()

0.8906775712966919

In [13]:
for i,row in enumerate(alpha_checkpoint):
    for item in row:
        print("block %d"%(i+1))

block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 1
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2
block 2


In [14]:
for i,row in enumerate(alpha_checkpoint):
    for item in row:
        print(item.item())
# for item in alpha_checkpoint[0]:
#     print(item.item())

0.8906775712966919
0.7498247027397156
0.06522924453020096
0.2870582044124603
0.45175084471702576
0.03890378773212433
0.06094735115766525
0.941724956035614
0.1165192574262619
0.13815782964229584
0.10624378174543381
0.45849746465682983
0.4488407075405121
0.7028577327728271
0.08526826649904251
0.29902926087379456
0.10280556976795197
0.15133629739284515
0.15441220998764038
0.05288417637348175
0.07325130701065063
0.3750394582748413
0.06977176666259766
0.1742529571056366
0.7649118304252625
0.1410234272480011
0.5410747528076172
0.048125170171260834
0.12654587626457214
0.10687520354986191
0.04554407298564911
0.37616461515426636
0.48382243514060974
0.7001307010650635
0.058183200657367706
0.004520602058619261
0.07860010117292404
0.10799648612737656
0.5418176651000977
0.62044757604599
0.05411115661263466
0.13194425404071808
0.3804042935371399
0.03937441483139992
0.00413112947717309
0.4242765009403229
0.0
1.0
0.2920246720314026
0.114754818379879
0.4742065370082855
0.8563236594200134
0.623793900012

In [15]:
#alpha_block_mask = alpha_checkpoint.repeat_interleave(dnis.feature_nums.to(dnis.alpha).long(),
                                                     #   dim=0).repeat_interleave(
  # dnis.embed_dims.to(dnis.alpha).long(), dim=1)

In [16]:
#dnis.feature_embeddings.weight.data *= alpha_block_mask

In [17]:
#dnis.alpha.data = torch.ones_like(dnis.alpha.data).to(dnis.alpha.data)

In [18]:
val(dnis,test_dataloader)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

(0.739168377419715, 0.7655923347969603)

In [19]:
arr = np.reshape(dnis.feature_embeddings.weight.data.abs().detach().cpu(),(-1))
arr = np.sort(arr)
embedding_checkpoint = dnis.feature_embeddings.weight.data.clone().detach()

In [20]:
dnis.feature_embeddings.weight.data

tensor([[ 1.1685e-02,  1.0090e-01,  1.0588e-01,  ...,  6.8006e-01,
          1.7277e-01, -5.5160e-01],
        [ 1.2717e-01, -2.9075e-02,  6.5321e-03,  ...,  7.4864e-02,
          5.3864e-02, -2.5512e-01],
        [-3.6245e-01,  5.7768e-01,  9.3803e-02,  ..., -9.2616e-02,
         -2.0556e-01, -2.1883e-01],
        ...,
        [ 5.9820e-04,  2.6121e-04, -8.9212e-04,  ...,  4.4282e-04,
         -4.5947e-04, -1.1509e-03],
        [ 1.9591e-04, -1.4190e-04,  5.2632e-04,  ...,  1.5511e-04,
          7.5685e-04, -5.3025e-04],
        [-5.3668e-03, -4.9665e-03,  2.4543e-03,  ...,  8.1259e-03,
          4.3659e-03, -8.2888e-03]], device='cuda:0')

In [21]:
print(torch.nonzero(dnis.feature_embeddings.weight.data).size(0))
threshold = arr[int(arr.shape[0]*0/100)]
print(f" threshold: {threshold}")
dnis.feature_embeddings.weight.data = embedding_checkpoint.clone().detach()
dnis.feature_embeddings.weight.data[dnis.feature_embeddings.weight.data.abs()<threshold]=0
loss, auc = val(dnis,test_dataloader)
print(f"pruned: {98}%, loss: {loss}, auc: {auc}")

98847296
 threshold: 5.919550960520326e-11


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

pruned: 98%, loss: 0.739168377419715, auc: 0.7655923347969603


In [22]:
dnis.train()

Dnis(
  (feature_embeddings): Embedding(1544489, 64)
  (model): DeepFM(
    (feature_biases): Embedding(1544489, 1)
    (fc1): Linear(in_features=1472, out_features=400, bias=True)
    (fc2): Linear(in_features=400, out_features=400, bias=True)
    (fc3): Linear(in_features=400, out_features=400, bias=True)
    (fc4): Linear(in_features=400, out_features=1, bias=True)
  )
)

In [23]:
criterion = torch.nn.BCEWithLogitsLoss()
def train_weights(model, batch,optimizer):
        model.train()
        # update weights and keep gradients of w'/alpha
        feature_ids, feature_vals, labels = batch
        outputs =  model(feature_ids, feature_vals)
        loss = criterion(outputs.squeeze(), labels.squeeze())
        loss.backward()
        optimizer.step()
        model.alpha.grad.data.zero_()
        return loss.data.detach().cpu().item()


In [24]:
print(torch.nonzero(dnis.feature_embeddings.weight.data).size(0))

98847296


In [25]:
#Retrain the subnetwork when the pruned ratio > 98% (Optional)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)
num_epochs = 2
print(len(train_dataloader))
print(len(val_dataloader))
print(len(test_dataloader))
def retrain(model,dataloader):
    parameters_w = [parameter for name, parameter in model.named_parameters() if 'alpha' not in name]
    #parameters_w = [parameter for name, parameter in model.named_parameters() if 'bias' not in name]
    optimizer_w = torch.optim.Adam(parameters_w, lr=0.001)
    scheduler_w = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_w, 'min', verbose=True,patience=0)
    for epoch in range(num_epochs):
            print(f"Starting epoch: {epoch} | phase: train | ⏰: {time.strftime('%H:%M:%S')}")
            model.train()
            running_loss = 0
            # if epoch == 8:
            #     self.optimizer_w.param_groups[0]['lr'] *= 0.1
            for itr, batch in tqdm(enumerate(dataloader)):
                batch = [item.to(device) for item in batch]
                feature_ids, feature_vals, labels = batch
                outputs = model(feature_ids, feature_vals).squeeze()
                loss = criterion(outputs, labels.squeeze())
                loss.backward()
                optimizer_w.step()
                model.zero_grad()
                running_loss += loss.item()
            epoch_loss = running_loss / itr
            print(f"training loss of epoch {epoch}: {epoch_loss}")
            torch.cuda.empty_cache()
            val_loss, val_auc = val(dnis,val_dataloader)
            print(f"val loss of epoch {epoch}: {val_loss}")
            print(f"val auc of epoch {epoch}: {val_auc}")
            loss, auc = val(dnis,test_dataloader)
            print(f"loss: {loss}, auc: {auc}")

6910
1975
988


In [26]:
retrain(dnis,train_dataloader)
print(torch.nonzero(dnis.feature_embeddings.weight.data).size(0))



Starting epoch: 0 | phase: train | ⏰: 21:31:13


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


0it [00:00, ?it/s]

training loss of epoch 0: 0.3780977173675623


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

val loss of epoch 0: 0.3820227979859219
val auc of epoch 0: 0.7766510880341078


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

loss: 0.38159348996665315, auc: 0.7770365541575974
Starting epoch: 1 | phase: train | ⏰: 21:40:38


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


0it [00:00, ?it/s]

training loss of epoch 1: 0.368109363509874


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

val loss of epoch 1: 0.38172210052043576
val auc of epoch 1: 0.7777114231625945


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

loss: 0.38134933993039344, auc: 0.7780102989736531
98847296


In [27]:
print(dnis.alpha.data)

tensor([[8.9068e-01, 7.4982e-01, 6.5229e-02, 2.8706e-01, 4.5175e-01, 3.8904e-02,
         6.0947e-02, 9.4172e-01, 1.1652e-01, 1.3816e-01, 1.0624e-01, 4.5850e-01,
         4.4884e-01, 7.0286e-01, 8.5268e-02, 2.9903e-01, 1.0281e-01, 1.5134e-01,
         1.5441e-01, 5.2884e-02, 7.3251e-02, 3.7504e-01, 6.9772e-02, 1.7425e-01,
         7.6491e-01, 1.4102e-01, 5.4107e-01, 4.8125e-02, 1.2655e-01, 1.0688e-01,
         4.5544e-02, 3.7616e-01, 4.8382e-01, 7.0013e-01, 5.8183e-02, 4.5206e-03,
         7.8600e-02, 1.0800e-01, 5.4182e-01, 6.2045e-01, 5.4111e-02, 1.3194e-01,
         3.8040e-01, 3.9374e-02, 4.1311e-03, 4.2428e-01, 0.0000e+00, 1.0000e+00,
         2.9202e-01, 1.1475e-01, 4.7421e-01, 8.5632e-01, 6.2379e-01, 1.5601e-01,
         3.6184e-01, 1.8883e-01, 3.6037e-02, 4.6095e-01, 2.8671e-01, 2.5716e-02,
         6.1751e-01, 3.1376e-01, 3.8785e-02, 1.2487e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0466e-02, 3.8053e-03, 4.8646e-03,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000

In [28]:
alpha_block_mask = alpha_checkpoint.repeat_interleave(dnis.feature_nums.to(dnis.alpha).long(),
                                                        dim=0).repeat_interleave(
   dnis.embed_dims.to(dnis.alpha).long(), dim=1)
                                                
#dnis.feature_embeddings.weight.data *= alpha_block_mask
#dnis.alpha.data = torch.ones_like(dnis.alpha.data).to(dnis.alpha.data)
arr = np.reshape(dnis.feature_embeddings.weight.data.abs().detach().cpu(),(-1))
arr = np.sort(arr)
embedding_checkpoint = dnis.feature_embeddings.weight.data.clone().detach()

In [29]:

threshold = arr[int(arr.shape[0]*95/100)]
print(f" threshold: {threshold}")
dnis.feature_embeddings.weight.data = embedding_checkpoint.clone().detach()
dnis.feature_embeddings.weight.data[dnis.feature_embeddings.weight.data.abs()<threshold]=0
loss, auc = val(dnis,test_dataloader)
print(f"pruned: {98}%, loss: {loss}, auc: {auc}")
print(torch.nonzero(dnis.feature_embeddings.weight.data).size(0))

 threshold: 0.2068343311548233


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


0it [00:00, ?it/s]

pruned: 98%, loss: 0.3857754762718069, auc: 0.7701850390686246
4942365
