# Multiple-choice QA  improving via GAN

In [5]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [1]:
%cd gan-plus-nlp-main/

In [2]:
# !pip install neptune-client
import os
import gc
import sys
import json
import numpy as np
import pandas as pd
import importlib as imp
import neptune as neptune
from tqdm import tqdm, tqdm_notebook
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from typing import List, Dict


import torch
import torch.nn.functional as F
import warnings

warnings.simplefilter("ignore")

sys.path.append("..")
# sys.path.append('gan-text-classification')
secret = json.load(open("secret.json"))

# Data Loading

In [92]:
from datasets import load_dataset, load_from_disk

dataset_name = "swag"  # cosmos_qa, swag
dataset = load_dataset(dataset_name, ignore_verifications=True)
# dataset = load_from_disk(dataset_path="../data/cosmos/cosmos_qa/")
# dataset = load_dataset(dataset_name, 'regular', ignore_verifications=True)
dataset = dataset.rename_column("label", "labels")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels'],
        num_rows: 25262
    })
    test: Dataset({
        features: ['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels'],
        num_rows: 6963
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels'],
        num_rows: 2985
    })
})

<IPython.core.display.Javascript object>

In [3]:
# _type = "test"
# _indexes = np.random.permutation(len(dataset[_type]))
# dataset[_type] = dataset[_type].select(_indexes[:6_000])
# _type = "validation"
# _indexes = np.random.permutation(len(dataset[_type]))
# dataset[_type] = dataset[_type].select(_indexes[:3_000])
# dataset

In [94]:
# LABEL_NAMES = dataset["train"].features["labels"].names
LABEL_NAMES = dataset["train"].to_pandas().labels.unique()

# NUM_LABELS = dataset["train"].features["labels"].num_classes
NUM_LABELS = len(LABEL_NAMES)
LABEL_NAMES = list(map(str, range(NUM_LABELS)))
get_ids2label = lambda ids: [LABEL_NAMES[t] for t in ids]
LABEL_NAMES, NUM_LABELS

(['0', '1', '2', '3'], 4)

<IPython.core.display.Javascript object>

In [96]:
# dataset_df["context"].str.split().apply(len).describe(percentiles=[0.5, 0.7, 0.9, 0.95])
# dataset_df["context"].str.split().apply(len).describe(percentiles=[0.5, 0.7, 0.9, 0.95])

<IPython.core.display.Javascript object>

In [97]:
dataset_df["labels"].value_counts()

2    761
3    751
0    744
1    729
Name: labels, dtype: int64

<IPython.core.display.Javascript object>

## Experiment

In [98]:
import torch
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("There are %d GPU(s) available." % torch.cuda.device_count())
    print("We will use the GPU:", torch.cuda.get_device_name())
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")
print(torch.cuda.memory_allocated())

There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 2080 Ti
3567115264


<IPython.core.display.Javascript object>

In [100]:
import sys

try:
    del sys.modules["base"]
    del sys.modules["model"]
    del sys.modules["model.discriminator"]
    del sys.modules["model.generator"]
    del sys.modules["model.utils"]
    # del sys.modules['model.generator']
    del sys.modules["trainer"]

    # del sys.modules['data_loader']
except:
    print("pass")

gc.collect()
torch.cuda.empty_cache()

import model

model = imp.reload(model)

<IPython.core.display.Javascript object>

In [8]:
model_name = "bert-base-uncased"

<IPython.core.display.Javascript object>

In [101]:
CONFIG = dict(
    TASK="multiple-choice",
    encoder_name=model_name,
    frozen_backbone=False,
    batch_size=8,
    max_seq_length=128,
    noise_size=100,
    dataset_train_size=len(dataset["train"]),
    dataset_valid_size=len(dataset["validation"]),
    dataset_test_size=len(dataset["test"]),
    num_labels=NUM_LABELS,
    label_names=LABEL_NAMES,
    lr_discriminator=5e-5,
    lr_generator=5e-5,
    epsilon=1e-8,
    num_train_epochs=5,
    multi_gpu=False,
    dropout_rate=0.2,
    apply_scheduler=True,
    warmup_proportion_d=0.1,
    warmup_proportion_g=0.0,
    fake_label_index=-1,
    dataset=dataset_name,
    save_path="../weights/best_model.pth",
)

<IPython.core.display.Javascript object>

In [103]:
from dataclasses import dataclass
import torch
from transformers.tokenization_utils_base import (
    PreTrainedTokenizerBase,
    PaddingStrategy,
)
from typing import Optional, Union


from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

# ending_names = ["ending0", "ending1", "ending2", "ending3"]
ending_names = ["answer0", "answer1", "answer2", "answer3"]
context_name = "context"  # sent1 context
question_name = "question"  # sent2 question


def tokenize_func(examples, num_options=4):
    first_sentences = [[context] * num_options for context in examples[context_name]]
    question_headers = examples[question_name]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names]
        for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        max_length=CONFIG["max_seq_length"],
    )
    return {
        k: [v[i : i + num_options] for i in range(0, len(v), num_options)]
        for k, v in tokenized_examples.items()
    }


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        try:
            labeled_mask = [feature.pop("labeled_mask") for feature in features]
        except:
            labeled_mask = [True] * len(features)
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
            for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        batch["labeled_mask"] = torch.tensor(labeled_mask, dtype=torch.bool)
        return batch

<IPython.core.display.Javascript object>

In [104]:
from copy import copy
from datasets import Dataset
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

LABELED_SIZE = 40
UNLABELED_SIZE = 1000
FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE))
multiplier = max(1, multiplier)
print("Multiplier:", multiplier)

np.random.seed(42)

data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)


def prepare_experiment_datasets(LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier):
    tokenized_dataset = dataset.map(tokenize_func, batched=True)
    try:
        tokenized_dataset = tokenized_dataset.select_columns(
            ["labels", "input_ids", "token_type_ids", "attention_mask"]
        )
    except:
        tokenized_dataset = tokenized_dataset.select_columns(
            ["labels", "input_ids", "attention_mask"]
        )
    tokenized_train_df = tokenized_dataset["train"].to_pandas()
    tokenized_train_df_labeled = tokenized_train_df.sample(LABELED_SIZE)
    tokenized_train_df_labeled["labeled_mask"] = True

    tokenized_train_df = tokenized_train_df.sample(UNLABELED_SIZE)
    tokenized_train_df["labeled_mask"] = False
    tokenized_train_df["labels"] = -100

    for _ in range(multiplier):
        tokenized_train_df = tokenized_train_df.append(tokenized_train_df_labeled)

    tokenized_dataset["train"] = Dataset.from_pandas(
        tokenized_train_df, preserve_index=False
    )
    tokenized_dataset["train_only_labeled"] = Dataset.from_pandas(
        tokenized_train_df_labeled, preserve_index=False
    )
    print(
        "TRAIN (FOR only discriminator):", len(tokenized_dataset["train_only_labeled"])
    )
    print("TRAIN (FOR GAN):", len(tokenized_dataset["train"]))
    # tokenized_train_df.labels.value_counts()
    return tokenized_dataset


tokenized_dataset = prepare_experiment_datasets(
    LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
)
tokenized_dataset

Multiplier: 4


Map:   0%|          | 0/25262 [00:00<?, ? examples/s]

Map:   0%|          | 0/2985 [00:00<?, ? examples/s]



TRAIN (FOR only discriminator): 40
TRAIN (FOR GAN): 1160


DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask'],
        num_rows: 1160
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2985
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2985
    })
    train_only_labeled: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask'],
        num_rows: 40
    })
})

<IPython.core.display.Javascript object>

In [100]:
train_only_labeled_dataloader = DataLoader(
    tokenized_dataset["train_only_labeled"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
    collate_fn=data_collator,
    pin_memory=True,
)

train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train"]),
    collate_fn=data_collator,
    pin_memory=True,
)

valid_dataloader = DataLoader(
    tokenized_dataset["validation"],
    batch_size=CONFIG["batch_size"],
    sampler=SequentialSampler(tokenized_dataset["validation"]),
    collate_fn=data_collator,
    pin_memory=True,
)

test_dataloader = DataLoader(
    tokenized_dataset["test"],
    batch_size=16,
    sampler=SequentialSampler(tokenized_dataset["test"]),
    collate_fn=data_collator,
    pin_memory=True,
)

<IPython.core.display.Javascript object>

### Train only discriminator

In [101]:
import gc

# del discriminator
torch.cuda.empty_cache()
gc.collect()

1236

<IPython.core.display.Javascript object>

In [102]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module

trainer_module = imp.reload(trainer_module)
gan_trainer_module = imp.reload(gan_trainer_module)

<IPython.core.display.Javascript object>

In [103]:
from copy import copy

BASE_CONFIG = copy(CONFIG)
BASE_CONFIG["GAN"] = False
BASE_CONFIG["gan_training"] = False
BASE_CONFIG["num_labels"] = NUM_LABELS
BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE
BASE_CONFIG["num_train_epochs"] = 3
# BASE_CONFIG

<IPython.core.display.Javascript object>

In [9]:
discriminator = model.DiscriminatorForMultipleChoice(**BASE_CONFIG)

<IPython.core.display.Javascript object>

In [105]:
from trainer import trainer

trainer = imp.reload(trainer)

BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
trainer = trainer.TrainerSequenceClassification(
    config=BASE_CONFIG,
    discriminator=discriminator,
    train_dataloader=train_only_labeled_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
)



Trainable layers 201


<IPython.core.display.Javascript object>

In [106]:
%%time
run = None
tags = ['test']
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = trainer.config

for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ========")
    train_info = trainer.train_epoch(log_env=run)
    valid_metrics = trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-321


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


	Train loss discriminator: 1.439
	Test loss discriminator: 1.383
	Test accuracy discriminator: 0.304
	Test f1 discriminator: 0.304
	Train loss discriminator: 1.383
	Test loss discriminator: 1.352
	Test accuracy discriminator: 0.280
	Test f1 discriminator: 0.280
	Train loss discriminator: 1.298
	Test loss discriminator: 1.330
	Test accuracy discriminator: 0.397
	Test f1 discriminator: 0.397
CPU times: user 1min 50s, sys: 602 ms, total: 1min 51s
Wall time: 1min 51s


<IPython.core.display.Javascript object>

In [10]:
%%time
predict_info = trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

<IPython.core.display.Javascript object>

### Train via GAN

In [108]:
import gc

torch.cuda.empty_cache()
gc.collect()

197

<IPython.core.display.Javascript object>

In [109]:
from copy import copy

GAN_CONFIG = copy(CONFIG)
GAN_CONFIG["GAN"] = True
GAN_CONFIG["gan_training"] = True
GAN_CONFIG["GAN_TYPE"] = "dummy" 
GAN_CONFIG["mixed_fake_ratio"] = 0.2
GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
GAN_CONFIG["noise_type"] = "normal"
GAN_CONFIG["noise_range"] = (0, 1)
GAN_CONFIG["warmup_proportion_d"] = 0.05
GAN_CONFIG["gen_multiplier"] = 4
GAN_CONFIG["num_train_epochs"] = 3
# GAN_CONFIG

<IPython.core.display.Javascript object>

In [110]:
generator = model.SimpleSequenceGenerator(
    input_size=CONFIG["noise_size"],
    output_size=discriminator.encoder.config.hidden_size,
)

discriminator = model.DiscriminatorForMultipleChoice(**GAN_CONFIG)
generator

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training with GAN mode on!


SimpleSequenceGenerator(
  (layers): Sequential(
    (0): Linear(in_features=100, out_features=768, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=768, out_features=768, bias=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

<IPython.core.display.Javascript object>

In [111]:
from trainer import gan_trainer as gan_trainer_module

gan_trainer_module = imp.reload(gan_trainer_module)

GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
gan_trainer = gan_trainer_module.GANTrainerMultipleChoice(
    config=GAN_CONFIG,
    discriminator=discriminator,
    generator=generator,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    save_path=CONFIG['save_path']
)


Trainable layers 201


<IPython.core.display.Javascript object>

In [112]:
%%time
run = None
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = gan_trainer.config


for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ========")
    train_info = gan_trainer.train_epoch(log_env=run)
    result = gan_trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-322
	Train loss discriminator: 1.976
	Train loss generator: 0.901
Best model saved!
	Test loss discriminator: 1.457
	Test accuracy discriminator: 0.317
	Test f1 discriminator: 0.317
	Train loss discriminator: 1.140
	Train loss generator: 0.842
	Test loss discriminator: 1.903
	Test accuracy discriminator: 0.314
	Test f1 discriminator: 0.314
	Train loss discriminator: 0.816
	Train loss generator: 0.808
Best model saved!
	Test loss discriminator: 1.963
	Test accuracy discriminator: 0.333
	Test f1 discriminator: 0.333
CPU times: user 5min 19s, sys: 1.44 s, total: 5min 20s
Wall time: 5min 21s


<IPython.core.display.Javascript object>

In [11]:
%%time
discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
predict_info = gan_trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

<IPython.core.display.Javascript object>

# Experiments

In [129]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module



# model_name = "bert-base-uncased"
# model_name = "google/electra-small-discriminator"
model_name = "google/electra-base-discriminator"
# model_name = "distilbert-base-uncased"


tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)

tags = ["FINAL", "updated_qa"]

<IPython.core.display.Javascript object>

In [130]:
CONFIG["encoder_name"] = model_name
CONFIG["noise_range"] = (-2, 2)
CONFIG["noise_range_str"] = str(CONFIG["noise_range"])
CONFIG["noise_type"] = "uniform"
CONFIG["num_train_epochs"] = 3
CONFIG["gen_multiplier"] = 4
print(CONFIG["num_labels"])

NUM_TRIALS_GAN = 2
UNLABELED_SIZE = CONFIG["num_labels"] * 200
print(UNLABELED_SIZE)

4
800


<IPython.core.display.Javascript object>

In [12]:
for per_label in tqdm_notebook([5, 10, 20, 50]):
    LABELED_SIZE = CONFIG["num_labels"] * per_label
    print(f"\n\n\n****************{LABELED_SIZE}***********\n\n\n")

    CONFIG["per_label_samples"] = per_label
    try:
        del discriminator
    except:
        pass
    torch.cuda.empty_cache()
    gc.collect()
    FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
    multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE)) - 1
    multiplier = max(1, multiplier)
    print("Multiplier:", multiplier)

    tokenized_dataset = prepare_experiment_datasets(
        LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
    )

    train_only_labeled_dataloader = DataLoader(
        tokenized_dataset["train_only_labeled"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    train_dataloader = DataLoader(
        tokenized_dataset["train"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    valid_dataloader = DataLoader(
        tokenized_dataset["validation"],
        batch_size=CONFIG["batch_size"],
        sampler=SequentialSampler(tokenized_dataset["validation"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    test_dataloader = DataLoader(
        tokenized_dataset["test"],
        batch_size=16,
        sampler=SequentialSampler(tokenized_dataset["test"]),
        collate_fn=data_collator,
        pin_memory=True,
    )
    for _ in range(1):
        # NO GAN
        print("NO GAN...")
        try:
            del discriminator
        except:
            pass
        torch.cuda.empty_cache()
        gc.collect()
        BASE_CONFIG = copy(CONFIG)
        BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
        BASE_CONFIG["GAN"] = False
        BASE_CONFIG["gan_training"] = False
        BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        discriminator = model.DiscriminatorForMultipleChoice(**BASE_CONFIG)
        print(discriminator.encoder_name)
        trainer = trainer_module.TrainerSequenceClassification(
            config=BASE_CONFIG,
            discriminator=discriminator,
            train_dataloader=train_only_labeled_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = trainer.config

        for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ==")
            train_info = trainer.train_epoch(log_env=run)
            valid_metrics = trainer.validation(log_env=run)
        predict_info = trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()

    for _ in range(NUM_TRIALS_GAN):
        # GAN
        print("GAN...")
        del discriminator
        gc.collect()
        torch.cuda.empty_cache()
        GAN_CONFIG = copy(CONFIG)
        GAN_CONFIG["GAN"] = True
        GAN_CONFIG["gan_training"] = True
        GAN_CONFIG["GAN_TYPE"] = "dummy"
        GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
        GAN_CONFIG["FULL_SIZE"] = FULL_SIZE
        discriminator = model.DiscriminatorForMultipleChoice(**GAN_CONFIG)
        generator = model.SimpleSequenceGenerator(
            input_size=CONFIG["noise_size"],
            output_size=discriminator.encoder.config.hidden_size,
        )

        GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
        gan_trainer = gan_trainer_module.GANTrainerMultipleChoice(
            config=GAN_CONFIG,
            discriminator=discriminator,
            generator=generator,
            train_dataloader=train_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
            save_path=CONFIG["save_path"],
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = gan_trainer.config

        for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ==")
            train_info = gan_trainer.train_epoch(log_env=run)
            result = gan_trainer.validation(log_env=run)
        discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
        predict_info = gan_trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()

<IPython.core.display.Javascript object>