# Paired sequence classification improving via GAN

**colab-to-git setting**

In [1]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [1]:
%cd gan-plus-nlp-main/

In [2]:
# !pip install neptune-client
import os
import gc
import sys
import json
import numpy as np
import pandas as pd
import importlib as imp
import neptune as neptune
from tqdm import tqdm, tqdm_notebook
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import warnings

from typing import List, Dict
import torch
import torch.nn.functional as F

warnings.simplefilter("ignore")
sys.path.append("..")
# sys.path.append('gan-text-classification')
secret = json.load(open("secret.json"))

# Data Loading

In [73]:
from datasets import load_dataset

dataset_name = "paws"  # snli, anli paws labeled_final
dataset = load_dataset(dataset_name, "labeled_final", ignore_verifications=True)
# dataset = load_dataset(dataset_name, ignore_verifications=True)
dataset = dataset.rename_column("label", "labels")
dataset

Found cached dataset paws (/home/valperovich/.cache/huggingface/datasets/paws/labeled_final/1.1.0/8d567c6472623f42bd2cc635cad06932d0f0cd2f897db56013c1180f4317d338)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'labels'],
        num_rows: 49401
    })
    test: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'labels'],
        num_rows: 8000
    })
    validation: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'labels'],
        num_rows: 8000
    })
})

<IPython.core.display.Javascript object>

In [3]:
# dataset["train"] = dataset["train_r3"]
# dataset["validation"] = dataset["dev_r3"]
# dataset["test"] = dataset["test_r3"]
# dataset.pop("train_r1")
# dataset.pop("dev_r1")
# dataset.pop("test_r1")
# dataset.pop("train_r2")
# dataset.pop("dev_r2")
# dataset.pop("test_r2")
# dataset.pop("train_r3")
# dataset.pop("dev_r3")
# dataset.pop("test_r3")
# dataset

In [4]:
# _type = "validation"
# _indexes = np.random.permutation(len(dataset[_type]))
# dataset[_type] = dataset[_type].select(_indexes[:3_500])
# dataset

In [59]:
# dataset_df = dataset["train"].to_pandas()
# indexes = dataset_df[dataset_df.labels != -1].index
# dataset["train"] = dataset["train"].select(indexes)


# dataset_df = dataset["test"].to_pandas()
# indexes = dataset_df[dataset_df.labels != -1].index
# dataset["test"] = dataset["test"].select(indexes)

<IPython.core.display.Javascript object>

In [5]:
# _indexes = np.random.permutation(len(dataset['train']))
# train_size = int(len(dataset['train']) * 0.75)
# train_indexes = _indexes[:train_size]
# valid_indexes = _indexes[train_size:]
# dataset['validation'] = dataset['train'].select(valid_indexes)

# for _type in ['test', 'validation']:
#     if len(dataset[_type]) > 10_000:
#         _indexes = np.random.permutation(len(dataset[_type]))
#         dataset[_type] = dataset[_type].select(_indexes[:10_000])
# # dataset['train'] = dataset['train'].select(train_indexes)
# dataset

In [75]:
LABEL_NAMES = dataset["train"].features["labels"].names
get_ids2label = lambda ids: [LABEL_NAMES[t] for t in ids]
NUM_LABELS = dataset["train"].features["labels"].num_classes
LABEL_NAMES, NUM_LABELS

(['0', '1'], 2)

<IPython.core.display.Javascript object>

In [77]:
# dataset_df["premise"].str.split().apply(len).describe(percentiles=[0.5, 0.7, 0.9, 0.95])
sent_col1 = "sentence1"  # premise sentence1
sent_col2 = "sentence2"  # hypothesis sentence2
dataset_df[sent_col1].str.split().apply(len).describe(percentiles=[0.5, 0.7, 0.9, 0.95])

count    8000.000000
mean       21.399125
std         5.482890
min         5.000000
50%        21.000000
70%        25.000000
90%        29.000000
95%        30.000000
max        36.000000
Name: sentence1, dtype: float64

<IPython.core.display.Javascript object>

In [78]:
dataset_df["labels"].value_counts()

0    4464
1    3536
Name: labels, dtype: int64

<IPython.core.display.Javascript object>

## Experiment

In [80]:
import torch
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("There are %d GPU(s) available." % torch.cuda.device_count())
    print("We will use the GPU:", torch.cuda.get_device_name())
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")
print(torch.cuda.memory_allocated())

There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 2080 Ti
2208788480


<IPython.core.display.Javascript object>

In [6]:
model_name = "distilbert-base-uncased"

In [81]:
import sys

try:
    del sys.modules["base"]
    del sys.modules["model"]
    del sys.modules["model.discriminator"]
    del sys.modules["model.generator"]
    del sys.modules["model.utils"]
    del sys.modules["trainer"]
except:
    print("pass")

gc.collect()
torch.cuda.empty_cache()

import model

model = imp.reload(model)

<IPython.core.display.Javascript object>

In [82]:
CONFIG = dict(
    TASK="paired-classification",
    encoder_name=model_name,
    frozen_backbone=False,
    batch_size=32,
    max_seq_length=64,
    noise_size=100,
    dataset_train_size=len(dataset["train"]),
    dataset_valid_size=len(dataset["validation"]),
    dataset_test_size=len(dataset["test"]),
    num_labels=NUM_LABELS,
    label_names=LABEL_NAMES,
    lr_discriminator=5e-5,
    lr_generator=5e-5,
    epsilon=1e-8,
    num_train_epochs=5,
    multi_gpu=False,
    dropout_rate=0.2,
    apply_scheduler=True,
    warmup_proportion_d=0.1,
    warmup_proportion_g=0.0,
    fake_label_index=-1,
    dataset=dataset_name,
    save_path="../weights/best_model.pth",
)

<IPython.core.display.Javascript object>

In [83]:
from copy import copy
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

LABELED_SIZE = 200
UNLABELED_SIZE = 3000
FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE)) - 1
multiplier = max(1, multiplier)
print("Multiplier:", multiplier)

np.random.seed(42)


tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


def prepare_experiment_datasets(LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier):
    tokenize = lambda x: tokenizer(
        x[sent_col1], x[sent_col2], truncation=True, max_length=CONFIG["max_seq_length"]
    )

    tokenized_dataset = dataset.map(tokenize, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns([sent_col1, sent_col2])

    tokenized_train_df = tokenized_dataset["train"].to_pandas()
    tokenized_train_df_labeled = tokenized_train_df.sample(LABELED_SIZE)
    tokenized_train_df_labeled["labeled_mask"] = True

    tokenized_train_df = tokenized_train_df.sample(UNLABELED_SIZE)
    tokenized_train_df["labeled_mask"] = False
    tokenized_train_df["labels"] = -100

    for _ in range(multiplier):
        tokenized_train_df = tokenized_train_df.append(tokenized_train_df_labeled)

    try:
        cols = [
            "input_ids",
            "token_type_ids",
            "attention_mask",
            "labeled_mask",
            "labels",
        ]
        tokenized_dataset["train"] = Dataset.from_pandas(
            tokenized_train_df[cols], preserve_index=False
        )
    except:
        cols = ["input_ids", "attention_mask", "labeled_mask", "labels"]
        tokenized_dataset["train"] = Dataset.from_pandas(
            tokenized_train_df[cols], preserve_index=False
        )

    tokenized_dataset["train_only_labeled"] = Dataset.from_pandas(
        tokenized_train_df_labeled[cols], preserve_index=False
    )
    tokenized_dataset["test"] = tokenized_dataset["test"].add_column(
        "labeled_mask",
        [True] * len(tokenized_dataset["test"]),
    )

    tokenized_dataset["validation"] = tokenized_dataset["validation"].add_column(
        "labeled_mask",
        [True] * len(tokenized_dataset["validation"]),
    )

    tokenized_dataset = tokenized_dataset.select_columns(
        cols
    )
    print(
        "TRAIN (FOR only discriminator):", len(tokenized_dataset["train_only_labeled"])
    )
    print("TRAIN (FOR GAN):", len(tokenized_dataset["train"]))
    return tokenized_dataset


tokenized_dataset = prepare_experiment_datasets(
    LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
)
tokenized_dataset

Multiplier: 3


Map:   0%|          | 0/49401 [00:00<?, ? examples/s]

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3500 [00:00<?, ? examples/s]

TRAIN (FOR only discriminator): 200
TRAIN (FOR GAN): 3600


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask', 'labels'],
        num_rows: 3600
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask'],
        num_rows: 8000
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask'],
        num_rows: 3500
    })
    train_only_labeled: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labeled_mask', 'labels'],
        num_rows: 200
    })
})

<IPython.core.display.Javascript object>

In [145]:
train_only_labeled_dataloader = DataLoader(
    tokenized_dataset["train_only_labeled"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
    collate_fn=data_collator,
    pin_memory=True,
)

train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=CONFIG["batch_size"],
    sampler=RandomSampler(tokenized_dataset["train"]),
    collate_fn=data_collator,
    pin_memory=True,
)

valid_dataloader = DataLoader(
    tokenized_dataset["validation"],
    batch_size=CONFIG["batch_size"],
    sampler=SequentialSampler(tokenized_dataset["validation"]),
    collate_fn=data_collator,
    pin_memory=True,
)

test_dataloader = DataLoader(
    tokenized_dataset["test"],
    batch_size=32,
    sampler=SequentialSampler(tokenized_dataset["test"]),
    collate_fn=data_collator,
    pin_memory=True,
)

<IPython.core.display.Javascript object>

### Train only discriminator

In [146]:
torch.cuda.empty_cache()

<IPython.core.display.Javascript object>

In [147]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module

trainer_module = imp.reload(trainer_module)
gan_trainer_module = imp.reload(gan_trainer_module)

<IPython.core.display.Javascript object>

In [148]:
from copy import copy

BASE_CONFIG = copy(CONFIG)
BASE_CONFIG["GAN"] = False
BASE_CONFIG["num_labels"] = NUM_LABELS
BASE_CONFIG["gan_training"] = False
BASE_CONFIG["num_train_epochs"] = 3
BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE

# BASE_CONFIG

<IPython.core.display.Javascript object>

In [9]:
discriminator = model.DiscriminatorForSequenceClassification(**BASE_CONFIG)

In [150]:
from trainer import trainer

trainer = imp.reload(trainer)

BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
trainer = trainer.TrainerSequenceClassification(
    config=BASE_CONFIG,
    discriminator=discriminator,
    train_dataloader=train_only_labeled_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
)

Trainable layers 201


<IPython.core.display.Javascript object>

In [151]:
%%time
run = None
tags = ['test']
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = trainer.config

for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ========")
    train_info = trainer.train_epoch(log_env=run)
    valid_metrics = trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-313
	Train loss discriminator: 1.079
	Test loss discriminator: 1.103
	Test accuracy discriminator: 0.355
	Test f1 discriminator: 0.306
	Train loss discriminator: 0.817
	Test loss discriminator: 1.223
	Test accuracy discriminator: 0.371
	Test f1 discriminator: 0.348
	Train loss discriminator: 0.347
	Test loss discriminator: 1.674
	Test accuracy discriminator: 0.367
	Test f1 discriminator: 0.302
CPU times: user 8.85 s, sys: 348 ms, total: 9.2 s
Wall time: 9.42 s


<IPython.core.display.Javascript object>

In [10]:
predict_info = trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

### Train via GAN

In [153]:
import gc

gc.collect()
torch.cuda.empty_cache()

<IPython.core.display.Javascript object>

In [155]:
from copy import copy

GAN_CONFIG = copy(CONFIG)
GAN_CONFIG["GAN"] = True
GAN_CONFIG["GAN_TYPE"] = "dummy" 
GAN_CONFIG["mixed_fake_ratio"] = 0.2
GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
GAN_CONFIG["gan_training"] = True
GAN_CONFIG["noise_type"] = "normal"
GAN_CONFIG["noise_range"] = (0, 1)
GAN_CONFIG["warmup_proportion_d"] = 0.05
GAN_CONFIG["num_train_epochs"] = 3

# GAN_CONFIG

<IPython.core.display.Javascript object>

In [156]:
generator = model.SimpleSequenceGenerator(
    input_size=CONFIG["noise_size"],
    output_size=discriminator.encoder.config.hidden_size,
)

discriminator = model.DiscriminatorForSequenceClassification(**GAN_CONFIG)
generator

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training with GAN mode on!
Default fake label index is -1


SimpleSequenceGenerator(
  (layers): Sequential(
    (0): Linear(in_features=100, out_features=768, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=768, out_features=768, bias=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

<IPython.core.display.Javascript object>

In [157]:
from trainer import gan_trainer as gan_trainer_module

gan_trainer_module = imp.reload(gan_trainer_module)

GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
    config=GAN_CONFIG,
    discriminator=discriminator,
    generator=generator,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    save_path=CONFIG['save_path']
)


Trainable layers 201


<IPython.core.display.Javascript object>

In [158]:
%%time
run = None
# tags = ['test']
run = neptune.init_run(
    project=secret["neptune_project"], api_token=secret["neptune_token"], tags=tags
)
run["config"] = gan_trainer.config


for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
    print(f"======== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ========")
    train_info = gan_trainer.train_epoch(log_env=run)
    result = gan_trainer.validation(log_env=run)
# run.stop()

https://app.neptune.ai/vmalperovich/gan-in-nlp/e/GAN2-314
	Train loss discriminator: 1.655
	Train loss generator: 0.855
Best model saved!
	Test loss discriminator: 1.869
	Test accuracy discriminator: 0.335
	Test f1 discriminator: 0.333
	Train loss discriminator: 0.836
	Train loss generator: 0.835
Best model saved!
	Test loss discriminator: 3.421
	Test accuracy discriminator: 0.343
	Test f1 discriminator: 0.343
	Train loss discriminator: 0.764
	Train loss generator: 0.808
	Test loss discriminator: 5.118
	Test accuracy discriminator: 0.347
	Test f1 discriminator: 0.343
CPU times: user 1min 34s, sys: 6.93 s, total: 1min 41s
Wall time: 1min 41s


<IPython.core.display.Javascript object>

In [11]:
discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
predict_info = gan_trainer.predict(
    discriminator, test_dataloader, label_names=CONFIG["label_names"]
)
run["test"] = predict_info
run.stop()
predict_info

# Experiments

In [8]:
from trainer import trainer as trainer_module
from trainer import gan_trainer as gan_trainer_module


# model_name = "bert-base-cased"
# model_name = "distilbert-base-uncased"
model_name = "google/electra-small-discriminator"
# model_name = "google/electra-base-discriminator"
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tags = ["FINAL"]

In [108]:
CONFIG["encoder_name"] = model_name
CONFIG["noise_range"] = (-2, 2)
CONFIG["noise_range_str"] = str(CONFIG["noise_range"])
CONFIG["noise_type"] = "uniform"
CONFIG["num_train_epochs"] = 3
print(CONFIG["num_labels"])

NUM_TRIALS_GAN = 2
UNLABELED_SIZE = CONFIG["num_labels"] * 150

2


<IPython.core.display.Javascript object>

In [12]:
for per_label in tqdm_notebook(
    [5, 10, 20,  50, 100, 200, 400]
):
    print(f"\n\n\n****************{LABELED_SIZE}***********\n\n\n")
    LABELED_SIZE = CONFIG["num_labels"] * per_label
    CONFIG["per_label_samples"] = per_label
    try:
        del discriminator
    except:
        pass
    torch.cuda.empty_cache()
    gc.collect()
    FULL_SIZE = LABELED_SIZE + UNLABELED_SIZE
    multiplier = int(np.log2(FULL_SIZE / LABELED_SIZE)) - 1
    multiplier = max(1, multiplier)
    print("Multiplier:", multiplier)

    tokenized_dataset = prepare_experiment_datasets(
        LABELED_SIZE, UNLABELED_SIZE, FULL_SIZE, multiplier
    )

    train_only_labeled_dataloader = DataLoader(
        tokenized_dataset["train_only_labeled"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train_only_labeled"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    train_dataloader = DataLoader(
        tokenized_dataset["train"],
        batch_size=CONFIG["batch_size"],
        sampler=RandomSampler(tokenized_dataset["train"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    valid_dataloader = DataLoader(
        tokenized_dataset["validation"],
        batch_size=CONFIG["batch_size"],
        sampler=SequentialSampler(tokenized_dataset["validation"]),
        collate_fn=data_collator,
        pin_memory=True,
    )

    test_dataloader = DataLoader(
        tokenized_dataset["test"],
        batch_size=32,
        sampler=SequentialSampler(tokenized_dataset["test"]),
        collate_fn=data_collator,
        pin_memory=True,
    )
    for _ in range(2):
        # NO GAN
        print("NO GAN...")
        try:
            del discriminator
        except:
            pass
        torch.cuda.empty_cache()
        gc.collect()
        BASE_CONFIG = copy(CONFIG)
        BASE_CONFIG["num_train_examples"] = len(train_only_labeled_dataloader.dataset)
        BASE_CONFIG["GAN"] = False
        BASE_CONFIG["gan_training"] = False
        BASE_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        discriminator = model.DiscriminatorForSequenceClassification(**BASE_CONFIG)
        print(discriminator.encoder_name)
        trainer = trainer_module.TrainerSequenceClassification(
            config=BASE_CONFIG,
            discriminator=discriminator,
            train_dataloader=train_only_labeled_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = trainer.config

        for epoch_i in range(BASE_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {BASE_CONFIG['num_train_epochs']} ==")
            train_info = trainer.train_epoch(log_env=run)
            valid_metrics = trainer.validation(log_env=run)
        predict_info = trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()

    for _ in range(NUM_TRIALS_GAN):
        # GAN
        print("GAN...")
        del discriminator
        gc.collect()
        torch.cuda.empty_cache()
        GAN_CONFIG = copy(CONFIG)
        GAN_CONFIG["GAN"] = True
        GAN_CONFIG["gan_training"] = True
        GAN_CONFIG["GAN_TYPE"] = "dummy"
        GAN_CONFIG["LABELED_SIZE"] = LABELED_SIZE
        GAN_CONFIG["UNLABELED_SIZE"] = UNLABELED_SIZE
        GAN_CONFIG["FULL_SIZE"] = FULL_SIZE
        discriminator = model.DiscriminatorForSequenceClassification(**GAN_CONFIG)
        generator = model.SimpleSequenceGenerator(
            input_size=CONFIG["noise_size"],
            output_size=discriminator.encoder.config.hidden_size,
        )

        GAN_CONFIG["num_train_examples"] = len(train_dataloader.dataset)
        gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
            config=GAN_CONFIG,
            discriminator=discriminator,
            generator=generator,
            train_dataloader=train_dataloader,
            valid_dataloader=valid_dataloader,
            device=device,
            save_path=CONFIG["save_path"],
        )
        run = neptune.init_run(
            project=secret["neptune_project"],
            api_token=secret["neptune_token"],
            tags=tags,
        )
        run["config"] = gan_trainer.config

        for epoch_i in range(GAN_CONFIG["num_train_epochs"]):
            print(f"== Epoch {epoch_i + 1} / {GAN_CONFIG['num_train_epochs']} ==")
            train_info = gan_trainer.train_epoch(log_env=run)
            result = gan_trainer.validation(log_env=run)
        discriminator.load_state_dict(torch.load(CONFIG["save_path"]))
        predict_info = gan_trainer.predict(
            discriminator, test_dataloader, label_names=CONFIG["label_names"]
        )
        run["test"] = predict_info
        run.stop()