In [1]:
import torch
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

In [2]:
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    input_field="review",
    target_fields="sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    batch_size=32
)
datamodule

  0%|          | 0/22500 [00:00<?, ?ex/s]

  0%|          | 0/2500 [00:00<?, ?ex/s]

  0%|          | 0/2500 [00:00<?, ?ex/s]

  exec(code_obj, self.user_global_ns, self.user_ns)


<flash.text.classification.data.TextClassificationData at 0x16b4be3d0>

In [3]:
gpu_count = torch.cuda.device_count()
print(f"Number of GPU devices: {gpu_count}")

Number of GPU devices: 0


In [4]:
# https://huggingface.co/prajjwal1/bert-tiny
classifier_model = TextClassifier(backbone="prajjwal1/bert-tiny", num_classes=datamodule.num_classes)
trainer = flash.Trainer(max_epochs=3, gpus=gpu_count)

trainer.finetune(classifier_model, datamodule=datamodule, strategy="freeze")

Using 'prajjwal1/bert-tiny' provided by Hugging Face/transformers (https://github.com/huggingface/transformers).
Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [5]:
trainer.test(datamodule=datamodule)

Restoring states from the checkpoint path at /Users/richardkuodis/Documents/Development/practical_mlflow/lightning_logs/version_1/checkpoints/epoch=9-step=28120.ckpt
Loaded model weights from checkpoint at /Users/richardkuodis/Documents/Development/practical_mlflow/lightning_logs/version_1/checkpoints/epoch=9-step=28120.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_accuracy': 0.6412000060081482,
  'test_cross_entropy': 0.6316602826118469}]

In [6]:
print('Get prediction outputs for two sample sentences')
predict_module = TextClassificationData.from_lists(
    predict_data=[
        "Best movie I have seen.",
        "What a movie!",
    ],
    batch_size=2
)
predictions = trainer.predict(classifier_model, datamodule=predict_module)
print(predictions)

Get prediction outputs for two sample sentences


Predicting: 2812it [00:00, ?it/s]

[[tensor([-0.0758, -0.2617]), tensor([ 0.1331, -0.3732])]]
