In [None]:
import json
import os
from collections import  Counter
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn.functional import cross_entropy
from torch.autograd import Variable
config ={
"elmo": {
        "activation": "relu",
        "filters": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]],
        "n_highway": 2, 
        "word_dim": 300,
        "char_dim": 50,
        "max_char_token": 50,
        "min_count":5,
        "max_length":256,
        "output_dim":150,
        "units":256,
        "n_layers":2,
    },
"batch_size":64,
"epochs":5,
"lr":0.00001,
}
model_path="./elmo_model"

In [None]:
with open("data/corpus.json") as f:
    corpus = json.load(f)
    #corpus = corpus[:1000]

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('device: ' + str(device))

In [None]:
class Tokenizer:
    def __init__(self, word2id,ch2id):
        self.word2id = word2id
        self.ch2id = ch2id
        self.id2word = {i: word for word, i in word2id.items()}
        self.id2ch = {i: char for char, i in ch2id.items()}
    @classmethod
    def from_corpus(cls,corpus,min_count=5):
        word_count = Counter()
        for sentence in corpus:
            word_count.update(sentence.lower().split())
        word_count = list(word_count.items())
        word_count.sort(key=lambda x: x[1], reverse=True)
        for i, (word, count) in enumerate(word_count):
            if count < min_count:
                break
        vocab = word_count[:i]
        vocab = [v[0] for v in vocab]
        word_lexicon = {}
        for special_word in ['<oov>', '<pad>']:
            if special_word not in word_lexicon:
                word_lexicon[special_word] = len(word_lexicon)
        for word in vocab:
            if word not in word_lexicon:
                word_lexicon[word] = len(word_lexicon)
        char_lexicon = {}
        for special_char in ['<oov>', '<pad>']:
            if special_char not in char_lexicon:
                char_lexicon[special_char] = len(char_lexicon)
        for sentence in corpus:
            for word in sentence.split():
                for ch in word:
                    if ch not in char_lexicon:
                        char_lexicon[ch] = len(char_lexicon)
        return cls(word_lexicon,char_lexicon)
    @classmethod
    def from_file(cls,path):
        with open(f"{path}/tokenizer.json") as f:
            d = json.load(f)
        return cls(d["word2id"],d["ch2id"])
    
    def tokenize(self,text,max_length=512,max_char=50):
        oov_id, pad_id = self.word2id.get("<oov>"), self.word2id.get("<pad>")
        w = torch.LongTensor(max_length).fill_(pad_id)
        words = text.lower().split()
        for i, wi in enumerate(words[:max_length]):
            w[i] = self.word2id.get(wi, oov_id)
        oov_id, pad_id = self.ch2id.get("<oov>"), self.ch2id.get("<pad>")
        c = torch.LongTensor(max_length,max_char).fill_(pad_id)
        for i, wi in enumerate(words[:max_length]):
            for j,wij in enumerate(wi[:max_char]):
                c[i][j]=self.ch2id.get(wij, oov_id)
        return w , c

    def save(self,path):
        try:
            os.mkdir(path)
        except:
            pass
        tok ={
            "word2id":self.word2id,
            "ch2id":self.ch2id
        }
        with open(f"{path}/tokenizer.json","w") as f:
            json.dump(tok,f,indent=4)

In [None]:
tokenizer = Tokenizer.from_corpus(corpus,config["elmo"]["min_count"])

In [None]:
class ELMoDataSet(Dataset):
    def __init__(self,corpus,tokenizer):
        self.corpus=corpus
        self.tokenizer=tokenizer
    def __getitem__(self, idx):
        text = self.corpus[idx]
        w,c = self.tokenizer.tokenize(text,max_length=config["elmo"]["max_length"],max_char=config["elmo"]["max_char_token"])
        return w,c
    def __len__(self):
        return len(self.corpus)

In [None]:
data = ELMoDataSet(corpus,tokenizer)

In [None]:
data_loader = DataLoader(data, batch_size=config["batch_size"])

In [None]:
# Based upon https://gist.github.com/Redchards/65f1a6f758a1a5c5efb56f83933c3f6e
# Original Paper https://arxiv.org/abs/1505.00387
class HighWay(nn.Module):
    def __init__(self, input_dim, num_layers=1,activation= nn.functional.relu):
        super(HighWay, self).__init__()
        self._input_dim = input_dim
        self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])
        self._activation = activation
        for layer in self._layers:
            layer.bias[input_dim:].data.fill_(1)
    def forward(self, inputs):
        current_input = inputs
        for layer in self._layers:
            projected_input = layer(current_input)
            linear_part = current_input
            nonlinear_part = projected_input[:, (0 * self._input_dim):(1 * self._input_dim)]
            gate = projected_input[:, (1 * self._input_dim):(2 * self._input_dim)]
            nonlinear_part = self._activation(nonlinear_part)
            gate = torch.sigmoid(gate)
            current_input = gate * linear_part + (1 - gate) * nonlinear_part
        return current_input

In [None]:
class ELMo(nn.Module):
    def __init__(self,tokenizer,config):
        super(ELMo, self).__init__()
        self.config=config
        self.tokenizer = tokenizer
        self.word_embedder = nn.Embedding(len(tokenizer.word2id),config["elmo"]["word_dim"],padding_idx=tokenizer.word2id.get("<pad>"))
        self.char_embedder = nn.Embedding(len(tokenizer.ch2id),config["elmo"]["char_dim"],padding_idx=tokenizer.ch2id.get("<pad>"))
        self.output_dim = config["elmo"]["output_dim"]
        activation = config["elmo"]["activation"]
        if activation=="relu":
            self.act = nn.ReLU()
        elif activation=="tanh":
            self.act=nn.Tanh()
        self.emb_dim = config["elmo"]["word_dim"]
        self.convolutions = []
        filters = config["elmo"]["filters"]
        char_dim = config["elmo"]["char_dim"]
        for i, (width, num) in enumerate(filters):
            conv = nn.Conv1d(in_channels=char_dim,
                             out_channels=num,
                             kernel_size=width,
                             bias=True
                             )
            self.convolutions.append(conv)
        self.convolutions = nn.ModuleList(self.convolutions)
        self.n_filters = sum(f[1] for f in filters)
        self.n_highway = config["elmo"]["n_highway"]
        self.highways = HighWay(self.n_filters, self.n_highway, activation=self.act)
        self.emb_dim += self.n_filters
        self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True)
        self.f=[nn.LSTM(input_size = config["elmo"]["output_dim"], hidden_size = config["elmo"]["units"], batch_first=True)]
        self.b=[nn.LSTM(input_size = config["elmo"]["output_dim"], hidden_size = config["elmo"]["units"], batch_first=True)]
        for _ in range(config["elmo"]["n_layers"]-1):
            self.f.append(nn.LSTM(input_size = config["elmo"]["units"], hidden_size = config["elmo"]["units"], batch_first=True))
            self.b.append(nn.LSTM(input_size = config["elmo"]["units"], hidden_size = config["elmo"]["units"], batch_first=True))
        self.f = nn.ModuleList(self.f)
        self.b = nn.ModuleList(self.b)
        self.fwl = nn.Linear(in_features=config["elmo"]["units"], out_features=len(tokenizer.word2id))
        self.bwl = nn.Linear(in_features=config["elmo"]["units"], out_features=len(tokenizer.word2id))
    def forward(self, word_inp, chars_inp):
        embs = []
        batch_size, seq_len = word_inp.size(0), word_inp.size(1)
        word_emb = self.word_embedder(Variable(word_inp))
        embs.append(word_emb)
        chars_inp = chars_inp.view(batch_size * seq_len, -1)
        char_emb = self.char_embedder(Variable(chars_inp))
        char_emb = char_emb.transpose(1, 2)
        convs = []
        for i in range(len(self.convolutions)):
            convolved = self.convolutions[i](char_emb)
            convolved, _ = torch.max(convolved, dim=-1)
            convolved = self.act(convolved)
            convs.append(convolved)
        char_emb = torch.cat(convs, dim=-1)
        char_emb = self.highways(char_emb)
        embs.append(char_emb.view(batch_size, -1, self.n_filters))
        token_embedding = torch.cat(embs, dim=2)
        embeddings = self.projection(token_embedding)
        fs = [embeddings[:, :-1, :]]         
        bs = [embeddings[:, 1:, :]]
        for fl,bl in zip(self.f,self.b):
            o_f,_ = fl(fs[-1])
            fs.append(o_f)
            o_b,_ = bl(torch.flip(bs[-1],dims=[1,]))
            bs.append(torch.flip(o_b,dims=(1,)))
        return fs,bs
    def save_model(self,path):
        try:
            os.mkdir(path)
        except:
            pass
        torch.save(self.state_dict(),f'{path}/model.pt')
        with open(f"{path}/config.json","w") as f:
            json.dump(self.config,f,indent=4)
        self.tokenizer.save(path)
    @classmethod
    def from_checkpoint(cls,path):
        with open(f"{path}/config.json") as f:
            config = json.load(f)
        tokenizer = Tokenizer.from_file(path)
        model = cls(tokenizer,config)
        model.load_state_dict(torch.load(f'{path}/model.pt'))
        return model

In [None]:
model = ELMo(tokenizer,config)

In [None]:
model.to(device)

In [None]:
opt = optim.Adam(model.parameters(),lr = config["lr"])
for epoch in range(config["epochs"]):
    print(f"Epoch: {epoch+1}")
    for batch in tqdm(data_loader):
        w , c = batch
        w = w.to(device)
        c = c.to(device)
        f, b = model(w,c)
        f = model.fwl(f[-1])  
        b = model.bwl(b[-1])
        loss = (
            cross_entropy(f.reshape(-1,len(tokenizer.word2id)),w[:,1:].reshape(-1)) +
            cross_entropy(b.reshape(-1,len(tokenizer.word2id)),w[:,:-1].reshape(-1)))/2
        loss.backward()
        opt.step()
        model.zero_grad()
    model.save_model(model_path)