In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/Colab Notebooks/DL_NLP/term_proj/

#Install Requirement



In [None]:
pip install -r requirements.txt

#Packages

In [None]:
import numpy as np
import torch 
import json
import pickle
import os
import torch
import random
import errno
import collections
from collections import Counter
import tensorflow as tf
from sklearn.model_selection import train_test_split
from chainer.dataset import convert
from torch.optim import Adam,lr_scheduler,SGD,RMSprop
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from torchtext import data
from __future__ import division
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as weight_init
import torch.nn.functional as F
import sys
from tqdm import tqdm
import unicodedata
from tensorboardX import SummaryWriter
from collections import defaultdict
from sklearn.utils import compute_class_weight
import math
import string
import io
import six
import pprint
from metrics import ConfusionMatrix
from argparse import ArgumentParser
from datetime import datetime
from collections import namedtuple
import gensim.downloader
import logging
import time
import itertools
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [None]:
device=torch.device("cuda:0")

#Load DATA and Embedding

In [None]:
#Preprocess Data
def get_data(file_name,max_nv):
    f = open(file_name,'r')
    count = Counter()
    labels = []
    texts = []
    stop_words = set(stopwords.words('english'))
    lemmatizer = WordNetLemmatizer()
    for line in f:
      line = line.split("\t")
      labels.append(int(line[0]))
      text = line[1]
      text = text.lower()
      text = text.replace('quot','')
      text = ''.join([i for i in text if not i.isdigit()])
      tokens = word_tokenize(text)
      temp = [i for i in tokens if not i in stop_words]
      if (len(temp)!=0):
        text = temp
      text = [lemmatizer.lemmatize(w) for w in text]
      for word in text:
        count[word] += 1
      texts.append(text)
    vocab = [word for (word, _) in count.most_common(max_nv)]
    f.close()
    return vocab,labels,texts

In [None]:
#Read Data and make vocab
train_vocab , train_label, train_texts = get_data('./data/aclImdb_tok/train.txt',80000)
test_vocab , test_label, test_texts = get_data('./data/aclImdb_tok/test.txt',80000)
unlabel_vocab , _, unlabel_texts = get_data('./data/aclImdb_tok/unlabel.txt',80000)
# train_vocab , train_label, train_texts = get_data('./data/Agnews/train.txt',75000)
# test_vocab , test_label, test_texts = get_data('./data/Agnews/test.txt',75000)
# unlabel_vocab , _, unlabel_texts = get_data('./data/Agnews/unlabel.txt',75000)
vocab = list(set(train_vocab + unlabel_vocab))
vocab = vocab + ['<pad>', '<eos>', '<unk>', '<bos>']
w2id = {word: index for index, word in enumerate(vocab)}
id2w = {i: w for w, i in w2id.items()}

In [None]:
#Load saved wordvector matrix
wordvector = np.load('./data/demo.word_vectors.npy')

In [None]:
#Load saved id to word maping
id2w = pickle.load(open('id2w.pickle','rb'))
w2id = {w: i for i, w in id2w.items()}

In [None]:
#Prepare data
Special_Seq = namedtuple('Special_Seq', ['PAD', 'EOS', 'UNK', 'BOS'])
Vocab_Pad = Special_Seq(PAD=0, EOS=1, UNK=2, BOS=3)

def make_dataset(text, w2id):
    dataset = []
    for line in text:
        array = np.asarray([w2id.get(word, Vocab_Pad.UNK) for word in line])
        dataset.append(array)
    return dataset

#Label training data
temp = make_dataset(train_texts, w2id)
train_data = [(l, s) for l, s in six.moves.zip(train_label, temp)]
train_data,dev_data = train_test_split(train_data,test_size=0.1)

#Unlabel training data
temp = make_dataset(unlabel_texts, w2id)
unlabel_data = [(l, s) for l, s in six.moves.zip(_, temp)]

#Test data
temp = make_dataset(test_texts, w2id)
test_data = [(l, s) for l, s in six.moves.zip(test_label, temp)]

In [None]:
#Download pretrained word embedding while training first time
embedding = gensim.downloader.load('word2vec-google-news-300')
# embedding = gensim.downloader.load('fasttext-wiki-news-subwords-300')

In [None]:
#Prepare wordembedding matrix while training first time
wordvector = []
for i in range(len(id2w)):
    word = id2w[i]
    if word in embedding:
      wordvector.append(embedding[word])
    else:
      wordvector.append(np.random.uniform(-0.25, 0.25, 300))

wordvector = np.array(wordvector,dtype=np.float32)

In [None]:
#save id 2 word maping
with open('id2w.pickle', 'wb') as f:
    pickle.dump(id2w, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
#save wordvector matrix
 np.save(os.path.join('data', 'demo' + '.word_vectors.npy'), wordvector)

#Definations

In [None]:
FLOAT_TYPE = torch.cuda.FloatTensor
INT_TYPE = torch.cuda.IntTensor
LONG_TYPE = torch.cuda.LongTensor
BYTE_TYPE = torch.cuda.ByteTensor

Batch = collections.namedtuple('Batch', ['batch_size','labels','word_ids','sent_len'])

def ensure_directory(directory):
    directory = os.path.expanduser(directory)
    try:
        os.makedirs(directory)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise e


def batch_size_fn(new, count, sofar):
    if count == 1:
        max_src_in_batch = 0
    max_src_in_batch = max(max_src_in_batch, len(new[1]) + 2)
    src_elements = count * max_src_in_batch
    return src_elements

def get_accuracy(cm, output, target):
    batch_size = output.size(0)
    predictions = output.max(-1)[1].type_as(target)
    correct = predictions.eq(target)
    correct = correct.float()
    if not hasattr(correct, 'sum'):
        correct = correct.cpu()
    correct = correct.sum()
    cm.add_batch(target.cpu().numpy(), predictions.cpu().numpy())
    return correct

def _get_trainabe_modules():
    param_list = list(embedder.parameters()) + list(encoder.parameters()) + list(clf.parameters())
    if lambda_ae > 0:
        param_list += list(ae.parameters())
    return param_list

def at_loss(embedder, encoder, clf, batch, perturb_norm_length=5.0):
    embedded = embedder(batch)
    embedded.retain_grad()
    ce = F.cross_entropy((clf(encoder(embedded, batch)[0])), batch.labels)
    ce.backward()

    d = embedded.grad.data.transpose(0, 1).contiguous()
    d = get_normalized_vector(d)
    d = d.transpose(0, 1).contiguous()

    d = embedder(batch) + (perturb_norm_length * Variable(d))
    loss = F.cross_entropy(clf(encoder(d, batch)[0]), batch.labels)
    return loss


def get_normalized_vector(d):
    B, T, D = d.shape
    d = d.view(B, -1)
    d /= (1e-12 + torch.max(torch.abs(d), dim=1, keepdim=True)[0])

    d /= torch.sqrt(1e-6 + torch.sum(d**2, dim=1, keepdim=True))
    d = d.view(B, T, D)
    return d

def rnn_factory(rnn_type, **kwargs):
  
    no_pack_padded_seq = False
    rnn = getattr(nn, rnn_type)(**kwargs)
    return rnn, no_pack_padded_seq


def kl_categorical(p_logit, q_logit):
    p = F.softmax(p_logit, dim=-1)
    _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) -
                         F.log_softmax(q_logit, dim=-1)), 1)
    return torch.mean(_kl) # F.sum(_kl) / xp.prod(xp.array(_kl.shape))


def vat_loss(embedder, encoder, clf, batch, perturb_norm_length=5.0,
             small_constant_for_finite_diff=1e-1, Ip=1, p_logit=None):
    embedded = embedder(batch)
    d = torch.randn(embedded.shape).type(FLOAT_TYPE)
    d = d.transpose(0, 1).contiguous()
    d = get_normalized_vector(d).transpose(0, 1).contiguous()
    for ip in range(Ip):
        x_d = Variable(embedded.data + (small_constant_for_finite_diff * d), requires_grad=True)
        x_d.retain_grad()
        p_d_logit = clf(encoder(x_d, batch)[0])
        kl_loss = kl_categorical(Variable(p_logit.data), p_d_logit)
        kl_loss.backward()
        d = x_d.grad.data.transpose(0, 1).contiguous()
        d = get_normalized_vector(d).transpose(0, 1).contiguous()
    x_adv = embedded + (perturb_norm_length * Variable(d))
    p_adv_logit = clf(encoder(x_adv, batch)[0])
    return kl_categorical(Variable(p_logit.data), p_adv_logit)


def entropy_loss(p_logit):
    p = F.softmax(p_logit, dim=-1)
    return -1 * torch.sum(p * F.log_softmax(p_logit, dim=-1)) / p_logit.size()[0]


def seq_func(func, x, reconstruct_shape=True, pad_remover=None):
    batch, length, units = x.shape
    e = x.view(batch * length, units)
    if pad_remover:
        e = pad_remover.remove(e)
    e = func(e)
    if pad_remover:
        e = pad_remover.restore(e)
    if not reconstruct_shape:
        return e
    out_units = e.shape[1]
    e = e.view(batch, length, out_units)
    assert (e.shape == (batch, length, out_units))
    return e

def _linear(in_sz, out_sz, unif):
    l = nn.Linear(in_sz, out_sz)
    weight_init.xavier_uniform(l.weight.data)
    return l


def _append2seq(seq, modules):
    for module_ in modules:
        seq.add_module(str(module_), module_)


def binary_cross_entropy(x, y, smoothing=0., epsilon=1e-12):
    y = y.float()
    if smoothing > 0:
        smoothing *= 2
        y = y * (1 - smoothing) + 0.5 * smoothing
    return -torch.mean(
        torch.log(x + epsilon) * y + torch.log(1.0 - x + epsilon) * (1 - y))
    
   

def aeq(*args):
    arguments = (arg for arg in args)
    first = next(arguments)
    assert all(arg == first for arg in arguments), \
        "Not all arguments have the same value: " + str(args)


def sequence_mask(lengths, max_len=None):
    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return (torch.arange(0, max_len)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))
    

def embedded_dropout(embed, words, dropout=0.1, scale=None):
    if dropout:
        mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
        mask = torch.autograd.Variable(mask)
        masked_embed_weight = mask * embed.weight
    else:
        masked_embed_weight = embed.weight
    if scale:
        masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight

    padding_idx = embed.padding_idx
    if padding_idx is None:
        padding_idx = -1

    X = F.embedding(words,masked_embed_weight,padding_idx,embed.max_norm,
                    embed.norm_type,embed.scale_grad_by_freq,embed.sparse)
    return X


def embedded_adaptive_dropout(embed, words, dropout=0.1, scale=None, is_training=True):
    if is_training:
        mask = embed.weight.data.new().resize_(
            (embed.weight.size(0), 1)).bernoulli_(1 - dropout)
        mask = (mask / torch.unsqueeze((1 - dropout), -1)).expand_as(
            embed.weight)
        mask = torch.autograd.Variable(mask)
        masked_embed_weight = mask * embed.weight
    else:
        masked_embed_weight = embed.weight
    if scale:
        masked_embed_weight = scale.expand_as(
            masked_embed_weight) * masked_embed_weight

    padding_idx = embed.padding_idx
    if padding_idx is None:
        padding_idx = -1

    X = embed._backend.Embedding.apply(words,masked_embed_weight,padding_idx,
                                       embed.max_norm,embed.norm_type,embed.scale_grad_by_freq,embed.sparse)
    return X

def seq_pad_concat(batch, device):
    labels, word_ids = zip(*batch)

    block_w = convert.concat_examples(word_ids,device,padding=Vocab_Pad.PAD)

    sent_len = np.array(list(map(lambda x: len(x), word_ids)))
    block_w = Variable(torch.LongTensor(block_w).type(LONG_TYPE),
                       requires_grad=False)
    labels = Variable(torch.LongTensor(labels).type(LONG_TYPE),
                      requires_grad=False)

    return Batch(batch_size=len(labels),word_ids=block_w.transpose(0, 1).contiguous(),labels=labels,sent_len=sent_len)


def seq2seq_pad_concat(ly_batch,device,eos_id=Vocab_Pad.EOS, bos_id=Vocab_Pad.BOS):
    labels, y_seqs = zip(*ly_batch)
    y_block = convert.concat_examples(y_seqs, device, padding=0)

    y_out_block = np.pad(y_block, ((0, 0), (0, 1)), 'constant', constant_values=0)
    for i_batch, seq in enumerate(y_seqs):
        y_out_block[i_batch, len(seq)] = eos_id
    y_in_block = np.pad(y_block, ((0, 0), (1, 0)), 'constant', constant_values=bos_id)

    y_in_block = Variable(torch.LongTensor(y_in_block).type(LONG_TYPE),
                          requires_grad=False)
    y_out_block = Variable(torch.LongTensor(y_out_block).type(LONG_TYPE),
                           requires_grad=False)
    return y_in_block, y_out_block

def long_0_tensor_alloc(nelements, dtype=None):
    lt = long_tensor_alloc(nelements)
    lt.zero_()
    return lt


def long_tensor_alloc(dims, dtype=None):
    if type(dims) == int or len(dims) == 1:
        return torch.LongTensor(dims)
    return torch.LongTensor(*dims)


def print_sentence(logger, data):

    spacings = [max([len(seq[i]) for seq in data.itervalues()]) for i in range(len(data[data.keys()[0]]))]
    for key, seq in data.iteritems():
        # logger.info("{} : ".format(key))
        to_print = ""
        for token, spacing in zip(seq, spacings):
            to_print += token + " " * (spacing - len(token) + 1)
        logger.info(to_print)


def get_logger(filename):
    logger = logging.getLogger('logger')
    logger.setLevel(logging.DEBUG)
    logging.basicConfig(format='%(message)s', level=logging.DEBUG)
    handler = logging.FileHandler(filename)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
    logging.getLogger().addHandler(handler)
    return logger

def batch_size_fn(new, count, sofar):
    global max_src_in_batch
    if count == 1:
        max_src_in_batch = 0
    max_src_in_batch = max(max_src_in_batch, len(new[1]) + 2)
    src_elements = count * max_src_in_batch
    return src_elements

def report_func(epoch, batch, num_batches, start_time, report_stats,
                report_every, logger):
    if batch % report_every == -1 % report_every:
        report_stats.output(epoch, batch + 1, num_batches, start_time, logger)

#Models

In [None]:
class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config
        self.rnn = nn.LSTM(input_size=config.encoder_input_size,
                           hidden_size=config.d_hidden,
                           num_layers=config.num_layers,
                           dropout=config.lstm_dropout,
                           bidirectional=config.brnn)

    def forward(self, inputs, batch_size):
        memory_bank, encoder_final = self.rnn(inputs)
        return memory_bank, encoder_final

In [None]:
class LstmPadding(object):
    def __init__(self, sent, sent_len, config):
        self.batch_size = len(sent_len)
        self.max_sent_len = max(sent_len)
        sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
        self.idx_unsort = np.argsort(idx_sort)
        self.config = config

        idx_sort = torch.from_numpy(idx_sort).type(LONG_TYPE)
        sent = sent.index_select(1, Variable(idx_sort))

        self.sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len.copy())

    def __call__(self, lstm_enc_func):
        idx_unsort = torch.from_numpy(self.idx_unsort). \
            type(LONG_TYPE)
        memory_bank, enc_final = lstm_enc_func(self.sent_packed, self.batch_size)

        enc_final = enc_final[0].index_select(1, Variable(idx_unsort)), \
                    enc_final[1].index_select(1, Variable(idx_unsort))

        memory_bank = nn.utils.rnn.pad_packed_sequence(memory_bank)[0]
        memory_bank = memory_bank.index_select(1, Variable(idx_unsort))
        memory_bank = memory_bank.transpose(0, 1).contiguous()
        return memory_bank, enc_final


In [None]:
class Progbar(object):
    def __init__(self, target, width=30, verbose=1):
        self.width = width
        self.target = target
        self.sum_values = {}
        self.unique_values = []
        self.start = time.time()
        self.total_width = 0
        self.seen_so_far = 0
        self.verbose = verbose

    def update(self, current, values=[], exact=[], strict=[]):

        for k, v in values:
            if k not in self.sum_values:
                self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far]
                self.unique_values.append(k)
            else:
                self.sum_values[k][0] += v * (current - self.seen_so_far)
                self.sum_values[k][1] += (current - self.seen_so_far)
        for k, v in exact:
            if k not in self.sum_values:
                self.unique_values.append(k)
            self.sum_values[k] = [v, 1]

        for k, v in strict:
            if k not in self.sum_values:
                self.unique_values.append(k)
            self.sum_values[k] = v

        self.seen_so_far = current

        now = time.time()
        if self.verbose == 1:
            prev_total_width = self.total_width
            sys.stdout.write("\b" * prev_total_width)
            sys.stdout.write("\r")

            numdigits = int(np.floor(np.log10(self.target))) + 1
            barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
            bar = barstr % (current, self.target)
            prog = float(current)/self.target
            prog_width = int(self.width*prog)
            if prog_width > 0:
                bar += ('='*(prog_width-1))
                if current < self.target:
                    bar += '>'
                else:
                    bar += '='
            bar += ('.'*(self.width-prog_width))
            bar += ']'
            sys.stdout.write(bar)
            self.total_width = len(bar)

            if current:
                time_per_unit = (now - self.start) / current
            else:
                time_per_unit = 0
            eta = time_per_unit*(self.target - current)
            info = ''
            if current < self.target:
                info += ' - ETA: %ds' % eta
            else:
                info += ' - %ds' % (now - self.start)
            for k in self.unique_values:
                if type(self.sum_values[k]) is list:
                    info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
                else:
                    info += ' - %s: %s' % (k, self.sum_values[k])

            self.total_width += len(info)
            if prev_total_width > self.total_width:
                info += ((prev_total_width-self.total_width) * " ")

            sys.stdout.write(info)
            sys.stdout.flush()

            if current >= self.target:
                sys.stdout.write("\n")

        if self.verbose == 2:
            if current >= self.target:
                info = '%ds' % (now - self.start)
                for k in self.unique_values:
                    info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
                sys.stdout.write(info + "\n")

    def add(self, n, values=[]):
        self.update(self.seen_so_far+n, values)


In [None]:
class RNNDecoderBase(nn.Module):
    def __init__(self, rnn_type, bidirectional_encoder, num_layers,
                 hidden_size, attn_type="general", dropout=0.0,
                 embeddings=None):
        super(RNNDecoderBase, self).__init__()

        self.decoder_type = 'rnn'
        self.bidirectional_encoder = bidirectional_encoder
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embeddings = embeddings
        self.dropout = nn.Dropout(dropout)

        self.rnn = self._build_rnn(rnn_type,
                                   input_size=self._input_size,
                                   hidden_size=hidden_size,
                                   num_layers=num_layers,
                                   dropout=dropout)
        self.attn = GlobalAttention(hidden_size,
                                    attn_type=attn_type)

    def forward(self, tgt, memory_bank, state, memory_lengths=None):
        assert isinstance(state, RNNDecoderState)
        tgt_len, tgt_batch = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)
  
        decoder_final, decoder_outputs, attns = self._run_forward_pass(
            tgt, memory_bank, state, memory_lengths=memory_lengths)

        final_output = decoder_outputs[-1]
        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        state.update_state(decoder_final, final_output.unsqueeze(0), coverage)

        return decoder_outputs, state, attns

    def init_decoder_state(self, encoder_final):
        def _fix_enc_hidden(h):
            if self.bidirectional_encoder:
                h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
            return h

        if isinstance(encoder_final, tuple):  # LSTM
            return RNNDecoderState(self.hidden_size,
                                   tuple([_fix_enc_hidden(enc_hid)
                                         for enc_hid in encoder_final]))
        else: 
            return RNNDecoderState(self.hidden_size,
                                   _fix_enc_hidden(encoder_final))


In [None]:
class ExponentialMovingAverage(object):
    def __init__(self, decay=0.999):
        self.decay = decay
        self.num_updates = 0
        self.shadow_variable_dict = {}

    def register(self, var_list):
        for name, param in var_list.items():
            self.shadow_variable_dict[name] = param.clone()

    def apply(self, var_list):
        self.num_updates += 1
        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
        for name, param in var_list:
            if param.requires_grad:
                assert name in self.shadow_variable_dict
                data = self.shadow_variable_dict[name]
                data -= (1 - decay) * (data - param.data.clone())

class SeqLabelReader(object):
    def __init__(self):
        pass

    def build_vocab(self, files, **kwargs):
        pass

    def load(self, filename, index, batchsz, **kwargs):
        pass


class TSVSeqLabelReader(SeqLabelReader):
    def __init__(self, mxlen=1000, mxfiltsz=0, vec_alloc=np.zeros):
        super(TSVSeqLabelReader, self).__init__()

        self.vocab = None
        self.label2index = {}
        self.mxlen = mxlen
        self.mxfiltsz = mxfiltsz
        self.vec_alloc = vec_alloc

    @staticmethod
    def splits(text):
        return text.lower().split()

    @staticmethod
    def label_and_sentence(line):
        label_text = line.strip().lower().split('\t')
        label = label_text[0]
        text = label_text[1:]
        text = ' '.join(text)
        return label, text

    def build_vocab(self, files, **kwargs):
        label_idx = len(self.label2index)
        if type(files) == str:
            if os.path.isdir(files):
                base = files
                files = filter(os.path.isfile, [os.path.join(base, x) for x in os.listdir(base)])
            else:
                files = [files]

        y = list()
        vocab = Counter()
        for file in files:
            if file is None:
                continue
            with io.open(file, encoding='utf-8', errors='ignore') as f:
                for line in tqdm(f):
                    label, text = TSVSeqLabelReader.label_and_sentence(line)
                    if label not in self.label2index:
                        self.label2index[label] = label_idx
                        label_idx += 1
                    for w in TSVSeqLabelReader.splits(text):
                        vocab[w] += 1
                    y.append(self.label2index[label])

        if kwargs.get("class_weight") == "balanced":
            class_weight = compute_class_weight("balanced", list(self.label2index.values()), y)
        else:
            class_weight = None

        return vocab, self.get_labels(), class_weight

    def get_labels(self):
        labels = [''] * len(self.label2index)
        for label, index in self.label2index.items():
            labels[index] = label
        return labels

    def load(self, filename, index, batchsz, **kwargs):
        PAD = index['<PAD>']
        shuffle = kwargs.get('shuffle', False)
        halffiltsz = self.mxfiltsz // 2
        nozplen = self.mxlen - 2 * halffiltsz

        examples = []
        with io.open(filename, encoding='utf-8', errors='ignore') as f:
            for offset, line in enumerate(tqdm(f)):
                label, text = TSVSeqLabelReader.label_and_sentence(line)
                y = self.label2index[label]
                toks = TSVSeqLabelReader.splits(text)
                mx = min(len(toks), nozplen)
                toks = toks[:mx]
                x = self.vec_alloc(self.mxlen, dtype=int)
                for j in range(len(toks)):
                    w = toks[j]
                    key = index.get(w, PAD)
                    x[j + halffiltsz] = key
                examples.append((x, y))

        return SeqLabelDataFeed(SeqLabelExamples(examples),batchsz=batchsz,
                                shuffle=shuffle,vec_alloc=self.vec_alloc,
                                src_vec_trans=None)


class SeqLabelExamples(object):
    SEQ = 0
    LABEL = 1

    def __init__(self, example_list, do_shuffle=True):
        self.example_list = example_list
        if do_shuffle:
            random.shuffle(self.example_list)

    def __getitem__(self, i):
        ex = self.example_list[i]
        return ex[SeqLabelExamples.SEQ], ex[SeqLabelExamples.LABEL]

    def __len__(self):
        return len(self.example_list)

    def width(self):
        x, y = self.example_list[0]
        return len(x)

    def batch(self, start, batchsz, vec_alloc=np.empty):
        siglen = self.width()
        xb = vec_alloc((batchsz, siglen), dtype=np.int)
        yb = vec_alloc((batchsz), dtype=np.int)
        sz = len(self.example_list)
        idx = start * batchsz
        for i in range(batchsz):
            if idx >= sz:
                # idx = 0
                batchsz = i
                break
            x, y = self.example_list[idx]
            xb[i] = x
            yb[i] = y
            idx += 1
        return xb[: batchsz], yb[: batchsz]

    @staticmethod
    def valid_split(data, splitfrac=0.15):
        numinst = len(data.examples)
        heldout = int(math.floor(numinst * (1 - splitfrac)))
        heldout_ex = data.example_list[1:heldout]
        rest_ex = data.example_list[heldout:]
        return SeqLabelExamples(heldout_ex), SeqLabelExamples(rest_ex)


class DataFeed(object):
    def __init__(self):
        self.steps = 0
        self.shuffle = False

    def _batch(self, i):
        pass

    def __getitem__(self, i):
        return self._batch(i)

    def __iter__(self):
        shuffle = np.random.permutation(np.arange(self.steps)) if self.shuffle else np.arange(self.steps)

        for i in range(self.steps):
            si = shuffle[i]
            yield self._batch(si)

    def __len__(self):
        return self.steps


class ExampleDataFeed(DataFeed):

    def __init__(self, examples, batchsz, **kwargs):
        super(ExampleDataFeed, self).__init__()

        self.examples = examples
        self.batchsz = batchsz
        self.shuffle = bool(kwargs.get('shuffle', False))
        self.vec_alloc = kwargs.get('vec_alloc', np.zeros)
        self.vec_shape = kwargs.get('vec_shape', np.shape)
        self.src_vec_trans = kwargs.get('src_vec_trans', None)
        # self.steps = int(math.floor(len(self.examples) / float(batchsz)))
        self.steps = int(math.ceil(len(self.examples) / float(batchsz)))
        self.trim = bool(kwargs.get('trim', False))


class SeqLabelDataFeed(ExampleDataFeed):
    def __init__(self, examples, batchsz, **kwargs):
        super(SeqLabelDataFeed, self).__init__(examples, batchsz, **kwargs)

    def _batch(self, i):
        x, y = self.examples.batch(i, self.batchsz, self.vec_alloc)
        if self.src_vec_trans is not None:
            x = self.src_vec_trans(x)
        return x, y

class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        return

    def _setup(self)
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw',
                                           torch.nn.Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout,
                                                   training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)



class AdaptiveDropout(torch.nn.Module):
    def __init__(self):
        super(AdaptiveDropout, self).__init__()

    def forward(self, input, p):
        if self.training:
            p = 1. - p.data
            temp = torch.rand(input.size()).cuda() < p
            temp = torch.autograd.Variable(temp.type_as(p) / p)
            input = torch.mul(input, temp)
            return input
        else:
            return input


class LockedDropout(torch.nn.Module):
    def __init__(self, dropout=None):
        super(LockedDropout, self).__init__()
        self.dropout = dropout

    def forward(self, x):
        if not self.training or not self.dropout:
            return x
        m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout)
        mask = torch.autograd.Variable(m, requires_grad=False) / (
                    1 - self.dropout)
        mask = mask.expand_as(x)
        return mask * x

class SequenceCriteria(nn.Module):
    def __init__(self, class_weight):
        super(SequenceCriteria, self).__init__()
        self.criteria = nn.CrossEntropyLoss(weight=class_weight)

    def forward(self, inputs, targets):
        # This is BxT, which is what we want!
        loss = self.criteria(inputs, targets)
        return loss


class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.config = config
        seq_in_size = config.d_hidden
        if config.brnn:
            seq_in_size *= 2
      
        layers = [nn.Dropout(0.3),
                  nn.Linear(seq_in_size, 1024),
                  nn.LeakyReLU()]
        for _ in range(config.num_discriminator_layers - 1):
            layers.append(nn.Dropout(0.3))
            layers.append(nn.Linear(1024, 1024))
            layers.append(nn.LeakyReLU())
        layers.append(nn.Dropout(0.3))
        layers.append(nn.Linear(1024, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x, sequence_lengths, ids_, smoothing=0.0):
        B, T, D = x.shape
        y = self.model(torch.max(x.transpose(1, 2).contiguous(), 2)[0])
        y = F.sigmoid(y)
        loss = binary_cross_entropy(y, ids_, smoothing=smoothing)
        return loss

class Classifier(nn.Module):
    def __init__(self, config):
        super(Classifier, self).__init__()
        self.config = config
        seq_in_size = config.d_hidden
        if config.brnn:
            seq_in_size *= 2
        if config.down_projection:
            self.down_projection = _linear(seq_in_size,
                                           config.d_down_proj,
                                           config.init_scalar)
            self.act = nn.ReLU()
            seq_in_size = config.d_down_proj
        self.clf = _linear(seq_in_size, config.num_classes, config.init_scalar)

    def forward(self, x):
        x = x.transpose(1, 2).contiguous()
        if self.config.pool_type == "max_pool":
            sent_output = torch.max(x, 2)[0]
        elif self.config.pool_type == "avg_pool":
            normalize = 1. / np.sqrt(self.max_sent_len)
            sent_output = torch.sum(x, 2).mul_(normalize)
        if self.config.down_projection:
            sent_output = self.act(self.down_projection(sent_output))
        logits = self.clf(sent_output)
        return logits


class Embedder(nn.Module):
    def __init__(self, config):
        super(Embedder, self).__init__()
        self.config = config
        self.embedder = nn.Embedding(config.n_vocab,
                                     config.d_units,
                                     padding_idx=Vocab_Pad.PAD)
        
        if config.use_pretrained_embeddings:
            print("Loading pre-trained word vectors")
            embeddings = wordvector
            self.embedder.weight = torch.nn.Parameter(torch.from_numpy(embeddings),
                                                      requires_grad=config.train_embeddings)
        if config.adaptive_dropout:
            self.word_dropout = LockedDropout(dropout=config.locked_dropout)
        else:
            self.word_dropout = nn.Dropout(p=config.word_dropout)

    def _normalize(self, emb):
        weights = self.vocab_freqs / torch.sum(self.vocab_freqs)
        weights = weights.unsqueeze(-1)
        mean = torch.sum(weights * emb, 0, keepdim=True)
        var = torch.sum(weights * torch.pow(emb - mean, 2.), 0, keepdim=True)
        stddev = torch.sqrt(1e-6 + var)
        return (emb - mean) / stddev

    def forward(self, batch):
        if self.config.normalize_embedding:
            self.embedder.weight.data = self._normalize(self.embedder.weight.data)

        if self.config.adaptive_dropout:
            word_embedding = embedded_dropout(self.embedder,
                                          batch.word_ids,
                                          dropout=self.config.word_dropout
                                          if self.training else 0)
            dropped = self.word_dropout(word_embedding.transpose(0, 1).contiguous()).transpose(0, 1).contiguous()
        else:
            word_embedding = self.embedder(batch.word_ids)
            dropped = self.word_dropout(word_embedding)
        return dropped

class LSTMEncoder(nn.Module):
    def __init__(self, config):
        super(LSTMEncoder, self).__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        if config.projection:
            self.projection = nn.Linear(config.d_units, config.d_proj)
            self.act1 = nn.ReLU()
        config.encoder_input_size = config.d_proj \
            if config.projection else config.d_units

        self.lstm_encoder = Encoder(config)
        seq_in_size = config.d_hidden
        if config.brnn:
            seq_in_size *= 2
        self.lstm_dropout = nn.Dropout(config.lstm_dropout)
        self.config = config

    def encode_sent(self, embedded, sent_len):
        if self.config.projection:
            embedded = self.act1(self.projection(embedded))
        memory_bank, encoder_final = LstmPadding(embedded,
                                                 sent_len,
                                                 self.config)(self.lstm_encoder)
        return memory_bank, encoder_final

    def forward(self, embedded, batch, *args, **kwargs):
        memory_bank, encoder_final = self.encode_sent(embedded,
                                                      batch.sent_len)
        memory_bank = self.lstm_dropout(memory_bank)
        return memory_bank, encoder_final


class DecoderState(object):
    def detach(self):
        for h in self._all:
            if h is not None:
                h.detach_()

    def beam_update(self, idx, positions, beam_size):
        for e in self._all:
            sizes = e.size()
            br = sizes[1]
            if len(sizes) == 3:
                sent_states = e.view(sizes[0], beam_size, br // beam_size,
                                     sizes[2])[:, :, idx]
            else:
                sent_states = e.view(sizes[0], beam_size,
                                     br // beam_size,
                                     sizes[2],
                                     sizes[3])[:, :, idx]

            sent_states.data.copy_(
                sent_states.data.index_select(1, positions))


class RNNDecoderState(DecoderState):
    def __init__(self, hidden_size, rnnstate):
        if not isinstance(rnnstate, tuple):
            self.hidden = (rnnstate,)
        else:
            self.hidden = rnnstate
        self.coverage = None

        # Init the input feed.
        batch_size = self.hidden[0].size(1)
        h_size = (batch_size, hidden_size)
        self.input_feed = Variable(self.hidden[0].data.new(*h_size).zero_(),
                                   requires_grad=False).unsqueeze(0)

    @property
    def _all(self):
        return self.hidden + (self.input_feed,)

    def update_state(self, rnnstate, input_feed, coverage):
        if not isinstance(rnnstate, tuple):
            self.hidden = (rnnstate,)
        else:
            self.hidden = rnnstate
        self.input_feed = input_feed
        self.coverage = coverage

    def repeat_beam_size_times(self, beam_size):
        vars = [Variable(e.data.repeat(1, beam_size, 1), volatile=True)
                for e in self._all]
        self.hidden = tuple(vars[:-1])
        self.input_feed = vars[-1]







class Statistics(object):
    def __init__(self):
        self.clf_loss = 0
        self.ae_loss = 0.
        self.at_loss = 0.
        self.vat_loss = 0.
        self.entropy_loss = 0.
        self.n_words = 0
        self.n_correct = 0
        self.n_sent = 0
        self.grad_norm = 0
        self.start_time = time.time()

    def accuracy(self):
        return 100 * (self.n_correct / self.n_sent)

    def elapsed_time(self):
        return time.time() - self.start_time

    def output(self, epoch, batch, n_batches, start, logger):
        t = self.elapsed_time()
        logger.info(("Epoch %2d, %5d/%5d; "
                     "acc: %6.2f; "
                     "clf_loss: %1.4f; "
                     "at_loss: %1.4f; "
                     "vat_loss: %1.4f; "
                     "entropy_loss: %1.4f; "
                     "ae_loss: %1.4f; "
                     "norm: %2.4f; "
                     "%3.0f tok/s; "
                     "%6.0f s elapsed") %
                    (epoch,
                     batch,
                     n_batches,
                     self.accuracy(),
                     self.clf_loss / (batch + 1),
                     self.at_loss / (batch + 1),
                     self.vat_loss / (batch + 1),
                     self.entropy_loss / (batch + 1),
                     self.ae_loss / (batch + 1),
                     self.grad_norm / (batch + 1),
                     self.n_words / (t + 1e-5),
                     time.time() - start))
        sys.stdout.flush()

    def log(self, prefix, experiment, lr):
        t = self.elapsed_time()
        experiment.add_scalar_value(prefix + "_accuracy", self.accuracy())
        experiment.add_scalar_value(prefix + "_tgtper", self.n_words / t)
        experiment.add_scalar_value(prefix + "_lr", lr)



class PrettyMetrics(float):
    def __repr__(self):
        return "%0.2f" % (self)


class ConfusionMatrix(object):
    def __init__(self, labels):
        if type(labels) is dict:
            self.labels = []
            for i in range(len(labels)):
                self.labels.append(labels[i])
        else:
            self.labels = labels
        nc = len(self.labels)
        self._cm = np.zeros((nc, nc), dtype=np.int)

    def add(self, truth, guess):
        self._cm[truth, guess] += 1

    def __str__(self):
        values = []
        width = max(8, max(len(x) for x in self.labels) + 1)
        for i, label in enumerate([''] + self.labels):
            values += ["{:>{width}}".format(label, width=width + 1)]
        values += ['\n']
        for i, label in enumerate(self.labels):
            values += ["{:>{width}}".format(label, width=width + 1)]
            for j in range(len(self.labels)):
                values += ["{:{width}d}".format(self._cm[i, j], width=width + 1)]
            values += ['\n']
        values += ['\n']
        return ''.join(values)

    def reset(self):
        self._cm *= 0

    def get_correct(self):
        return self._cm.diagonal().sum()

    def get_total(self):
        return self._cm.sum()

    def get_acc(self):
        return float(self.get_correct()) / self.get_total()

    def get_recall(self):
        total = np.sum(self._cm, axis=1) + 0.0000001
        return np.diag(self._cm) / total

    def get_precision(self):
        total = np.sum(self._cm, axis=0) + 0.0000001
        return np.diag(self._cm) / total

    def get_mean_precision(self):
        return np.mean(self.get_precision())

    def get_mean_recall(self):
        return np.mean(self.get_recall())

    def get_macro_f(self, beta=1):
        p = self.get_mean_precision()
        r = self.get_mean_recall()
        if beta < 0:
            raise Exception('Beta must be greater than 0')
        return (beta * beta + 1) * p * r / (beta * beta * p + r)

    def get_f(self, beta=1):
        p = self.get_precision()[1]
        r = self.get_recall()[1]
        if beta < 0:
            raise Exception('Beta must be greater than 0')
        return (beta * beta + 1) * p * r / (beta * beta * p + r)

    def get_all_metrics(self):
        metrics = {}
        metrics['acc'] = PrettyMetrics(self.get_acc() * 100)
        metrics['correct'] = PrettyMetrics(self.get_correct())
        metrics['total'] = PrettyMetrics(self.get_total())
        if len(self.labels) == 2:
            metrics['precision'] = PrettyMetrics(self.get_precision()[1] * 100)
            metrics['recall'] = PrettyMetrics(self.get_recall()[1] * 100)
            metrics['f1'] = PrettyMetrics(self.get_f(1) * 100)
        else:
            metrics['mean_precision'] = PrettyMetrics(self.get_mean_precision() * 100)
            metrics['mean_recall'] = PrettyMetrics(self.get_mean_recall() * 100)
            metrics['macro_f1'] = PrettyMetrics(self.get_macro_f(1) * 100)
        return metrics

    def add_batch(self, truth, guess):
        for truth_i, guess_i in zip(truth, guess):
            self.add(truth_i, guess_i)

In [None]:
class AEModel(nn.Module):

    def __init__(self, config):
        super(AEModel, self).__init__()
        self.config = config
        self.embed = nn.Embedding(config.n_vocab,
                                  config.d_units,
                                  padding_idx=Vocab_Pad.PAD)
        self.decoder = StdRNNDecoder(rnn_type='LSTM',
                                     bidirectional_encoder=True,
                                     num_layers=1,
                                     hidden_size=config.hidden_size,
                                     dropout=0.2,
                                     embeddings=self.embed,
                                     attn_type="general")
        self.affine = nn.Linear(config.hidden_size,
                                config.n_vocab,
                                bias=True)
        weight = torch.ones(config.n_vocab)
        weight[Vocab_Pad.PAD] = 0
        self.criterion = nn.NLLLoss(weight,
                                    size_average=False)

    def output_and_loss(self, h_block, t_block):
        batch, length, units = h_block.shape
        logits_flat = seq_func(self.affine,
                               h_block,
                               reconstruct_shape=False)
        log_probs_flat = F.log_softmax(logits_flat,
                                       dim=-1)
        rebatch, _ = logits_flat.shape
        concat_t_block = t_block.view(rebatch)
        weights = (concat_t_block >= 1).float()

        loss = self.criterion(log_probs_flat,
                              concat_t_block)
        loss = loss.sum() / (weights.sum() + 1e-13)
        return loss

    def forward(self, memory_bank, enc_final, lengths, ly_batch_raw,
                dec_state=None):

        tgt_in_block, tgt_out_block = seq2seq_pad_concat(
            ly_batch_raw, -1)
        tgt_in_block = tgt_in_block.transpose(0, 1).contiguous()

        memory_bank = memory_bank.transpose(0, 1).contiguous()
        lengths = torch.from_numpy(lengths).type(LONG_TYPE)
        enc_state = self.decoder.init_decoder_state(enc_final)
        decoder_outputs, dec_state, attns = self.decoder(tgt_in_block,
                                                         memory_bank,
                                                         enc_state if dec_state is None
                                                         else dec_state,
                                                         memory_lengths=lengths)
        decoder_outputs = decoder_outputs.transpose(0, 1).contiguous()
        loss = self.output_and_loss(decoder_outputs, tgt_out_block)
        return loss


In [None]:
class StdRNNDecoder(RNNDecoderBase):
    def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
        attns = {}
        emb = self.embeddings(tgt)

        if isinstance(self.rnn, nn.GRU):
            rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, decoder_final = self.rnn(emb, state.hidden)

        tgt_len, tgt_batch = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)

        decoder_outputs, p_attn = self.attn(
            rnn_output.transpose(0, 1).contiguous(),
            memory_bank.transpose(0, 1),
            memory_lengths=memory_lengths
        )
        attns["std"] = p_attn

        decoder_outputs = self.dropout(decoder_outputs)
        return decoder_final, decoder_outputs, attns

    def _build_rnn(self, rnn_type, **kwargs):
        rnn, _ = rnn_factory(rnn_type, **kwargs)
        return rnn

    @property
    def _input_size(self):
        return self.embeddings.embedding_dim



In [None]:
class Training(object):
    def __init__(self, config, logger=None):
        if logger is None:
            logger = logging.getLogger('logger')
            logger.setLevel(logging.DEBUG)
            logging.basicConfig(format='%(message)s', level=logging.DEBUG)

        self.logger = logger
        self.config = config
        self.classes = list(config.id2label.keys())
        self.num_classes = config.num_classes

        self.embedder = Embedder(self.config).to(device)
        self.encoder = LSTMEncoder(self.config).to(device)
        self.clf = Classifier(self.config).to(device)
        self.clf_loss = SequenceCriteria(class_weight=None).to(device)
        if self.config.lambda_ae > 0: self.ae = AEModel(self.config).to(device)

        self.writer = SummaryWriter(log_dir="TFBoardSummary")
        self.global_steps = 0
        
        self.enc_clf_opt = Adam(self._get_trainabe_modules(),lr=self.config.lr,
                                betas=(config.beta1,config.beta2),weight_decay=config.weight_decay,
                                eps=config.eps)

        if config.scheduler == "ReduceLROnPlateau":
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.enc_clf_opt,mode='max',
                                                            factor=config.lr_decay,patience=config.patience,
                                                            verbose=True)
        elif config.scheduler == "ExponentialLR":
            self.scheduler = lr_scheduler.ExponentialLR(self.enc_clf_opt,gamma=config.gamma)

        self._init_or_load_model()
        if config.multi_gpu:
            self.embedder.cuda()
            self.encoder.cuda()
            self.clf.cuda()
            self.clf_loss.cuda()
            if self.config.lambda_ae > 0: self.ae.cuda()

        self.ema_embedder = ExponentialMovingAverage(decay=0.999)
        self.ema_embedder.register(self.embedder.state_dict())
        self.ema_encoder = ExponentialMovingAverage(decay=0.999)
        self.ema_encoder.register(self.encoder.state_dict())
        self.ema_clf = ExponentialMovingAverage(decay=0.999)
        self.ema_clf.register(self.clf.state_dict())

        self.time_s = time.time()

    def _get_trainabe_modules(self):
        param_list = list(self.embedder.parameters()) + \
                     list(self.encoder.parameters()) + \
                     list(self.clf.parameters())
        if self.config.lambda_ae > 0:
            param_list += list(self.ae.parameters())
        return param_list

    def _get_l2_norm_loss(self):
        total_norm = 0.
        for p in self._get_trainabe_modules():
            param_norm = p.data.norm(p=2)
            total_norm += param_norm  
        return total_norm  

    def _init_or_load_model(self):
        ensure_directory(self.config.output_path)
        self.epoch = 0
        self.best_accuracy = -np.inf

    def _compute_vocab_freq(self, train_, dev_):
        counter = collections.Counter()
        for _, ids_ in train_:
            counter.update(ids_)
        for _, ids_ in dev_:
            counter.update(ids_)
        word_freq = np.zeros(self.config.n_vocab)
        for index_, freq_ in counter.items():
            word_freq[index_] = freq_
        return torch.from_numpy(word_freq).type(FLOAT_TYPE)

    def _save_model(self):
        state = {'epoch': self.epoch,
                 'state_dict_encoder': self.ema_encoder.shadow_variable_dict,
                 'state_dict_embedder': self.ema_embedder.shadow_variable_dict,
                 'state_dict_clf': self.ema_clf.shadow_variable_dict,
                 'best_accuracy': self.best_accuracy}
        torch.save(state, os.path.join(self.config.output_path,
                                       self.config.model_file))

    def _load_model(self):
        checkpoint_path = os.path.join(self.config.output_path,
                                       self.config.model_file)
        if self.config.load_checkpoint and os.path.isfile(checkpoint_path):
            dict_ = torch.load(checkpoint_path)
            self.epoch = dict_['epoch']
            self.best_accuracy = dict_['best_accuracy']
            self.embedder.load_state_dict(dict_['state_dict_embedder'])
            self.encoder.load_state_dict(dict_['state_dict_encoder'])
            self.clf.load_state_dict(dict_['state_dict_clf'])
            self.logger.info(
                "=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path,
                                                              self.epoch))
            return True

    def __call__(self, train, dev, test, unlabel):
        if self.config.normalize_embedding:
            self.embedder.vocab_freqs = self._compute_vocab_freq(train, dev)
            print("Embeddings will be normalized during training")
        self.logger.info('Start training')
        self._train(train, dev, unlabel)
        self._evaluate(test)

    def _create_iter(self, data_, wbatchsize, random_shuffler=data.iterator.RandomShuffler()):
        iter_ = data.iterator.pool(data_,
                                   wbatchsize,
                                   key=lambda x: len(x[1]),
                                   batch_size_fn=batch_size_fn,
                                   random_shuffler=random_shuffler
                                   )
        return iter_

    def _run_epoch(self, train_data, dev_data, unlabel_data):
        report_stats = Statistics()
        cm = ConfusionMatrix(self.classes)
        _, seq_data = list(zip(*train_data))
        total_seq_words = len(list(itertools.chain.from_iterable(seq_data)))
        iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize

        self.encoder.train()
        self.clf.train()
        self.embedder.train()
        train_iter = self._create_iter(train_data, self.config.wbatchsize)
        unlabel_iter = self._create_iter(unlabel_data,
                                         self.config.wbatchsize_unlabel)
        for batch_index, train_batch_raw in enumerate(train_iter):
            seq_iter = list(zip(*train_batch_raw))[1]
            seq_words = len(list(itertools.chain.from_iterable(seq_iter)))
            report_stats.n_words += seq_words
            self.global_steps += 1

            if self.config.add_noise:
                train_batch_raw = add_noise(train_batch_raw,
                                            self.config.noise_dropout,
                                            self.config.random_permutation)
            train_batch = seq_pad_concat(train_batch_raw, -1)

            train_embedded = self.embedder(train_batch)
            memory_bank_train, enc_final_train = self.encoder(train_embedded, train_batch)

            if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy:
                try:
                    unlabel_batch_raw = next(unlabel_iter)
                except StopIteration:
                    unlabel_iter = self._create_iter(unlabel_data,
                                                     self.config.wbatchsize_unlabel)
                    unlabel_batch_raw = next(unlabel_iter)

                if self.config.add_noise:
                    unlabel_batch_raw = add_noise(unlabel_batch_raw,
                                                  self.config.noise_dropout,
                                                  self.config.random_permutation)
                unlabel_batch = seq_pad_concat(unlabel_batch_raw,-1)
                unlabel_embedded = self.embedder(unlabel_batch)
                memory_bank_unlabel, enc_final_unlabel = self.encoder(
                    unlabel_embedded,
                    unlabel_batch)

            pred = self.clf(memory_bank_train)
            accuracy = self.get_accuracy(cm, pred.data, train_batch.labels.data)
            lclf = self.clf_loss(pred, train_batch.labels)

            lat = Variable(torch.FloatTensor([-1.]).type(FLOAT_TYPE))
            lvat = Variable(torch.FloatTensor([-1.]).type(FLOAT_TYPE))
            if self.config.lambda_at > 0:
                lat = at_loss(self.embedder,
                              self.encoder,
                              self.clf,
                              train_batch,
                              perturb_norm_length=self.config.perturb_norm_length)

            if self.config.lambda_vat > 0:
                lvat_train = vat_loss(self.embedder,
                                      self.encoder,
                                      self.clf,
                                      train_batch,
                                      p_logit=pred,
                                      perturb_norm_length=self.config.perturb_norm_length)
                if self.config.inc_unlabeled_loss:
                    lvat_unlabel = vat_loss(self.embedder,
                                            self.encoder,
                                            self.clf,
                                            unlabel_batch,
                                            p_logit=self.clf(memory_bank_unlabel),
                                            perturb_norm_length=self.config.perturb_norm_length)
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lvat = 0.5 * (lvat_train + lvat_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lvat = lvat_unlabel
                else:
                    lvat = lvat_train

            lentropy = Variable(torch.FloatTensor([-1.]).type(FLOAT_TYPE))
            if self.config.lambda_entropy > 0:
                lentropy_train = entropy_loss(pred)
                if self.config.inc_unlabeled_loss:
                    lentropy_unlabel = entropy_loss(self.clf(memory_bank_unlabel))
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lentropy = 0.5 * (lentropy_train + lentropy_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lentropy = lentropy_unlabel
                else:
                    lentropy = lentropy_train

            lae = Variable(torch.FloatTensor([-1.]).type(FLOAT_TYPE))
            if self.config.lambda_ae > 0:
                lae = self.ae(memory_bank_unlabel,
                              enc_final_unlabel,
                              unlabel_batch.sent_len,
                              unlabel_batch_raw)

            ltotal = (self.config.lambda_clf * lclf) + \
                     (self.config.lambda_ae * lae) + \
                     (self.config.lambda_at * lat) + \
                     (self.config.lambda_vat * lvat) + \
                     (self.config.lambda_entropy * lentropy)

            report_stats.clf_loss += lclf.data.cpu().numpy()
            report_stats.at_loss += lat.data.cpu().numpy()
            report_stats.vat_loss += lvat.data.cpu().numpy()
            report_stats.ae_loss += lae.data.cpu().numpy()
            report_stats.entropy_loss += lentropy.data.cpu().numpy()
            report_stats.n_sent += len(pred)
            report_stats.n_correct += accuracy
            self.enc_clf_opt.zero_grad()
            ltotal.backward()

            params_list = self._get_trainabe_modules()
            if not self.config.normalize_embedding:
                params_list += list(self.embedder.parameters())

            norm = torch.nn.utils.clip_grad_norm(params_list,
                                                 self.config.max_norm)
            report_stats.grad_norm += norm
            self.enc_clf_opt.step()
            if self.config.scheduler == "ExponentialLR":
                self.scheduler.step()
            self.ema_embedder.apply(self.embedder.named_parameters())
            self.ema_encoder.apply(self.encoder.named_parameters())
            self.ema_clf.apply(self.clf.named_parameters())

            report_func(self.epoch,
                        batch_index,
                        iter_per_epoch,
                        self.time_s,
                        report_stats,
                        self.config.report_every,
                        self.logger)

            if self.global_steps % self.config.eval_steps == 0:
                cm_, accuracy, prc_dev = self._run_evaluate(dev_data)
                self.logger.info("- dev accuracy {} | best dev accuracy {} ".format(accuracy, self.best_accuracy))
                self.writer.add_scalar("Dev_Accuracy", accuracy,
                                       self.global_steps)
                pred_, lab_ = zip(*prc_dev)
                pred_ = torch.cat(pred_)
                lab_ = torch.cat(lab_)
                self.writer.add_pr_curve("Dev PR-Curve", lab_,
                                         pred_,
                                         self.global_steps)
                pprint.pprint(cm_)
                pprint.pprint(cm_.get_all_metrics())
                if accuracy > self.best_accuracy:
                    self.logger.info("- new best score!")
                    self.best_accuracy = accuracy
                    self._save_model()
                if self.config.scheduler == "ReduceLROnPlateau":
                    self.scheduler.step(accuracy)
                self.encoder.train()
                self.embedder.train()
                self.clf.train()

                if self.config.weight_decay > 0:
                    print(">> Square Norm: %1.4f " % self._get_l2_norm_loss())

        cm, train_accuracy, _ = self._run_evaluate(train_data)
        self.logger.info("- Train accuracy  {}".format(train_accuracy))
        pprint.pprint(cm.get_all_metrics())

        cm, dev_accuracy, _ = self._run_evaluate(dev_data)
        self.logger.info("- Dev accuracy  {} | best dev accuracy {}".format(dev_accuracy, self.best_accuracy))
        pprint.pprint(cm.get_all_metrics())
        self.writer.add_scalars("Overall_Accuracy",
                                {"Train_Accuracy": train_accuracy,
                                 "Dev_Accuracy": dev_accuracy},
                                self.global_steps)
        return dev_accuracy

    @staticmethod
    def get_accuracy(cm, output, target):
        batch_size = output.size(0)
        predictions = output.max(-1)[1].type_as(target)
        correct = predictions.eq(target)
        correct = correct.float()
        if not hasattr(correct, 'sum'):
            correct = correct.cpu()
        correct = correct.sum()
        cm.add_batch(target.cpu().numpy(), predictions.cpu().numpy())
        return correct

    def _predict_batch(self, cm, batch):
        self.embedder.eval()
        self.encoder.eval()
        self.clf.eval()
        pred = self.clf(self.encoder(self.embedder(batch),
                                     batch)[0])
        accuracy = self.get_accuracy(cm, pred.data, batch.labels.data)
        return pred, accuracy

    def chunks(self, l, n=15):
        for i in range(0, len(l), n):
            yield l[i:i + n]

    def _run_evaluate(self, test_data):
        pr_curve_data = []
        cm = ConfusionMatrix(self.classes)
        accuracy_list = []
        test_iter = self.chunks(test_data)
        for test_batch in test_iter:
            test_batch = seq_pad_concat(test_batch, -1)
            pred, acc = self._predict_batch(cm, test_batch)
            accuracy_list.append(acc)
            pr_curve_data.append(
                (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data))
        accuracy = 100 * (sum(accuracy_list) / len(test_data))
        return cm, accuracy, pr_curve_data

    def _train(self, train_data, dev_data, unlabel_data):
        nepoch_no_imprv = 0

        epoch_start = self.epoch + 1
        epoch_end = self.epoch + self.config.nepochs + 1
        for self.epoch in range(epoch_start, epoch_end):
            self.logger.info(
                "Epoch {:} out of {:}".format(self.epoch, self.config.nepochs))
            random.shuffle(train_data)
            random.shuffle(unlabel_data)
            accuracy = self._run_epoch(train_data, dev_data, unlabel_data)

            if accuracy > self.best_accuracy:
                nepoch_no_imprv = 0
                self.best_accuracy = accuracy
                self.logger.info("- new best score!")
                self._save_model()
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                    self.logger.info(
                        "- early stopping {} epochs without improvement".format(
                            nepoch_no_imprv))
                    break
            if self.config.scheduler == "ReduceLROnPlateau":
                self.scheduler.step(accuracy)

    def _evaluate(self, test_data):
        self.logger.info("Evaluating model over test set")
        self._load_model()
        _, accuracy, prc_test = self._run_evaluate(test_data)
        pred_, lab_ = zip(*prc_test)
        pred_ = torch.cat(pred_).cpu().tolist()
        lab_ = torch.cat(lab_).cpu().tolist()
        path_ = os.path.join(self.config.output_path, "{}_pred_gt.tsv".format(self.config.exp_name))
        with open(path_, 'w') as fp:
            for p, l in zip(pred_, lab_):
                fp.write(str(p) + '\t' + str(l) + '\n')
        self.logger.info("- test accuracy {}".format(accuracy))


In [None]:
class GlobalAttention(nn.Module):
    def __init__(self, dim, attn_type="dot"):
        super(GlobalAttention, self).__init__()

        self.dim = dim
        self.attn_type = attn_type
        assert (self.attn_type in ["dot", "general", "mlp"]), (
                "Please select a valid attention type.")

        if self.attn_type == "general":
            self.linear_in = nn.Linear(dim, dim, bias=False)
        elif self.attn_type == "mlp":
            self.linear_context = nn.Linear(dim, dim, bias=False)
            self.linear_query = nn.Linear(dim, dim, bias=True)
            self.v = nn.Linear(dim, 1, bias=False)
        out_bias = self.attn_type == "mlp"
        self.linear_out = nn.Linear(dim*2, dim, bias=out_bias)

        self.sm = nn.Softmax()
        self.tanh = nn.Tanh()

    def score(self, h_t, h_s):
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)

    def forward(self, input, memory_bank, memory_lengths=None):
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  
            align.data.masked_fill_(~mask, -float('inf'))

        align_vectors = self.sm(align.view(batch*targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        c = torch.bmm(align_vectors, memory_bank)

        concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors


#Training

In [None]:
class hyperparameter():
    model='WordLstm'
    d_units=300
    d_proj=200
    d_hidden=512
    projection=False
    d_down_proj=100
    down_projection=False
    num_discriminator_layers=3
    frnn=False
    brnn=True
    timedistributed=False
    init_scalar=0.05
    num_layers=1
    unif=0.25
    multi_gpu=False
    gpu_ids=[0, 1, 2, 3]
    gradient_clipping=True
    max_norm=1.0
    weight_decay=0.0
    load_checkpoint=False
    use_pretrained_embeddings=True
    train_embeddings=True
    finetune=False
    home = os.environ['HOME']
    max_iter=None
    nepochs=50
    dropout=0.5
    word_dropout=0.5
    lstm_dropout=0.5
    locked_dropout=0.5
    batch_size=64
    nepoch_no_imprv=3
    nchkp_no_imprv=30
    hidden_size=1024
    subsampling=1e-4
    class_weight='uniform'
    optim='adam'
    eval_steps=1000
    lr=0.001
    lr_decay=0.5
    beta1=0.9
    beta2=0.999
    eps=1e-8
    patience=20
    scheduler='ReduceLROnPlateau'
    gamma=0.99995
    adaptive_dropout = False
    pool_type='max_pool'
    dynamic_pool_size=20
    wbatchsize=3000
    wbatchsize_unlabel=12000

    lambda_clf=1.0
    lambda_ae=0.0
    lambda_at=1.0
    lambda_vat=1.0
    lambda_entropy=1.0
    inc_unlabeled_loss=True
    unlabeled_loss_type='AvgTrainUnlabel'
    perturb_norm_length=5.0
    max_embedding_norm=None
    normalize_embedding=False
    add_noise=False
    noise_dropout=0.1
    random_permutation=3
    debug=False
    report_every=100
    input='temp'
    save_data='demo'
    output_path="results/clf/"
    exp_name='ssl'
    corpus='sst'
    model_file = exp_name + ".pt"
    now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M-%S")
    random_num = random.randint(1, 1000)
    log_path = os.path.join(output_path, "log_{}_time-{}_rand_{}.txt".format(now,exp_name,random_num))

In [None]:
args = Args()
if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
logger = get_logger(args.log_path)
logger.info(json.dumps(args.__dict__, indent=4))

id2label = {0:0,1:1}
# id2label = {0:1,1:2,2:3,3:4}
args.id2w = id2w
args.n_vocab = len(id2w)
args.id2label = id2label
args.num_classes = len(id2label)

object = Training(args, logger)
object(train_data, dev_data, test_data, unlabel_data)