In [2]:
import pandas as pd

In [21]:
# split a dev set from the training data
train_df = pd.read_csv("Train-splits/testnway/full_shot/NuclearEnergy_train.csv", header=None, names=['text', 'labels'])
test_df = pd.read_csv("Train-splits/testnway/full_shot/NuclearEnergy_test.csv", header=None, names=['text', 'labels'])
dev_df = pd.read_csv("Train-splits/testnway/full_shot/NuclearEnergy_dev.csv", header=None, names=['text', 'labels'])

In [22]:
all_train_labels = train_df.labels.tolist()
unique_labels = list(set(all_train_labels))
label2int = {label:i for i, label in enumerate(unique_labels)}
int2label = {i:label for i, label in enumerate(unique_labels)}


In [23]:
train_df.labels.value_counts()

environment and health issues    209
other                            120
reactorsecurity                  112
alternatives                     100
costs                             98
waste                             87
weapons                           52
reliability                       47
improvement                       33
Name: labels, dtype: int64

In [24]:
# split into 0-shot, few-shot datasets
n_trials = 5
n_shots = [0, 1, 2, 4, 8, 10, 50, 100, 209]

In [37]:
# 1) pretain TARS on n-1 labels complete data
# 2) ft TARS on n-th label few data points
fs_training_data = {}
for label in unique_labels:
    fs_training_data[label] = {"pre":train_df[train_df.labels!=label], "ft":[]}
    # perform pretraining here
    for n_shot in n_shots:
        for trial in range(1,n_trials+1):
            if trial > 1 and n_shot == 0:
                continue
            seed = trial*n_shot
            idx = train_df.labels==label
            sample_ft =  train_df[idx].sample(n=min(n_shot, sum(idx)), random_state=seed)
            fs_training_data[label]["ft"].append({"n_shot":n_shot, "trial":trial, "data":sample_ft})

In [None]:
import pickle
with open("data_samples.pkl", "wb") as f:
    pickle.dump(fs_training_data, f)

In [40]:
from flair.datasets import SentenceDataset
from flair.data import Corpus, Sentence

def get_flair_dataset_from_dataframe(data, text_col, label_col):
    sentences = list(data.apply(lambda row: Sentence(row[text_col]).add_label('class', row[label_col]), axis=1))
    return SentenceDataset(sentences)

dev_dataset = get_flair_dataset_from_dataframe(dev_df, "text", "labels")
test_dataset = get_flair_dataset_from_dataframe(test_df, "text", "labels")

In [42]:
from flair.models import TARSClassifier
from flair.trainers import ModelTrainer

for label in unique_labels:
    print("Pretraining excluding ", label)
    model_name = label+"_pretraining"
    train_dataset = get_flair_dataset_from_dataframe(fs_training_data[label]["pre"], "text", "labels")
    corpus = Corpus(train=train_dataset, dev=dev_dataset, test=test_dataset, name=model_name, sample_missing_splits=False)
    label_dict = corpus.make_label_dictionary(label_type='class')
    tars = TARSClassifier(num_negative_labels_to_sample=4, embeddings='distilbert-base-german-cased') # roberta-large
    tars.add_and_switch_to_new_task(task_name=model_name,
        label_dictionary=label_dict,
        label_type='class',
    )
    trainer = ModelTrainer(tars, corpus)
    trainer.train(
        base_path='models/' + model_name,  # path to store the model artifacts
        learning_rate=0.01,                # use very small learning rate
        mini_batch_size=8,
        max_epochs=20
    )

Pretraining excluding  reliability
2021-11-23 22:51:21,364 Computing label dictionary. Progress:


100%|██████████| 811/811 [00:00<00:00, 1437.28it/s]

2021-11-23 22:51:21,983 Corpus contains the labels: class (#811)
2021-11-23 22:51:21,987 Created (for label 'class') Dictionary with 9 tags: <unk>, other, alternatives, environment and health issues, reactorsecurity, costs, improvement, waste, weapons





2021-11-23 22:51:37,409 TARS initialized without a task. You need to call .add_and_switch_to_new_task() before training this model
2021-11-23 22:51:37,471 ----------------------------------------------------------------------------------------------------
2021-11-23 22:51:37,481 Model: "TARSClassifier(
  (tars_model): TextClassifier(
    (loss_function): CrossEntropyLoss()
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(31102, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0): TransformerBlock(
              (attention): MultiHeadSelfAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): Linear(i

KeyboardInterrupt: 