Skip to content

Commit

Permalink
Merge pull request #3 from LiyuanLucasLiu/toparams
Browse files Browse the repository at this point in the history
Toparams
  • Loading branch information
LiyuanLucasLiu committed Oct 2, 2018
2 parents 4766928 + c7178c3 commit ea4846b
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 29 deletions.
13 changes: 10 additions & 3 deletions ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Details about LD-Net can be accessed at: https://arxiv.org/abs/1804.07827.
- [Data](#data)
- [Model](#model)
- [Command](#command)
- [Inference](#inference)
- [Citation](#citation)

## Model Notes
Expand Down Expand Up @@ -64,15 +65,15 @@ The original pre-trained named entity tagger achieves 91.95 F1, the pruned tagge

| Original Tagger | Pruned Tagger |
| ------------- |------------- |
| [Download Link](http://dmserv4.cs.illinois.edu/ner.th) | Preparing... |
| [Download Link](http://dmserv4.cs.illinois.edu/ner.th) | [Download Link](http://dmserv4.cs.illinois.edu/pner1.th) |

### Chunking

The original pre-trained named entity tagger achieves 96.13 F1, the pruned tagged achieved 95.79 F1.

| Original Tagger | Pruned Tagger |
| ------------- |------------- |
| [Download Link](http://dmserv4.cs.illinois.edu/np.th) | Preparing... |
| [Download Link](http://dmserv4.cs.illinois.edu/np.th) | [Download Link](http://dmserv4.cs.illinois.edu/pnp0.th) |

## Training

Expand Down Expand Up @@ -114,13 +115,19 @@ Our implementations are available in ```model_seq``` and ```model_word_ada```, a
| ------------- |------------- |
| [Download Link](http://dmserv4.cs.illinois.edu/ner_dataset.pk) | [Download Link](http://dmserv4.cs.illinois.edu/np_dataset.pk) |

## Inference

For model inference, please check our [LightNER package](https://github.com/LiyuanLucasLiu/LightNER)

## Citation

If you find the implementation useful, please cite the following paper: [Efficient Contextualized Representation: Language Model Pruning for Sequence Labeling](https://arxiv.org/abs/1804.07827)

```
@inproceedings{liu2018efficient,
title = "{Efficient Contextualized Representation: Language Model Pruning for Sequence Labeling}",
author = {Liu, Liyuan and Ren, Xiang and Shang, Jingbo and Peng, Jian and Han, Jiawei},
booktitle = {EMNLP},
year = 2018,
}
}
```
23 changes: 23 additions & 0 deletions model_seq/seqlabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, f_lm, b_lm,

self.f_lm = f_lm
self.b_lm = b_lm
self.unit_type = unit

self.char_embed = nn.Embedding(c_num, c_dim)
self.word_embed = nn.Embedding(w_num, w_dim)
Expand All @@ -83,6 +84,28 @@ def __init__(self, f_lm, b_lm,

self.drop = nn.Dropout(p = droprate)

def to_params(self):
"""
To parameters.
"""
return {
"model_type": "char-lstm-crf",
"forward_lm": self.f_lm.to_params(),
"backward_lm": self.b_lm.to_params(),
"word_embed_num": self.word_embed.num_embeddings,
"word_embed_dim": self.word_embed.embedding_dim,
"char_embed_num": self.char_embed.num_embeddings,
"char_embed_dim": self.char_embed.embedding_dim,
"char_hidden": self.c_hidden,
"char_layers": self.char_fw.num_layers,
"word_hidden": self.word_rnn.hidden_size,
"word_layers": self.word_rnn.num_layers,
"droprate": self.drop.p,
"y_num": self.y_num,
"label_schema": "iobes",
"unit_type": self.unit_type
}

def prune_dense_rnn(self):
"""
Prune dense rnn to be smaller by delecting layers.
Expand Down
10 changes: 10 additions & 0 deletions model_seq/seqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def __init__(self, ori_lm, backward, droprate, fix_rate):

self.backward = backward

def to_params(self):
"""
To parameters.
"""
return {
"rnn_params": self.rnn.to_params(),
"word_embed_num": self.word_embed.num_embeddings,
"word_embed_dim": self.word_embed.embedding_dim
}

def init_hidden(self):
"""
initialize hidden states.
Expand Down
59 changes: 45 additions & 14 deletions model_seq/sparse_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import torch.nn.functional as F
import model_seq.utils as utils

import torch
import torch.nn as nn
import torch.nn.functional as F

class SBUnit(nn.Module):
"""
The basic recurrent unit for the dense-RNNs wrapper.
Expand All @@ -31,7 +27,7 @@ class SBUnit(nn.Module):
def __init__(self, ori_unit, droprate, fix_rate):
super(SBUnit, self).__init__()

self.unit = ori_unit.unit
self.unit_type = ori_unit.unit_type

self.layer = ori_unit.layer

Expand Down Expand Up @@ -98,18 +94,40 @@ class SDRNN(nn.Module):
def __init__(self, ori_drnn, droprate, fix_rate):
super(SDRNN, self).__init__()

self.layer_list = [SBUnit(ori_unit, droprate, fix_rate) for ori_unit in ori_drnn.layer._modules.values()]
if ori_drnn.layer:
self.layer_list = [SBUnit(ori_unit, droprate, fix_rate) for ori_unit in ori_drnn.layer._modules.values()]

self.weight_list = nn.Parameter(torch.FloatTensor([1.0] * len(self.layer_list)))
self.weight_list.requires_grad = not fix_rate
self.weight_list = nn.Parameter(torch.FloatTensor([1.0] * len(self.layer_list)))
self.weight_list.requires_grad = not fix_rate

# self.layer = nn.Sequential(*self.layer_list)
self.layer = nn.ModuleList(self.layer_list)
# self.layer = nn.Sequential(*self.layer_list)
self.layer = nn.ModuleList(self.layer_list)

for param in self.layer.parameters():
param.requires_grad = False
else:
self.layer_list = list()
self.weight_list = list()
self.layer = None

for param in self.layer.parameters():
param.requires_grad = False
# self.output_dim = self.layer_list[-1].output_dim
self.emb_dim = ori_drnn.emb_dim
self.output_dim = ori_drnn.output_dim
self.unit_type = ori_drnn.unit_type

self.output_dim = self.layer_list[-1].output_dim
def to_params(self):
"""
To parameters.
"""
return {
"rnn_type": "LDRNN",
"unit_type": self.unit_type,
"layer_num": 0 if not self.layer else len(self.layer),
"emb_dim": self.emb_dim,
"hid_dim": -1 if not self.layer else self.layer[0].increase_rate,
"droprate": -1 if not self.layer else self.layer[0].droprate,
"after_pruned": True
}

def prune_dense_rnn(self):
"""
Expand Down Expand Up @@ -143,6 +161,7 @@ def prune_dense_rnn(self):
self.weight_list = nn.Parameter(torch.FloatTensor(new_weight_list))
self.weight_list.requires_grad = False


for param in self.layer.parameters():
param.requires_grad = False

Expand Down Expand Up @@ -227,6 +246,17 @@ def __init__(self, ori_lm, backward, droprate, fix_rate):

self.backward = backward

def to_params(self):
"""
To parameters.
"""
return {
"backward": self.backward,
"rnn_params": self.rnn.to_params(),
"word_embed_num": self.word_embed.num_embeddings,
"word_embed_dim": self.word_embed.embedding_dim
}

def prune_dense_rnn(self):
"""
Prune dense rnn to be smaller by delecting layers.
Expand Down Expand Up @@ -282,4 +312,5 @@ def forward(self, w_in, ind=None):
out_size = out.size()
out = out.view(out_size[0] * out_size[1], out_size[2]).index_select(0, ind).contiguous().view(out_size)

return out
return out

17 changes: 16 additions & 1 deletion model_word_ada/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BasicUnit(nn.Module):
def __init__(self, unit, input_dim, hid_dim, droprate):
super(BasicUnit, self).__init__()

self.unit_type = unit
rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}

self.batch_norm = (unit == 'bnlstm')
Expand Down Expand Up @@ -97,9 +98,23 @@ def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
layer_list = [BasicUnit(unit, emb_dim, hid_dim, droprate)] + [BasicUnit(unit, hid_dim, hid_dim, droprate) for i in range(layer_num - 1)]
self.layer = nn.Sequential(*layer_list)
self.output_dim = layer_list[-1].output_dim

self.unit_type = unit

self.init_hidden()

def to_params(self):
"""
To parameters.
"""
return {
"rnn_type": "Basic",
"unit_type": self.layer[0].unit_type,
"layer_num": len(self.layer),
"emb_dim": self.layer[0].layer.input_size,
"hid_dim": self.layer[0].layer.hidden_size,
"droprate": self.layer[0].droprate
}

def init_hidden(self):
"""
Initialize hidden states.
Expand Down
25 changes: 20 additions & 5 deletions model_word_ada/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __init__(self, unit, input_dim, increase_rate, droprate):

rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}

self.unit = unit
self.unit_type = unit

self.layer = rnnunit_map[unit](input_dim, increase_rate, 1)

if 'lstm' == self.unit:
if 'lstm' == self.unit_type:
utils.init_lstm(self.layer)

self.droprate = droprate
Expand Down Expand Up @@ -102,13 +102,28 @@ class DenseRNN(nn.Module):
"""
def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
super(DenseRNN, self).__init__()


self.unit_type = unit
self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate) for i in range(layer_num)]
self.layer = nn.Sequential(*self.layer_list)
self.output_dim = self.layer_list[-1].output_dim
self.layer = nn.Sequential(*self.layer_list) if layer_num > 0 else None
self.output_dim = self.layer_list[-1].output_dim if layer_num > 0 else emb_dim
self.emb_dim = emb_dim

self.init_hidden()

def to_params(self):
"""
To parameters.
"""
return {
"rnn_type": "DenseRNN",
"unit_type": self.layer[0].unit_type,
"layer_num": len(self.layer),
"emb_dim": self.layer[0].input_dim,
"hid_dim": self.layer[0].increase_rate,
"droprate": self.layer[0].droprate
}

def init_hidden(self):
"""
Initialize hidden states.
Expand Down
26 changes: 21 additions & 5 deletions model_word_ada/ldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __init__(self, unit, input_dim, increase_rate, droprate, layer_drop = 0):

rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}

self.unit = unit
self.unit_type = unit

self.layer = rnnunit_map[unit](input_dim, increase_rate, 1)

if 'lstm' == self.unit:
if 'lstm' == self.unit_type:
utils.init_lstm(self.layer)

self.layer_drop = layer_drop
Expand Down Expand Up @@ -121,14 +121,30 @@ class LDRNN(nn.Module):
def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate, layer_drop):
super(LDRNN, self).__init__()

self.unit_type = unit
self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate, layer_drop) for i in range(layer_num)]

self.layer_num = layer_num
self.layer = nn.ModuleList(self.layer_list)
self.output_dim = self.layer_list[-1].output_dim

self.layer = nn.ModuleList(self.layer_list) if layer_num > 0 else None
self.output_dim = self.layer_list[-1].output_dim if layer_num > 0 else emb_dim
self.emb_dim = emb_dim

self.init_hidden()

def to_params(self):
"""
To parameters.
"""
return {
"rnn_type": "LDRNN",
"unit_type": self.layer[0].unit_type,
"layer_num": len(self.layer),
"emb_dim": self.layer[0].input_dim,
"hid_dim": self.layer[0].increase_rate,
"droprate": self.layer[0].droprate,
"after_pruned": False
}

def init_hidden(self):
"""
Initialize hidden states.
Expand Down
12 changes: 11 additions & 1 deletion prune_sparse_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,16 @@

seq_model.cpu()
pw.info('Saving model...')
pw.save_checkpoint(model = seq_model, is_best = True)

seq_config = seq_model.to_params()

pw.save_checkpoint(model = seq_model,
is_best = True,
s_dict = {'config': seq_config,
'flm_map': flm_map,
'blm_map': blm_map,
'gw_map': gw_map,
'c_map': c_map,
'y_map': y_map})

pw.close()

0 comments on commit ea4846b

Please sign in to comment.