In [None]:
!pip install transformers

In [None]:
!pip install torchdata


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [4]:
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score

import torch
import transformers
import torch.nn as nn
from transformers import AutoModel, BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

tqdm.pandas()

device = torch.device('cuda')

In [None]:
#descriminator

bert = AutoModel.from_pretrained("bert-base-uncased")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [6]:
import os
import io

import torchdata.datapipes as dp
import json
from random import randrange

def snli_train_all():
    file = dp.iter.FileOpener(['./drive/MyDrive/snli_1.0/snli_1.0_train.jsonl']) \
        .readlines(decode=True, return_path=False, strip_newline=True) \
        .map(lambda line: json.loads(line.strip())) \
        .map(lambda line: (line['sentence1'], line['sentence2']))
    return file

In [7]:
for param in bert.parameters():
    param.requires_grad = True

class BERT_Arch(nn.Module):
    
    def __init__(self, bert):
        super(BERT_Arch, self).__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(768,512)
        self.fc2 = nn.Linear(512,2)
        self.softmax = nn.LogSoftmax(dim = 1)
    
    def forward(self, sent_id, mask):
        _, cls_hs = self.bert(sent_id, attention_mask = mask, return_dict = False)
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

In [None]:
model = BERT_Arch(bert)

model = model.to(device)
from transformers import AdamW

desc_optimizer = AdamW(model.parameters(), lr= 1e-3)

In [10]:
def desc_train(train_dataloader):
    model.train()
    cross_entropy = nn.CrossEntropyLoss()
    for step, batch in tqdm(enumerate(train_dataloader), total = len(train_dataloader)):
        batch = [r.to(device) for r in batch]
        sent_id,mask,labels = batch

        model.zero_grad()
        preds = model(sent_id, mask)
        loss = cross_entropy(preds, labels)
        loss.backward()

        desc_grad = model.bert.embeddings.word_embeddings.weight.grad

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        desc_optimizer.step()
        

In [None]:
# GENERATOR
!pip install -U spacy
!python -m spacy download en_core_web_sm

In [12]:
import torchdata.datapipes as dp
import json


class SnliLoader:
    def snli_train_all(self):
        file = dp.iter.FileOpener(['./drive/MyDrive/snli_1.0/snli_1.0_train.jsonl']) \
            .readlines(decode=True, return_path=False, strip_newline=True) \
            .map(lambda line: json.loads(line.strip())) \
            .map(lambda line: (line['sentence1'], line['sentence2']))
        return file

    def snli_valid_option(self, option):
        file = dp.iter.FileOpener(['./drive/MyDrive/snli_1.0/snli_1.0_dev.jsonl']) \
            .readlines(decode=True, return_path=False, strip_newline=True) \
            .map(lambda line: json.loads(line.strip())) \
            .filter(lambda line: line['gold_label'] == option) \
            .map(lambda line: (line['sentence1'], line['sentence2']))
        return file

    def snli_train_option(self, option):
        file = dp.iter.FileOpener(['./drive/MyDrive/snli_1.0/snli_1.0_train.jsonl']) \
            .readlines(decode=True, return_path=False, strip_newline=True) \
            .map(lambda line: json.loads(line.strip())) \
            .filter(lambda line: line['gold_label'] == option) \
            .map(lambda line: (line['sentence1'], line['sentence2']))
        return file

In [13]:
#vocab properties
gen_tokenizer = 'spacy'

special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3


In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from typing import Iterable, List


class Vocabulary(object):
    token_transform = get_tokenizer(gen_tokenizer)
    special_symbols = special_symbols
    vocab_transform: Vocab

    def __init__(self):
        self.build()

    def yield_tokens(self, data_iter: Iterable) -> List[str]:
        for data_sample in data_iter:
            yield self.token_transform(data_sample[0] + data_sample[1])

    def build(self) -> Vocab:
        snliLoader = SnliLoader()
        train_iter = snliLoader.snli_train_all()
        self.vocab_transform = build_vocab_from_iterator(self.yield_tokens(train_iter),
                                                         min_freq=1,
                                                         specials=self.special_symbols,
                                                         special_first=True)
        self.vocab_transform.set_default_index(UNK_IDX)
        return self.vocab_transform

    def get_vocab(self) -> Vocab:
        return self.vocab_transform

    def get_tokenizer(self):
        return self.token_transform

In [15]:
#transformer properties

EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4

In [16]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class WordEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(WordEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size: int, dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=EMB_SIZE,
                                       nhead=NHEAD,
                                       num_encoder_layers=NUM_ENCODER_LAYERS,
                                       num_decoder_layers=NUM_DECODER_LAYERS,
                                       dim_feedforward=FFN_HID_DIM,
                                       dropout=dropout)
        self.generator = nn.Linear(EMB_SIZE, vocab_size)
        self.src_tok_emb = WordEmbedding(vocab_size, EMB_SIZE)
        self.tgt_tok_emb = WordEmbedding(vocab_size, EMB_SIZE)
        self.positional_encoding = PositionalEncoding(EMB_SIZE, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
            self.tgt_tok_emb(tgt)), memory,
            tgt_mask)

In [17]:
#properties
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
import torch


class MasksBuilder:

    def create_mask(self, src, tgt):
        src_seq_len = src.shape[0]
        tgt_seq_len = tgt.shape[0]

        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len)
        src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

        src_padding_mask = (src == PAD_IDX).transpose(0, 1)
        tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [19]:
import torch
from torch import nn
from torch.utils.data import DataLoader


class SentenceGenerator:
    def __init__(self, option, sentence_preprocessor, loss_fn, vocab, batch_size):
        self.BATCH_SIZE = batch_size
        self.sentence_preprocessor = sentence_preprocessor
        self.option = option
        self.vocab = vocab
        self.model = Seq2SeqTransformer(len(vocab.vocab_transform))
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
        self.loss_fn = loss_fn
        self.snli = SnliLoader()
        self.mask_builder = MasksBuilder()

        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.model = self.model.to(DEVICE)

    def train_epoch(self):
        self.model.train()
        losses = 0
        train_iteration = self.snli.snli_train_option(self.option)
        train_dataloader = DataLoader(train_iteration, batch_size=self.BATCH_SIZE,
                                      collate_fn=self.sentence_preprocessor.collate_fn)
        counter = 0
        for src, tgt in train_dataloader:
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)
            tgt_input = tgt[:-1, :]
            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.mask_builder.create_mask(src, tgt_input)
            logits = self.model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask,
                                src_padding_mask)
            self.optimizer.zero_grad()
            tgt_out = tgt[1:, :]
            loss = self.loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss.backward()
            self.optimizer.step()
            losses += loss.item()
            counter += 1
        return losses / counter

    def evaluate(self):
        self.model.eval()
        losses = 0

        val_iter = self.snli.snli_valid_option(self.option)
        val_dataloader = DataLoader(val_iter, batch_size=self.BATCH_SIZE,
                                    collate_fn=self.sentence_preprocessor.collate_fn)

        counter = 0
        for src, tgt in val_dataloader:
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.mask_builder.create_mask(src, tgt_input)

            logits = self.model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask,
                                src_padding_mask)

            tgt_out = tgt[1:, :]
            loss = self.loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            losses += loss.item()
            counter += 1

        return losses / counter

    def greedy_decode(self, src, src_mask, max_len, start_symbol):
        src = src.to(DEVICE)
        src_mask = src_mask.to(DEVICE)

        memory = self.model.encode(src, src_mask)
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
        for i in range(max_len - 1):
            memory = memory.to(DEVICE)
            tgt_mask = (self.mask_builder.generate_square_subsequent_mask(ys.size(0))
                        .type(torch.bool)).to(DEVICE)
            out = self.model.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = self.model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()

            ys = torch.cat([ys,
                            torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
            if next_word == EOS_IDX:
                break
        return ys

    def translate(self, src_sentence: str):
        self.model.eval()
        src = self.sentence_preprocessor.text_transform(src_sentence).view(-1, 1)
        num_tokens = src.shape[0]
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = self.greedy_decode(
            src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
        return " ".join(self.vocab.vocab_transform.lookup_tokens(list(tgt_tokens.cpu().numpy()))) \
            .replace("<bos>", "") \
            .replace("<eos>", "")

    def backward_model(self, grads, sentence_batch):
        src, tgt = self.sentence_preprocessor.collate_fn(sentence_batch)
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:-1, :]
        new_grads = self.match_grads(grads, tgt_input)
        new_grads = new_grads.to(DEVICE)
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.mask_builder.create_mask(src, tgt_input)
        gen_ret = self.model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        self.optimizer.zero_grad()
        gen_ret.backward(new_grads)
        self.optimizer.step()

    def match_grads(self, grads, token_batch):
        pool_size = len(self.vocab.vocab_transform)
        new_grads = torch.zeros(token_batch.size()[0], token_batch.size()[1], pool_size)
        for batch_ind in range(grads.size()[0]):
            for token_ind in range(grads.size()[1]):
                if token_ind < token_batch.size()[0]:
                    grad = grads[batch_ind, token_ind]
                    new_grad = torch.zeros(pool_size)
                    target_pos = token_batch[token_ind, batch_ind]
                    new_grad[target_pos] = grad
                    new_grads[token_ind, batch_ind] = new_grad
        return new_grads


In [20]:
import torch


class ModelLoader:
    def __init__(self, saves_base_path="./.saved/spacy/test/"):
        self.saves_base_path = saves_base_path

    def load_model(self, sentence_generator: SentenceGenerator):
        checkpoint = torch.load(self.saves_base_path + sentence_generator.option + ".pt")
        sentence_generator.model.load_state_dict(checkpoint['model_state_dict'])
        sentence_generator.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    def save_model(self, sentence_generator: SentenceGenerator):
        torch.save({
            'model_state_dict': sentence_generator.model.state_dict(),
            'optimizer_state_dict': sentence_generator.optimizer.state_dict()
        }, self.saves_base_path + sentence_generator.option + ".pt")

In [21]:
from torch.nn.utils.rnn import pad_sequence


class SentencePreprocessor:
    def __init__(self, vocab):
        self.vocab: Vocabulary = vocab
        self.text_transform = self.sequential_transforms(self.vocab.get_tokenizer(),  
                                                         self.vocab.get_vocab(),  
                                                         self.tensor_transform)  

    def sequential_transforms(self, *transforms):
        def func(txt_input):
            for transform in transforms:
                txt_input = transform(txt_input)
            return txt_input

        return func

    def tensor_transform(self, token_ids: List[int]):
        return torch.cat((torch.tensor([BOS_IDX]),
                          torch.tensor(token_ids),
                          torch.tensor([EOS_IDX])))

    def collate_fn(self, batch):

        src_batch, tgt_batch = [], []
        for src_sample, tgt_sample in batch:
            src_batch.append(self.text_transform(src_sample.rstrip("\n")))
            tgt_batch.append(self.text_transform(tgt_sample.rstrip("\n")))

        src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
        tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
        return src_batch, tgt_batch

In [22]:
import torch

from timeit import default_timer as timer


class Generator:

    def __init__(self, batch_size):
        torch.manual_seed(0)
        self.vt = Vocabulary()
        self.VOCAB_SIZE = len(self.vt.build())
        self.BATCH_SIZE = batch_size
        self.NUM_EPOCHS = 1
        self.sentence_preprocessor = SentencePreprocessor(self.vt)
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
        self.ENTAILMENT = 'entailment'
        self.CONTRADICTION = 'contradiction'
        self.model_loader = ModelLoader()
        self.subgenerators = {
            self.ENTAILMENT: SentenceGenerator(self.ENTAILMENT, self.sentence_preprocessor, self.loss_fn, self.vt, self.BATCH_SIZE),
            self.CONTRADICTION: SentenceGenerator(self.CONTRADICTION, self.sentence_preprocessor, self.loss_fn, self.vt, self.BATCH_SIZE)
        }

    def train_model(self):
        for epoch in range(1, self.NUM_EPOCHS + 1):
            for option in [self.ENTAILMENT, self.CONTRADICTION]:
                start_time = timer()
                train_loss = self.subgenerators[option].train_epoch()
                end_time = timer()
                val_loss = self.subgenerators[option].evaluate()
                print(f"Epoch: {epoch}, GeneratorPart: {option}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s")

    def evaluate(self):
        for option in [self.ENTAILMENT, self.CONTRADICTION]:
            val_loss = self.subgenerators[option].evaluate()
            print(f"GeneratorPart: {option}, Val loss: {val_loss:.3f}")

    def load_state(self):
        for option in [self.ENTAILMENT, self.CONTRADICTION]:
            self.model_loader.load_model(self.subgenerators[option])

    def save_state(self):
        for option in [self.ENTAILMENT, self.CONTRADICTION]:
            self.model_loader.save_model(self.subgenerators[option])

    def generate(self, src: str):
        return self.subgenerators[self.ENTAILMENT].translate(src), self.subgenerators[self.CONTRADICTION].translate(src)


In [23]:
#GAN

BATCH_SIZE = 32
generator = Generator(BATCH_SIZE)

pretrained_base_path = './drive/MyDrive/diplom-collab/models/4_l/'

def load_model():
  for option in [generator.ENTAILMENT, generator.CONTRADICTION]:
    checkpoint = torch.load(pretrained_base_path + option + ".pt")
    generator.subgenerators[option].model.load_state_dict(checkpoint['model_state_dict'])
    generator.subgenerators[option].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  
load_model()

In [24]:
def tokens_grads_from_emb(embedding_weight, tokens):
    pool_size = embedding_weight.size()[0]
    grad_x = torch.zeros(pool_size)
    for i in range(pool_size):
        grad_x[i] = torch.matmul(embedding_weight[i], embedding_weight.grad[i])
    tokens_grad = tokens.clone()
    for token_ind in range(tokens_grad.size()[0]):
        for batch_ind in range(tokens_grad.size()[1]):
            if tokens_grad[token_ind][batch_ind] != 0:
                ind = tokens[token_ind][batch_ind]
                tokens_grad[token_ind][batch_ind] = grad_x[ind]
    return tokens_grad

In [25]:
def create_dataloader_for_desc(data, label):
  df = pd.DataFrame(data={'text': data, 'target': [label for _ in range(len(data))]})

  train_text = df['text'].astype('str')
  train_labels = df['target']

  tokens_train = tokenizer.batch_encode_plus(
    train_text.values,
    max_length = 50,
    padding = 'max_length',
    truncation = True
  )

  train_seq = torch.tensor(tokens_train['input_ids'])
  train_mask = torch.tensor(tokens_train['attention_mask'])
  train_y = torch.tensor(data=train_labels.values)

  train_data = TensorDataset(train_seq, train_mask, train_y)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = BATCH_SIZE)
  return train_dataloader


In [26]:
#total_len = 550_152
LIMIT = 10_000

def prepare_train_data(epoch):
  iter = snli_train_all().enumerate().filter(lambda pair: pair[0] >= LIMIT*(epoch-1) and pair[0] < LIMIT*epoch).map(lambda pair: pair[1])

  data_desc_real = []
  data_desc_fake = []

  data_gen_ent = []
  data_gen_contr = []

  for src, tgt in tqdm(iter, total = LIMIT):
    data_desc_real.append(src)
    data_desc_real.append(tgt)
    ent, cont = generator.generate(src)
    data_desc_fake.append(ent)
    data_desc_fake.append(cont)
    data_gen_ent.append((src, ent))
    data_gen_contr.append((src, cont))  

  data_gen = {generator.ENTAILMENT: data_gen_ent, generator.CONTRADICTION: data_gen_contr}
  return data_desc_real, data_desc_fake, data_gen

In [27]:
def train_desc(data_desc_real, data_desc_fake):
  real_desc_dataloader = create_dataloader_for_desc(data_desc_real, 1)
  fake_desc_dataloader = create_dataloader_for_desc(data_desc_fake, 0)
  desc_train(real_desc_dataloader)
  desc_train(fake_desc_dataloader)

In [28]:
def create_dataloader_for_gen(data):
  df = pd.DataFrame(data={'text': [pair[0] for pair in data], 'target': [0 for _ in range(len(data))]})

  train_text = df['text'].astype('str')
  train_labels = df['target']

  tokens_train = tokenizer.batch_encode_plus(
    train_text.values,
    max_length = 50,
    padding = 'max_length',
    truncation = True
  )

  train_seq = torch.tensor(tokens_train['input_ids'])
  train_mask = torch.tensor(tokens_train['attention_mask'])
  train_y = torch.tensor(data=train_labels.values)

  train_data = TensorDataset(train_seq, train_mask, train_y)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = BATCH_SIZE)
  return train_dataloader

In [30]:
def train_gen(data):
    model.train()
    cross_entropy = nn.CrossEntropyLoss()
    for option in [generator.ENTAILMENT, generator.CONTRADICTION]:
        generator.subgenerators[option].model.train()
        train_dataloader = create_dataloader_for_gen(data[option])
        for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            batch = [r.to(DEVICE) for r in batch]
            sent_id, mask, labels = batch
            model.zero_grad()
            preds = model(sent_id, mask)
            loss = cross_entropy(preds, labels)
            loss.backward()

            embedding = model.bert.embeddings.word_embeddings.weight

            desc_grads = tokens_grads_from_emb(embedding, sent_id)
            sentences_batch = data[option][step * BATCH_SIZE:(step + 1) * BATCH_SIZE]
            generator.subgenerators[option].backward_model(desc_grads, sentences_batch)

In [None]:
# paired models
NUM_EPOCHS = 1
for epoch in range(3, NUM_EPOCHS+3):
    data_desc_real, data_desc_fake, data_gen = prepare_train_data(epoch)
    train_desc(data_desc_real, data_desc_fake) 
    train_gen(data_gen)

generator.evaluate()