In [1]:
!pip install transformers==4.28.0 datasets evaluate torch sentencepiece tokenizers sacrebleu

Collecting transformers==4.28.0
  Downloading transformers-4.28.0-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     

In [2]:
import numpy as np
import evaluate
import torch
from torch import nn
from typing import Dict, Any
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, DataCollatorWithPadding, Seq2SeqTrainingArguments, Seq2SeqTrainer,  DataCollatorForSeq2Seq, TrainingArguments, Trainer
from datasets import DatasetDict, Dataset

# Discriminator

In [18]:
## Assuming the dataset is a json in the format [{first_lang:" ", second_lang:" ", context:" "}, {first_lang:" ", second_lang:" ", context:" "},...] in a DataDict
## Also assuming that the user_model is a vaild model for classification
class Discriminator():
    def __init__(self, user_model: str, dataset=None, output_dir="discriminator", first_lang='first_lang', second_lang='second_lang', target='target', learning_rate=2e-5,
                 per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=2, weight_decay=0.01,
                 evaluation_strategy="epoch", save_strategy="epoch",):

        # main stuff
        self.dataset = dataset
        self._tokenizer = AutoTokenizer.from_pretrained(user_model, truncation=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(user_model, num_labels=2)
        self.data_collator = DataCollatorWithPadding(tokenizer=self._tokenizer)

        # Args blablabla
        self.output_dir = output_dir
        self.learning_rate = learning_rate
        self.per_device_train_batch_size = per_device_train_batch_size
        self.per_device_eval_batch_size = per_device_eval_batch_size
        self.num_train_epochs = num_train_epochs
        self.weight_decay = weight_decay
        self.evaluation_strategy = evaluation_strategy
        self.save_strategy = save_strategy
        self.first_lang = first_lang
        self.second_lang = second_lang
        self.target = target



    ## Encode the dataset let's goooo
    def _model_inputs(self):

        lang1 = self.first_lang
        lang2 = self.second_lang
        target = self.target

        def preprocess_function(examples):
          inputs = [example[lang1] + ' ' + example[lang2] for example in examples['translation']]
          labels = [int(example[target]) for example in examples['translation']]

          model_inputs = self._tokenizer(inputs, padding="max_length", truncation=True)
          model_inputs['labels'] = labels
          model_inputs['text'] = inputs
          return model_inputs

        token_dataset = self.dataset.map(preprocess_function, batched=True)

        token_train = token_dataset['train']
        token_eval = token_dataset['test']

        return token_train, token_eval

    #Training
    def train(self):

      accuracy = evaluate.load("accuracy")
      def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)


      # Prepare the training arguments
      training_args = TrainingArguments(
      output_dir = self.output_dir,
      learning_rate = self.learning_rate,
      per_device_train_batch_size = self.per_device_train_batch_size,
      per_device_eval_batch_size = self.per_device_eval_batch_size,
      num_train_epochs = self.num_train_epochs,
      weight_decay = self.weight_decay,
      evaluation_strategy = self.evaluation_strategy,
      save_strategy = self.save_strategy,
      load_best_model_at_end = True,
      )

      # Create eval and train datasets from the encoded data with labels
      token_train, token_eval = self._model_inputs()

      # Create a Trainer and train the model
      trainer = Trainer(
          model=self.model,
          args=training_args,
          train_dataset=token_train,
          eval_dataset=token_eval,
          data_collator=self.data_collator,
          compute_metrics=compute_metrics,
      )

      trainer.train()

    ## Prediction
    def predict(self, text1, text2, trained_model_dir=None):
        if trained_model_dir == None:
          trained_model_dir = self.output_dir
        # Initialize the tokenizer and model
        tokenizer = self._tokenizer
        model = AutoModelForSequenceClassification.from_pretrained(trained_model_dir)

        # Tokenize the input text
        inputs = tokenizer(text1, text2, padding="max_length", truncation=True, return_tensors="pt")

        # Ensure the model is in evaluation mode
        model.eval()

        # Perform the prediction
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=1)

        return predictions


# CustomTrainerWithDiscriminator

In [21]:

def postprocess_text(predictions, labels):
    """
    Postprocesses the generated predictions and labels.
    """
    predictions = [pred.strip() for pred in predictions]
    labels = [[label.strip()] for label in labels]
    return predictions, labels

class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, discriminator, discriminator_dir, tokenizer, model, args, train_dataset=None, eval_dataset=None, data_collator=None):
        """
        Custom trainer for sequence-to-sequence tasks with a discriminator.
        """
        super().__init__(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
        self.discriminator = discriminator
        self.tokenizer = tokenizer
        self.discriminator_dir = discriminator_dir

    def compute_metrics(self, inputs, eval_predictions):
        """
        Computes evaluation metrics including BLEU and discriminator-based loss.
        """
        metric = evaluate.load("sacrebleu")
        generated_predictions, labels = eval_predictions

        if isinstance(generated_predictions, tuple):
            generated_predictions = generated_predictions[0]

        decoded_generated_preds = self.tokenizer.batch_decode(generated_predictions, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_generated_preds, decoded_labels = postprocess_text(decoded_generated_preds, decoded_labels)

        discriminator_predictions = self.discriminator.predict(decoded_labels, decoded_generated_preds, self.discriminator_dir)
        labels = torch.tensor(discriminator_predictions, dtype=torch.long, device=self.model.device)

        model_outputs = self.model(**inputs)
        logits = model_outputs.logits

        # Compute cross-entropy loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))

        # Compute BLEU score
        bleu_result = metric.compute(predictions=decoded_generated_preds, references=decoded_labels)
        bleu_result = {"bleu": bleu_result["score"]}

        prediction_lens = [np.count_nonzero(pred != self.tokenizer.pad_token_id) for pred in generated_predictions]
        bleu_result["gen_len"] = np.mean(prediction_lens)
        bleu_result = {k: round(v, 4) for k, v in bleu_result.items()}

        return loss, bleu_result

# Generator

In [49]:
## Assuming the dataset is a json in the format [{lang:" ", target:" "}, {lang:" ", target:" "},...] in a DataDict
## Also assuming that the user_model is a vaild model for text generation
class Generator():
    def __init__(self, user_model=None, _tokenizer=None, dataset=None, discriminator=None, lang='lang', target='target', output_dir="generator", discriminator_dir=None, learning_rate=2e-5,
                 per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=2, weight_decay=0.01,
                 evaluation_strategy="epoch", save_strategy="epoch", split=0.3,):

        # main stuff
        if user_model is None:
            user_model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
        self.user_model = user_model

        if _tokenizer is None:
            _tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
        self._tokenizer = _tokenizer

        self.dataset = dataset
        self.discriminator = discriminator
        self.discriminator_dir = discriminator_dir

        self.data_collator = DataCollatorForSeq2Seq(tokenizer=self._tokenizer, model=self.user_model)


        # Args blablabla
        self.lang = lang
        self.target = target
        self.output_dir = output_dir
        self.learning_rate = learning_rate
        self.per_device_train_batch_size = per_device_train_batch_size
        self.per_device_eval_batch_size = per_device_eval_batch_size
        self.num_train_epochs = num_train_epochs
        self.weight_decay = weight_decay
        self.evaluation_strategy = evaluation_strategy
        self.save_strategy = save_strategy
        self.split = split


    ## Function for preprocessing
    def _model_inputs(self):

        source_lang = self.lang
        target_lang = self.target

        def preprocess_function(examples):
          inputs = [example[source_lang] for example in examples['translation']]
          targets = [example[target_lang] for example in examples['translation']]
          model_inputs = self._tokenizer(inputs, text_target=targets, max_length=128, truncation=True, padding=True, return_tensors="pt")
          return model_inputs

        token_dataset = self.dataset.map(preprocess_function, batched=True)

        token_train = token_dataset['train']
        token_eval = token_dataset['test']

        return token_train, token_eval


    # Training
    def train(self):

      training_args = Seq2SeqTrainingArguments(
          output_dir = self.output_dir,
          evaluation_strategy = self.evaluation_strategy,
          save_strategy = self.save_strategy,
          learning_rate = self.learning_rate,
          per_device_train_batch_size= self.per_device_train_batch_size,
          per_device_eval_batch_size= self.per_device_eval_batch_size,
          weight_decay = self.weight_decay,
          save_total_limit = 3,
          num_train_epochs = self.num_train_epochs,
          predict_with_generate=True,
          load_best_model_at_end = True,
      )

      token_train, token_eval = self._model_inputs()

      # Create a Trainer and train the model
      trainer = CustomSeq2SeqTrainer(
          discriminator=self.discriminator,
          discriminator_dir=self.discriminator_dir,
          tokenizer=self._tokenizer,
          model=self.user_model,
          args=training_args,
          train_dataset=token_train,
          eval_dataset=token_eval,
          data_collator=self.data_collator,
       )

      trainer.train()

    ## Prediction
    def predict(self, text, trained_generator_dir):
        inputs = self._tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
        model = AutoModelForSeq2SeqLM.from_pretrained(trained_generator_dir)


        generated_ids = model.generate(**inputs, max_length=50, num_return_sequences=1)
        generated_text = self._tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return generated_text


# Training

In [24]:
import json

json_file_path = "/content/drive/MyDrive/dataset-context-final.json"

with open(json_file_path, 'r') as json_file:
    json_data = json.load(json_file)

In [25]:
def discriminator_dataset(dataset, first_lang, second_lang, target):

  json_list = []

  for item in dataset:
    temp1 = list(item.values())[0]
    temp2 = list(item.values())[1]
    temp3 = list(item.values())[2]

    data_dict = {
        first_lang: temp1,
        second_lang: temp2,
        target: temp3
    }

    json_list.append(data_dict)

  id = [str(index) for index, value in enumerate(dataset)]
  train_dataset = {
                    'id':id,
                    'translation':json_list
  }

  train_dataset = DatasetDict({"train": Dataset.from_dict(train_dataset)})
  return train_dataset

smol_context_dataset = discriminator_dataset(json_data, 'en', 'zh', 'context')
smol_context_dataset = smol_context_dataset['train'].train_test_split(test_size=0.2)

In [37]:
smol_context_dataset['test']['translation'][200]

{'context': '0',
 'en': '(b) Communications from Parties not included in Annex I to the Convention.',
 'zh': '(二)未列入《公约》附件的缔约方的通讯。'}

In [42]:
from datasets import load_dataset

opus100 = load_dataset("opus100", "en-zh")
opus_smol = opus100['train']['translation'][0:10000]

def generator_dataset(dataset, first_lang, second_lang):

  json_list = []

  for item in dataset:
    lang = list(item.values())[0]
    target = list(item.values())[1]

    data_dict = {
        first_lang: lang,
        second_lang: target
    }

    json_list.append(data_dict)

  id = [str(index) for index, value in enumerate(dataset)]
  train_dataset = {
                    'id':id,
                    'translation':json_list
  }

  train_dataset = DatasetDict({"train": Dataset.from_dict(train_dataset)})
  return train_dataset

opus_smol_dataset = generator_dataset(opus_smol, 'en', 'zh')
opus_smol_dataset = opus_smol_dataset['train'].train_test_split(test_size=0.2)

In [43]:
opus_smol_dataset['test']['translation'][200]

{'en': 'I just got in my head a little bit.', 'zh': '我們這到底是怎麼了 我不知道'}

In [27]:
user_model_name = "bert-base-uncased"
discriminator = Discriminator(user_model_name, smol_context_dataset, first_lang='en', second_lang='zh', target='context')
discriminator.train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Map:   0%|          | 0/15999 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0401,0.024602,0.99425
2,0.0118,0.027647,0.9935


In [39]:
text1 = 'T(b) Communications from Parties not included in Annex I to the Convention.'
text2 = '(二)未列入《公约》附件的缔约方的通讯。'
discriminator.predict(text1, text2, trained_model_dir='/content/discriminator/checkpoint-2000')

tensor([0])

In [51]:
from transformers import BartForConditionalGeneration, BartTokenizer

discriminator_dir = '/content/discriminator/checkpoint-2000'
user_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", forced_bos_token_id=0)
_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

generator = Generator(dataset=opus_smol_dataset, user_model=user_model, _tokenizer=_tokenizer, discriminator=discriminator, discriminator_dir=discriminator_dir, lang='en', target='zh', num_train_epochs=2,  output_dir="generator")

In [52]:
generator.train()

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,2.1066,0.981034
2,1.0476,0.94564


In [53]:
generator.predict("I just got in my head a little bit.", "/content/generator/checkpoint-1000")

'我们没有这些'