In [1]:
from datasets import load_from_disk, Dataset, concatenate_datasets, ClassLabel, Features, Value, Sequence
from transformers import BertTokenizer, BertForSequenceClassification, BasicTokenizer
from torch.utils.data import  DataLoader
from tqdm.notebook import tqdm
import numpy as np
import os
import torch
import base


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [3]:
augmentation_params = {"n_iter": 10, "p_mask":0.1, "p_pos": 0.3, "p_ng":0.2}

In [4]:
tokenizer = BasicTokenizer(do_lower_case=True)
DATASET = "trec"

In [5]:
train_data = load_from_disk(f"~/data/{DATASET}/train_coarse")
sentences = list(map(lambda e: e["sentence"], train_data))

In [6]:
token_lengths = [len(tokenizer.tokenize(sentence)) for sentence in sentences]


In [7]:
sorted_token_lengths = sorted(token_lengths, reverse=True)
avg_tokens = np.mean(token_lengths)

In [8]:
print(sorted_token_lengths[0:25])
print(avg_tokens)
print(sorted_token_lengths[-25:])

[37, 36, 33, 33, 32, 31, 31, 31, 30, 30, 29, 29, 29, 29, 28, 28, 28, 27, 27, 27, 27, 27, 27, 27, 27]
10.821600550332493
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3]


In [9]:
pos_tag_word_map_list = base.get_pos_tag_word_map(sentences, tokenizer=tokenizer)

In [10]:
print(train_data.features)

{'sentence': Value(dtype='string', id=None), 'label': ClassLabel(names=['ABBR', 'ENTY', 'DESC', 'HUM', 'LOC', 'NUM'], id=None)}


In [11]:
augmented_datasets = base.get_augmented_dataset(augmentation_params, train_data, pos_tag_word_map_list, tokenizer=tokenizer, include_idx=False)

In [12]:
aug_datasets_formated = []
ds_schema = Features({
    "sentence": Value("string"),
    "label": ClassLabel(names=["ABBR", "ENTY", "DESC", "HUM", "LOC", "NUM"])
})

for iter in augmented_datasets:
    dataset = Dataset.from_dict(iter)
    dataset = dataset.cast(ds_schema)
    aug_datasets_formated.append(dataset)

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

In [13]:
print(aug_datasets_formated[0][70])
print(train_data[70])

{'sentence': 'what body of water does the danube river flow by ?', 'label': 4}
{'sentence': 'What body of water does the Danube River flow into ?', 'label': 4}


In [14]:
tokenizer = BertTokenizer.from_pretrained("carrassi-ni/bert-base-trec-question-classification")
model = BertForSequenceClassification.from_pretrained("carrassi-ni/bert-base-trec-question-classification", num_labels=6)
model.to(device)
model.eval()

torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/teacher_coarse.pth")

In [None]:
train_dataset = base.prepare_dataset(train_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False)
train_logits = base.generate_logits(train_dataloader, model)
train_dataset = train_dataset.add_column("logits", train_logits)
train_dataset = train_dataset.remove_columns(["token_type_ids", "attention_mask", "input_ids"])
train_dataset.set_format(type="torch", columns=["logits", "labels"], device="cpu")

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

In [16]:
print(train_dataset[28])

{'labels': tensor(5), 'logits': tensor([-0.9769, -1.8649, -1.7079, -1.7614, -1.5043,  6.0615])}


In [17]:
print(base.check_acc(train_dataset, "Accuracy for base dataset: "))

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for base dataset:  0.9867002980967668


In [None]:
aug_clean_datasets = []
for dataset in tqdm(aug_datasets_formated, total=(len(aug_datasets_formated)), desc="Processing augmented datasets: "):
    aug_train_dataset = base.prepare_dataset(dataset, tokenizer)
    aug_train_dataloader = DataLoader(aug_train_dataset, batch_size=128, shuffle=False)
    aug_train_logits = base.generate_logits(aug_train_dataloader, model)
    aug_train_dataset = aug_train_dataset.add_column("logits", aug_train_logits)
    aug_train_dataset = aug_train_dataset.remove_columns(["token_type_ids", "attention_mask", "input_ids"])
    aug_train_dataset.set_format(type="torch", columns=["logits", "labels"], device="cpu")

    print(base.check_acc(aug_train_dataset, "Accuracy for augmented dataset: "))

    aug_train_dataset = base.remove_diff_pred_class(train_dataset, aug_train_dataset, pytorch_dataset=False)
    
    print(base.check_acc(aug_train_dataset, "Accuracy for filtered dataset: "))

    aug_train_dataset.reset_format()
    aug_clean_datasets.extend(aug_train_dataset)

Processing augmented datasets:   0%|          | 0/10 [00:00<?, ?it/s]

Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7959183673469388


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3489 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9899684723416452


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7915615684476037


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3483 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9879414298018949


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7844531070855308


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3445 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9892597968069666


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7817014446227929


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3424 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9903621495327103


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7883512955744095


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3455 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9904486251808973


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7966062829626233


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3486 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9911072862880091


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7897271268057785


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3442 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9927367809413132


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7810135290071084


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3423 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9894829097283085


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7940839257051135


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3481 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.9899454179833381


Tokenizing the provided dataset:   0%|          | 0/4361 [00:00<?, ? examples/s]

Generating logits for given dataset:   0%|          | 0/35 [00:00<?, ?it/s]

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.7817014446227929


Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

Calculating accuracy based on the saved logits:   0%|          | 0/3429 [00:00<?, ?it/s]

Accuracy for filtered dataset:  0.989501312335958


In [19]:
print(train_dataset.features)

{'sentence': Value(dtype='string', id=None), 'labels': ClassLabel(names=['ABBR', 'ENTY', 'DESC', 'HUM', 'LOC', 'NUM'], id=None), 'logits': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)}


In [20]:
ds_schema = Features({
    "sentence": Value("string"),
    "labels": ClassLabel(names=["ABBR", "ENTY", "DESC", "HUM", "LOC", "NUM"]),
    "logits": Sequence(feature=Value(dtype="float32")),
})

aug_dataset = Dataset.from_list(aug_clean_datasets)
aug_dataset = aug_dataset.cast(ds_schema)


Casting the dataset:   0%|          | 0/34557 [00:00<?, ? examples/s]

In [21]:
aug_dataset.set_format(type="torch", columns=["logits", "labels"], device="cpu")
train_dataset.set_format(type="torch", columns=["logits", "labels"], device="cpu")

In [22]:
print(base.check_acc(train_dataset, "Accuracy for base dataset: "))
print(base.check_acc(aug_dataset, "Accuracy for augmented dataset: "))

Calculating accuracy based on the saved logits:   0%|          | 0/4361 [00:00<?, ?it/s]

Accuracy for base dataset:  0.9867002980967668


Calculating accuracy based on the saved logits:   0%|          | 0/34557 [00:00<?, ?it/s]

Accuracy for augmented dataset:  0.9900743698816448


In [23]:
train_all_data = concatenate_datasets([train_dataset, aug_dataset])
train_all_data.set_format(type="torch", columns=["logits", "labels"], device="cpu")

In [24]:
print(base.check_acc(train_all_data, "Accuracy for combined dataset: "))

Calculating accuracy based on the saved logits:   0%|          | 0/38918 [00:00<?, ?it/s]

Accuracy for combined dataset:  0.9896962844956061


In [25]:
print(train_all_data.column_names)

['sentence', 'labels', 'logits']


In [26]:
train_all_data.reset_format()

In [27]:
train_all_data.save_to_disk(f"~/data/{DATASET}/train-logits-augmented_coarse")

Saving the dataset (0/1 shards):   0%|          | 0/38918 [00:00<?, ? examples/s]

In [28]:
train_dataset.reset_format()
train_dataset.save_to_disk(f"~/data/{DATASET}/train-logits_coarse")

Saving the dataset (0/1 shards):   0%|          | 0/4361 [00:00<?, ? examples/s]

In [29]:
eval_data = load_from_disk(f"~/data/{DATASET}/eval_coarse")

eval_dataset = base.prepare_dataset(eval_data, tokenizer)
eval_dataloader = DataLoader(eval_dataset, batch_size=128, shuffle=False)

In [30]:
test_data = load_from_disk(f"~/data/{DATASET}/test_coarse")

test_dataset = base.prepare_dataset(test_data, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
eval_logits = base.generate_logits(eval_dataloader, model)
test_logits = base.generate_logits(test_dataloader, model)

Generating logits for given dataset:   0%|          | 0/9 [00:00<?, ?it/s]

Generating logits for given dataset:   0%|          | 0/4 [00:00<?, ?it/s]

In [32]:
eval_dataset.reset_format()
eval_dataset = eval_dataset.add_column("logits", eval_logits)
eval_dataset = eval_dataset.remove_columns(["token_type_ids", "input_ids", "attention_mask"])

In [33]:
test_dataset.reset_format()
test_dataset = test_dataset.add_column("logits", test_logits)
test_dataset = test_dataset.remove_columns(["token_type_ids", "input_ids", "attention_mask"])

In [34]:
eval_dataset.save_to_disk(f"~/data/{DATASET}/eval-logits_coarse")
test_dataset.save_to_disk(f"~/data/{DATASET}/test-logits_coarse")

Saving the dataset (0/1 shards):   0%|          | 0/1091 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

In [35]:
eval_data = load_from_disk(f"~/data/{DATASET}/eval-logits_coarse")
test_data = load_from_disk(f"~/data/{DATASET}/test-logits_coarse")

eval_data.set_format(type="torch", columns=["logits", "labels"], device="cpu")
test_data.set_format(type="torch", columns=["logits", "labels"], device="cpu")

print(base.check_acc(eval_data, "Accuracy for base eval dataset: "))
print(base.check_acc(test_data, "Accuracy for base test dataset: "))

Calculating accuracy based on the saved logits:   0%|          | 0/1091 [00:00<?, ?it/s]

Accuracy for base eval dataset:  0.9825847846012832


Calculating accuracy based on the saved logits:   0%|          | 0/500 [00:00<?, ?it/s]

Accuracy for base test dataset:  0.978
