In [37]:
import os
import sys
import yaml

import torch
import warnings

warnings.simplefilter("ignore")
sys.path.append("..")


device = torch.device("cuda")

<IPython.core.display.Javascript object>

In [29]:
from datasets import load_dataset

dataset_name = "vmalperovich/20ng_not_enough_data"
dataset = load_dataset(dataset_name)
dataset = dataset.rename_column("label", "labels")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'labeled_mask'],
        num_rows: 4886
    })
    test: Dataset({
        features: ['text', 'labels', 'labeled_mask'],
        num_rows: 7532
    })
    validation: Dataset({
        features: ['text', 'labels', 'labeled_mask'],
        num_rows: 2263
    })
})

<IPython.core.display.Javascript object>

In [35]:
config = yaml.safe_load(open("default_config.yml"))

config["encoder_name"] = "distilbert-base-uncased"
config["label_names"] = dataset["train"].features["labels"].names
config["num_labels"] = len(config["label_names"])
config["num_train_epochs"] = 3

<IPython.core.display.Javascript object>

In [26]:
import model
from trainer import gan_trainer as gan_trainer_module


discriminator = model.DiscriminatorForSequenceClassification(**config)
generator = model.SimpleSequenceGenerator(
    input_size=config["noise_size"],
    output_size=discriminator.encoder.config.hidden_size,
)
tokenizer = discriminator.get_tokenizer()

<IPython.core.display.Javascript object>

In [2]:
from example_utils import prepare_dataloaders


loaders = prepare_dataloaders(dataset, tokenizer, config)

In [20]:
gan_trainer = gan_trainer_module.GANTrainerSequenceClassification(
    config=config,
    discriminator=discriminator,
    generator=generator,
    train_dataloader=loaders["train"],
    valid_dataloader=loaders["valid"],
    device=device,
    save_path=config["save_path"],
)

Trainable layers 102


<IPython.core.display.Javascript object>

In [21]:
for epoch_i in range(config["num_train_epochs"]):
    print(
        f"======== Epoch {epoch_i + 1} / {config['num_train_epochs']} ========"
    )
    train_info = gan_trainer.train_epoch(log_env=None)
    result = gan_trainer.validation(log_env=None)

	Train loss discriminator: 3.154
	Train loss generator: 0.609
Best model saved!
	Test loss discriminator: 0.995
	Test accuracy discriminator: 0.705
	Test f1 discriminator: 0.681
	Train loss discriminator: 1.193
	Train loss generator: 0.809
Best model saved!
	Test loss discriminator: 0.854
	Test accuracy discriminator: 0.766
	Test f1 discriminator: 0.755
	Train loss discriminator: 0.897
	Train loss generator: 0.796
Best model saved!
	Test loss discriminator: 0.938
	Test accuracy discriminator: 0.766
	Test f1 discriminator: 0.759


<IPython.core.display.Javascript object>

In [25]:
discriminator.load_state_dict(torch.load(config["save_path"]))
predict_info = gan_trainer.predict(
    discriminator, loaders["test"], label_names=config["label_names"]
)
predict_info["overall_f1"]

0.7023236526761144

<IPython.core.display.Javascript object>