# Multi-Label classification improving via GAN

In [1]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [3]:
%cd gan-plus-nlp-main/

<IPython.core.display.Javascript object>

In [5]:
# !pip install neptune-client
import os
import gc
import sys
import json
import numpy as np
import pandas as pd
import importlib as imp
import neptune as neptune
from tqdm import tqdm, tqdm_notebook
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
import seaborn as sns
import matplotlib.pyplot as plt

from typing import List, Dict

import torch
import torch.nn.functional as F

import warnings

warnings.simplefilter("ignore")
sys.path.append("..")
secret = json.load(open("secret.json"))

<IPython.core.display.Javascript object>

# Data Loading

In [5]:
from datasets import load_dataset, load_from_disk

dataset_name = "vmalperovich/toxic_comments"  # "go_emotions" vmalperovich/toxic_comments
dataset = load_dataset(
    dataset_name,
    #     ignore_verifications=True,
)
# dataset = load_from_disk(dataset_path="../data/go/")
dataset = dataset.rename_column("label", "labels")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 43410
    })
    validation: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5426
    })
    test: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5427
    })
})

<IPython.core.display.Javascript object>

In [9]:
_type = "validation"
_indexes = np.random.permutation(len(dataset[_type]))
dataset[_type] = dataset[_type].select(_indexes[:3_500])
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 43410
    })
    validation: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 3500
    })
    test: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5427
    })
})

<IPython.core.display.Javascript object>

In [6]:
LABEL_NAMES = dataset["train"].features["labels"].feature.names
get_ids2label = lambda ids: [LABEL_NAMES[t] for t in ids]
NUM_LABELS = len(LABEL_NAMES)
LABEL_NAMES, NUM_LABELS

(['admiration',
  'amusement',
  'anger',
  'annoyance',
  'approval',
  'caring',
  'confusion',
  'curiosity',
  'desire',
  'disappointment',
  'disapproval',
  'disgust',
  'embarrassment',
  'excitement',
  'fear',
  'gratitude',
  'grief',
  'joy',
  'love',
  'nervousness',
  'optimism',
  'pride',
  'realization',
  'relief',
  'remorse',
  'sadness',
  'surprise',
  'neutral'],
 28)

<IPython.core.display.Javascript object>

In [8]:
dataset_df["text"].str.split().apply(len).describe(percentiles=[0.5, 0.7, 0.9, 0.95])

count    5427.000000
mean       12.726000
std         6.672055
min         1.000000
50%        12.000000
70%        17.000000
90%        22.000000
95%        24.000000
max        32.000000
Name: text, dtype: float64

<IPython.core.display.Javascript object>

## Experiment

In [10]:
model_name = "distilbert-base-uncased"
# model_name = "bert-base-uncased"

<IPython.core.display.Javascript object>

In [12]:
import torch
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("There are %d GPU(s) available." % torch.cuda.device_count())
    print("We will use the GPU:", torch.cuda.get_device_name())
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")
print(torch.cuda.memory_allocated())

There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 2080 Ti
0


<IPython.core.display.Javascript object>

In [8]:
import sys

try:
    del sys.modules["base"]
    del sys.modules["model"]
    del sys.modules["model.discriminator"]
    del sys.modules["model.generator"]
    del sys.modules["model.utils"]
    del sys.modules["trainer"]

except:
    print("pass")

gc.collect()
torch.cuda.empty_cache()

import model

model = imp.reload(model)

pass


<IPython.core.display.Javascript object>

In [14]:
CONFIG = dict(
    TASK="multi-label",
    encoder_name=model_name,
    frozen_backbone=False,
    batch_size=32,
    max_seq_length=24,
    noise_size=100,
    dataset_train_size=len(dataset["train"]),
    dataset_valid_size=len(dataset["validation"]),
    dataset_test_size=len(dataset["test"]),
    num_labels=NUM_LABELS,
    label_names=LABEL_NAMES,
    lr_discriminator=5e-5,
    lr_generator=5e-5,
    epsilon=1e-8,
    num_train_epochs=5,
    multi_gpu=False,
    dropout_rate=0.2,
    apply_scheduler=True,
    warmup_proportion_d=0.1,
    warmup_proportion_g=0.0,
    fake_label_index=-1,
    dataset=dataset_name,
    save_path="../weights/best_model.pth",
)
CONFIG["multi-label"] = True
CONFIG["multi-label-trh"] = 0.5

<IPython.core.display.Javascript object>

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_bool_labels(labels, num_classes):
    new_labels = np.zeros(num_classes, dtype=np.bool)
    for i in labels:
        new_labels[i] = True
    return {"labels": new_labels}


tokenize = lambda x: tokenizer(
    x["text"], truncation=True, max_length=CONFIG["max_seq_length"]
)

<IPython.core.display.Javascript object>

In [16]:
from copy import copy
from datasets import Dataset
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

LABELED_SIZE = 400
UNLABELED_SIZE = 3000
FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE)) - 1
multiplier = max(1, multiplier)
print("Multiplier:", multiplier)

np.random.seed(42)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


def prepare_experiment_datasets(LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier):
    tokenize = lambda x: tokenizer(
        x["text"], truncation=True, max_length=CONFIG["max_seq_length"]
    )
    tokenized_dataset = dataset.map(tokenize, batched=True)
    tokenized_dataset = tokenized_dataset.map(
        lambda x: get_bool_labels(x["labels"], NUM_LABELS)
    )
    tokenized_dataset = tokenized_dataset.select_columns(
        ["input_ids", "attention_mask", "labels"]
    )

    tokenized_train_df = tokenized_dataset["train"].to_pandas()
    tokenized_train_df_labeled = tokenized_train_df.sample(LABELED_SIZE)
    tokenized_train_df_labeled["labeled_mask"] = True

    tokenized_train_df = tokenized_train_df.sample(UNLABELED_SIZE)
    tokenized_train_df["labeled_mask"] = False
    tokenized_train_df["labels"] = tokenized_train_df["labels"].apply(
        lambda x: np.ones(NUM_LABELS, np.int) * -100
    )

    for _ in range(multiplier):
        tokenized_train_df = tokenized_train_df.append(tokenized_train_df_labeled)

    tokenized_dataset["train"] = Dataset.from_pandas(
        tokenized_train_df, preserve_index=False
    )
    tokenized_dataset["train_only_labeled"] = Dataset.from_pandas(
        tokenized_train_df_labeled, preserve_index=False
    )
    print(
        "TRAIN (FOR only discriminator):", len(tokenized_dataset["train_only_labeled"])
    )
    print("TRAIN (FOR GAN):", len(tokenized_dataset["train"]))
    return tokenized_dataset


tokenized_dataset = prepare_experiment_datasets(
    LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
)
tokenized_dataset

Loading cached processed dataset at /home/valperovich/projects/other/std/data/go/train/cache-c81ed038386eee0b.arrow


Multiplier: 2


Map:   0%|          | 0/3500 [00:00<?, ? examples/s]

Loading cached processed dataset at /home/valperovich/projects/other/std/data/go/test/cache-f6a1c38e956fc1a8.arrow
Loading cached processed dataset at /home/valperovich/projects/other/std/data/go/train/cache-ff67ee714414d7b1.arrow


Map:   0%|          | 0/3500 [00:00<?, ? examples/s]

Loading cached processed dataset at /home/valperovich/projects/other/std/data/go/test/cache-478e8cf156bc6951.arrow


TRAIN (FOR only discriminator): 400
TRAIN (FOR GAN): 3800


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'labeled_mask'],
        num_rows: 3800
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 3500
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 5427
    })
    train_only_labeled: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'labeled_mask'],
        num_rows: 400
    })
})

<IPython.core.display.Javascript object>

In [131]:
train_only_labeled_dataloader = DataLoader(
    tokenized_dataset["train_only_labeled"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
    collate_fn=data_collator,
    pin_memory=True,
)

train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train"]),
    collate_fn=data_collator,
    pin_memory=True,
)

valid_dataloader = DataLoader(
    tokenized_dataset["validation"],
    batch_size=CONFIG["batch_size"],
    sampler=SequentialSampler(tokenized_dataset["validation"]),
    collate_fn=data_collator,
    pin_memory=True,
)

test_dataloader = DataLoader(
    tokenized_dataset["test"],
    batch_size=32,
    sampler=SequentialSampler(tokenized_dataset["test"]),
    collate_fn=data_collator,
    pin_memory=True,
)

<IPython.core.display.Javascript object>

### Train only discriminator

In [132]:
torch.cuda.empty_cache()
gc.collect()

5917

<IPython.core.display.Javascript object>

In [133]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module

trainer_module = imp.reload(trainer_module)
gan_trainer_module = imp.reload(gan_trainer_module)

<IPython.core.display.Javascript object>

In [134]:
from copy import copy

BASE_CONFIG = copy(CONFIG)
BASE_CONFIG["GAN"] = False
BASE_CONFIG["num_labels"] = NUM_LABELS
BASE_CONFIG["gan_training"] = False
BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE
BASE_CONFIG["num_train_epochs"] = 5

discriminator = model.DiscriminatorForMultiLabelClassification(**BASE_CONFIG)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<IPython.core.display.Javascript object>

In [135]:
from trainer import trainer

trainer = imp.reload(trainer)

BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
trainer = trainer.TrainerSequenceClassification(
    config=BASE_CONFIG,
    discriminator=discriminator,
    train_dataloader=train_only_labeled_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
)


Trainable layers 201


<IPython.core.display.Javascript object>

In [136]:
%%time
run = None
tags = ['test']
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = trainer.config

for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ========")
    train_info = trainer.train_epoch(log_env=run)
    valid_metrics = trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-317


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


	Train loss discriminator: 0.545
	Test loss discriminator: 0.421
	Test accuracy discriminator: 0.470
	Test f1 discriminator: 0.172
	Train loss discriminator: 0.348
	Test loss discriminator: 0.350
	Test accuracy discriminator: 0.619
	Test f1 discriminator: 0.402
	Train loss discriminator: 0.255
	Test loss discriminator: 0.337
	Test accuracy discriminator: 0.632
	Test f1 discriminator: 0.424
	Train loss discriminator: 0.189
	Test loss discriminator: 0.357
	Test accuracy discriminator: 0.637
	Test f1 discriminator: 0.451
	Train loss discriminator: 0.145
	Test loss discriminator: 0.374
	Test accuracy discriminator: 0.624
	Test f1 discriminator: 0.483
CPU times: user 14.9 s, sys: 246 ms, total: 15.1 s
Wall time: 15.5 s


<IPython.core.display.Javascript object>

In [10]:
predict_info = trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

<IPython.core.display.Javascript object>

### Train via GAN

In [11]:
import gc

gc.collect()
torch.cuda.empty_cache()

<IPython.core.display.Javascript object>

In [139]:
from copy import copy

GAN_CONFIG = copy(CONFIG)
GAN_CONFIG["GAN"] = True
GAN_CONFIG["GAN_TYPE"] = "dummy"
GAN_CONFIG["mixed_fake_ratio"] = 0.2
GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
GAN_CONFIG["gan_training"] = True
GAN_CONFIG["noise_type"] = "normal"
GAN_CONFIG["noise_range"] = (0, 1)
GAN_CONFIG["warmup_proportion_d"] = 0.05
GAN_CONFIG["num_train_epochs"] = 5

<IPython.core.display.Javascript object>

In [140]:
generator = model.SimpleSequenceGenerator(
    input_size=CONFIG["noise_size"],
    output_size=discriminator.encoder.config.hidden_size,
)
discriminator = model.DiscriminatorForMultiLabelClassification(**GAN_CONFIG)
generator

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training with GAN mode on!


SimpleSequenceGenerator(
  (layers): Sequential(
    (0): Linear(in_features=100, out_features=768, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=768, out_features=768, bias=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

<IPython.core.display.Javascript object>

In [141]:
from trainer import gan_trainer as gan_trainer_module

gan_trainer_module = imp.reload(gan_trainer_module)

GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
    config=GAN_CONFIG,
    discriminator=discriminator,
    generator=generator,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    save_path=CONFIG["save_path"],
)

Trainable layers 201


<IPython.core.display.Javascript object>

In [142]:
%%time
run = None
# tags = None
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = gan_trainer.config


for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ========")
    train_info = gan_trainer.train_epoch(log_env=run)
    result = gan_trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-318
	Train loss discriminator: 1.238
	Train loss generator: 0.929
Best model saved!
	Test loss discriminator: 0.358
	Test accuracy discriminator: 0.615
	Test f1 discriminator: 0.420
	Train loss discriminator: 1.021
	Train loss generator: 0.823
	Test loss discriminator: 0.356
	Test accuracy discriminator: 0.624
	Test f1 discriminator: 0.411
	Train loss discriminator: 0.906
	Train loss generator: 0.804
Best model saved!
	Test loss discriminator: 0.396
	Test accuracy discriminator: 0.622
	Test f1 discriminator: 0.428
	Train loss discriminator: 0.840
	Train loss generator: 0.794
Best model saved!
	Test loss discriminator: 0.436
	Test accuracy discriminator: 0.628
	Test f1 discriminator: 0.486
	Train loss discriminator: 0.813
	Train loss generator: 0.788
Best model saved!
	Test loss discriminator: 0.454
	Test accuracy discriminator: 0.626
	Test f1 discriminator: 0.514
CPU times: user 1min 18s, sys: 3.63 s, total: 1min 21s
Wall time: 1min

<IPython.core.display.Javascript object>

In [12]:
discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
predict_info = gan_trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

<IPython.core.display.Javascript object>

# Experiments

In [17]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module


# model_name = "distilbert-base-uncased"
# model_name = "bert-base-uncased"
model_name = "google/electra-small-discriminator"
# model_name = "google/electra-base-discriminator"
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tags = ["FINAL"]

<IPython.core.display.Javascript object>

In [18]:
CONFIG["encoder_name"] = model_name
CONFIG["noise_range"] = (-2, 2)
CONFIG["noise_range_str"] = str(CONFIG["noise_range"])
CONFIG["noise_type"] = "uniform"
CONFIG["num_train_epochs"] = 3
print(CONFIG["num_labels"])

NUM_TRIALS_GAN = 2
UNLABELED_SIZE = CONFIG["num_labels"] * 150
print(UNLABELED_SIZE)

28
4200


<IPython.core.display.Javascript object>

In [13]:
for per_label in tqdm_notebook([10, 20, 50, 100, 200]):
    LABELED_SIZE = CONFIG["num_labels"] * per_label
    print(f"\n\n\n****************{LABELED_SIZE}***********\n\n\n")

    CONFIG["per_label_samples"] = per_label
    try:
        del discriminator
    except:
        pass
    torch.cuda.empty_cache()
    gc.collect()
    FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
    multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE)) - 1
    multiplier = max(1, multiplier)
    print("Multiplier:", multiplier)

    tokenized_dataset = prepare_experiment_datasets(
        LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
    )

    train_only_labeled_dataloader = DataLoader(
        tokenized_dataset["train_only_labeled"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    train_dataloader = DataLoader(
        tokenized_dataset["train"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    valid_dataloader = DataLoader(
        tokenized_dataset["validation"],
        batch_size=CONFIG["batch_size"],
        sampler=SequentialSampler(tokenized_dataset["validation"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    test_dataloader = DataLoader(
        tokenized_dataset["test"],
        batch_size=32,
        sampler=SequentialSampler(tokenized_dataset["test"]),
        collate_fn=data_collator,
        pin_memory=True,
    )
    for _ in range(2):
        # NO GAN
        print("NO GAN...")
        try:
            del discriminator
        except:
            pass
        torch.cuda.empty_cache()
        gc.collect()
        BASE_CONFIG = copy(CONFIG)
        BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
        BASE_CONFIG["GAN"] = False
        BASE_CONFIG["gan_training"] = False
        BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        discriminator = model.DiscriminatorForMultiLabelClassification(**BASE_CONFIG)
        print(discriminator.encoder_name)
        trainer = trainer_module.TrainerSequenceClassification(
            config=BASE_CONFIG,
            discriminator=discriminator,
            train_dataloader=train_only_labeled_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = trainer.config

        for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ==")
            train_info = trainer.train_epoch(log_env=run)
            valid_metrics = trainer.validation(log_env=run)
        predict_info = trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()

    for _ in range(NUM_TRIALS_GAN):
        # GAN
        print("GAN...")
        del discriminator
        gc.collect()
        torch.cuda.empty_cache()
        GAN_CONFIG = copy(CONFIG)
        GAN_CONFIG["GAN"] = True
        GAN_CONFIG["gan_training"] = True
        GAN_CONFIG["GAN_TYPE"] = "dummy"
        GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
        GAN_CONFIG["FULL_SIZE"] = FULL_SIZE
        discriminator = model.DiscriminatorForMultiLabelClassification(**GAN_CONFIG)
        generator = model.SimpleSequenceGenerator(
            input_size=CONFIG["noise_size"],
            output_size=discriminator.encoder.config.hidden_size,
        )

        GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
        gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
            config=GAN_CONFIG,
            discriminator=discriminator,
            generator=generator,
            train_dataloader=train_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
            save_path=CONFIG["save_path"],
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = gan_trainer.config

        for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ==")
            train_info = gan_trainer.train_epoch(log_env=run)
            result = gan_trainer.validation(log_env=run)
        discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
        predict_info = gan_trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()

<IPython.core.display.Javascript object>