In [None]:
import torch
import torch.nn as nn

In [None]:
class BasicClassifierCell(nn.Module):
    def __init__(self, input_size, hidden_size=None, dropout=0.):
        super(BasicClassifierCell, self).__init__()        
        self.dout = nn.Dropout(dropout)
        self.l = nn.Linear(2*input_size, 1)
        self.sig = nn.Sigmoid()
        
    
    def forward(self, inputs):
        vec1, vec2 = inputs
        concated = self.dout(torch.cat((vec1, vec2), -1))
        return self.sig(self.l(concated))

class ClassifierCell(nn.Module):
    def __init__(self, input_size, hidden_size=128, dropout=0.):
        super(ClassifierCell, self).__init__()
        
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        
        self.dout = nn.Dropout(dropout)
        
        self.l2 = nn.Linear(2*hidden_size, 1)
        self.sig = nn.Sigmoid()
        
    
    def forward(self, inputs):
        vec1, vec2 = inputs
        
        a11 = self.relu1(self.l1(vec1))
        a12 = self.relu1(self.l1(vec2))
        concated = self.dout(torch.cat((a11, a12), -1))
        
        return self.sig(self.l2(concated))
    

class Classifier(nn.Module):
    def __init__(self, word_emb_size, lstm_hidden_size, lstm_num_layers, clssfr_cell_type, clssfr_hidden_size):
        super(Classifier, self).__init__()
        
        self.is_bidirectional = True
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm = nn.LSTM(input_size=word_emb_size, 
                            hidden_size=lstm_hidden_size, 
                            num_layers=lstm_num_layers,
                            dropout=0.2, bidirectional=self.is_bidirectional)
        
        self.clssfr = clssfr_cell_type(lstm_hidden_size, 
                                     hidden_size=clssfr_hidden_size,
                                     dropout=0.2)
        
    def encode(self, seq, h_0=None, c_0=None):
        outs, (h_n, c_n) = self.lstm(seq)
        
        if self.is_bidirectional:
            forward_out = outs[-1, :, :self.lstm_hidden_size]
            backward_out = outs[0, :, self.lstm_hidden_size:]
            return (forward_out+backward_out)/2
        else:
            return outs[-1]
        
    def forward(self, inputs):
        seq1, seq2 = inputs        
        enc1, enc2 = self.encode(seq1), self.encode(seq2)
        return self.clssfr((enc1, enc2))
    
    
    def fit(self, inputs, true_outputs, epochs=10):
        # convert inputs & outputs to tensors (batch?)
        # setup training: optimiser, loss func, train params
        # keep track of variables: loss and prediction
        # train loop
        
        loss_f = torch.nn.BCELoss()
        optim = torch.optim.Adam(self.parameters(), lr=0.0001)

        losses, preds = [], []

        for i in tqdm(range(epochs)):
            pred = self.forward((one_tens, two_tens))
            preds.append(pred)
    
            loss = loss_f(pred, y)
            losses.append(loss)

            optim.zero_grad()
            loss.backward()
            optim.step()

In [None]:
d = 768
cls = Classifier(word_emb_size=d, lstm_hidden_size=7, lstm_num_layers=3, 
                 clssfr_cell_type=BasicClassifierCell, clssfr_hidden_size=None)

In [None]:
x, y = torch.arange(float(d)).unsqueeze(0).unsqueeze(0), torch.arange(float(d)).unsqueeze(0).unsqueeze(0)

# x, y = torch.randn(2,2,2), torch.randn(2,2,2)

o = torch.autograd.Variable(torch.tensor([0.]))

cls((x,y))

In [None]:
# from transformers import DistilBertModel
import pickle

# with open("emails_token_ids.pkl", "rb") as handle:
#     ids = pickle.load(handle)[:20]


# bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
# bert.eval()

In [None]:
def cut_up(mail_tensor, n=512):
    split_tens = mail_tensor.split(n)
    
    if split_tens[-1].nelement() == split_tens[0].nelement():
        return torch.stack(split_tens), None
    else:
        return torch.stack(split_tens[:-1]), split_tens[-1].unsqueeze(0)


# WRAPPER FOR MODEL CALL
def email_to_vec(email_body_ids, to_id_first=False, chunk_size=512):
    if to_id_first:
        input_ids = email_to_ids(email_body_ids)
    else:
        input_ids = email_body_ids
        
    chunks, end_chunk = cut_up(input_ids, chunk_size)
    
#     chunks_cuda = chunks.to(device)
    outputs, *_ = bert(chunks)
    
    outputs_flattened = outputs.view(-1, outputs.shape[-1])
    
    if end_chunk is not None:
#         end_cuda = end_chunk.to(device)
        end_output, *_ = bert(end_chunk)
        outputs_flattened = torch.cat((outputs_flattened, end_output.squeeze(0)), 0)
        
    return outputs_flattened.cpu().numpy()

In [None]:
with torch.no_grad():
    first_twenty = [email_to_vec(mt) for mt in ids]

In [None]:
with open("first_twenty.pkl", "rb") as handle:
    first_twenty = pickle.load(handle)

In [None]:
with torch.no_grad():
    first = email_to_vec(ids[0])
    second = email_to_vec(ids[1])
    print(first.shape)

In [None]:
first_tens = torch.tensor(first).unsqueeze(1)
second_tens = torch.tensor(second).unsqueeze(1)

print(first_tens.size())
print(second_tens.size())

In [None]:
cls((first_tens, second_tens))

In [None]:
with torch.no_grad():
    for i, one in enumerate(first_twenty):
        print(i)
        for two in first_twenty:
            one_tens = torch.tensor(one).unsqueeze(1)
            two_tens = torch.tensor(two).unsqueeze(1)


            print("\t", cls((one_tens, two_tens)))
            print()

In [None]:
loss_f = torch.nn.BCELoss()
optim = torch.optim.Adam(cls.parameters(), lr=0.0001)

In [None]:
y = torch.tensor([[0.]])

epochs = 1000

losses, preds = [], []

for i in tqdm(range(epochs)):
    pred = cls((one_tens, two_tens))
    preds.append(pred)
    
    loss = loss_f(pred, y)
    losses.append(loss)

    optim.zero_grad()
    loss.backward()
    optim.step()

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

plt.plot(range(epochs), losses)
plt.show()

In [None]:
plt.plot(range(epochs), preds)