In [1]:
from datasets.text_classification_dataset import MAX_TRAIN_SIZE, MAX_VAL_SIZE
import logging
import os
import random
import pickle
from argparse import ArgumentParser
from datetime import datetime

import numpy as np

import torch

from torch.utils import data

import datasets.utils
from models.cls_oml_ori_v2 import OML
from models.base_models_ori import LabelAwareReplayMemory

In [2]:
# Define the ordering of the datasets
dataset_order_mapping = {
    1: [2, 0, 3, 1, 4],
    2: [3, 4, 0, 1, 2],
    3: [2, 4, 1, 3, 0],
    4: [0, 2, 1, 4, 3]
}
n_classes = 33

In [3]:
args = {
    "order": 1,
    "n_epochs": 1,
    "lr": 3e-5,
    "inner_lr": 0.001*10,
    "meta_lr": 3e-5,
    "model": "bert",
    "learner": "oml",
    "mini_batch_size": 16,
    "updates": 5*1,
    "write_prob": 1.0,
    "max_length": 448,
    "seed": 42,
    "replay_rate": 0.01,
    "replay_every": 9600
}
updates = args["updates"]
mini_batch_size = args["mini_batch_size"]
order = args["order"]
seed = args["seed"]

In [4]:
## Added db caching for ease of load
use_db_cache = True
cache_dir = '/data/omler_data/tmp'

# Set random seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Load the datasets
print('Loading the datasets')
train_datasets, val_datasets, test_datasets = [], [], []
for dataset_id in dataset_order_mapping[order]:

    # KNIGHT EDIT TO MAKE IT LOAD TRAINING DATA FAST!
    train_dataset_file = os.path.join(cache_dir, f"train-{dataset_id}.cache")
    if os.path.exists(train_dataset_file):
        with open(train_dataset_file, 'rb') as f:
            train_dataset = pickle.load(f)
    else:
        train_dataset = datasets.utils.get_dataset_train("", dataset_id)
        print('Loaded {}'.format(train_dataset.__class__.__name__))
        train_dataset = datasets.utils.offset_labels(train_dataset)
        pickle.dump(train_dataset, open( train_dataset_file, "wb" ), protocol=pickle.HIGHEST_PROTOCOL)
        print(f"Pickle saved at {train_dataset_file}")

    train_dataset, val_dataset = datasets.utils.get_train_val_split(dataset=train_dataset,
                                                                    train_size=MAX_TRAIN_SIZE,
                                                                    val_size=MAX_VAL_SIZE)
    train_datasets.append(train_dataset)
    val_datasets.append(val_dataset)
print('Finished loading all the datasets')

Loading the datasets
Finished loading all the datasets


In [5]:
def convert_to_task_class_key(task_dict):
    task_class_list = []
    for task_idx, class_list in task_dict.items():
        task_class_list.extend([f"{task_idx}|{class_idx}" for class_idx in class_list])
    return task_class_list

In [6]:
task_dict = {
    0: list(range(5, 9)), # AG
    1: list(range(0, 5)), # Amazon
    2: list(range(0, 5)), # Yelp
    3: list(range(9, 23)), # DBPedia
    4: list(range(23, 33)), # Yahoo
}
task_class_list = convert_to_task_class_key(task_dict)
print(task_class_list)

['0|5', '0|6', '0|7', '0|8', '1|0', '1|1', '1|2', '1|3', '1|4', '2|0', '2|1', '2|2', '2|3', '2|4', '3|9', '3|10', '3|11', '3|12', '3|13', '3|14', '3|15', '3|16', '3|17', '3|18', '3|19', '3|20', '3|21', '3|22', '4|23', '4|24', '4|25', '4|26', '4|27', '4|28', '4|29', '4|30', '4|31', '4|32']


In [7]:
memory_buffer = LabelAwareReplayMemory(write_prob=1., tuple_size=2, n_classes=33, \
                                     validation_split=0., task_dict=task_dict, task_aware=True)

In [8]:
episode_id = 0
        
for train_idx, train_dataset in enumerate(train_datasets):
    task_idx = dataset_order_mapping[order][train_idx]
    print('Starting with train_idx: {} on task_idx'.format(train_idx, task_idx))
    
    # Change to each dataset.
    train_dataloader = iter(data.DataLoader(data.ConcatDataset([train_dataset]), batch_size=mini_batch_size, shuffle=False,
                                            collate_fn=datasets.utils.batch_encode))

    while True:
        is_break = False
        
        # Inner loop
        support_set = []
        task_predictions, task_labels = [], []
        for _ in range(updates):
            try:
                text, labels = next(train_dataloader)
                support_set.append((text, labels))
            except StopIteration:
                is_break = True
                print('Terminating training as all the data is seen')
                break
        
        if is_break:
            break

        for text, labels in support_set:
            memory_buffer.write_batch(text, labels, task_id=task_idx)

        episode_id += 1

Starting with train_idx: 0 on task_idx
Terminating training as all the data is seen
Starting with train_idx: 1 on task_idx
Terminating training as all the data is seen
Starting with train_idx: 2 on task_idx
Terminating training as all the data is seen
Starting with train_idx: 3 on task_idx
Terminating training as all the data is seen
Starting with train_idx: 4 on task_idx
Terminating training as all the data is seen


In [9]:
#SAVE MODEL AND MEMORY EVERY EPOCH
model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854228.pt"
_model_path0 = os.path.splitext(model_path)[0]
MEMORY_SAVE_LOC = _model_path0 + "_memory_newkey.pickle"
pickle.dump( memory_buffer, open( MEMORY_SAVE_LOC, "wb" ), protocol=pickle.HIGHEST_PROTOCOL)
print('Saving ER on path: {}'.format(MEMORY_SAVE_LOC))

Saving ER on path: /data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854228_memory_newkey.pickle


In [10]:
memory_buffer.buffer_dict.keys()

dict_keys(['2|3', '2|2', '2|0', '2|4', '2|1', '0|8', '0|6', '0|5', '0|7', '3|18', '3|19', '3|14', '3|10', '3|16', '3|12', '3|20', '3|11', '3|21', '3|17', '3|15', '3|9', '3|22', '3|13', '1|1', '1|2', '1|3', '1|4', '1|0', '4|25', '4|28', '4|31', '4|32', '4|23', '4|24', '4|30', '4|26', '4|27', '4|29'])

In [12]:
memory_buffer.buffer_dict["2|3"][0]

"my daughter and i went there on saturday night. the place wasn't very busy, only one other table taken. we sat down and were immediately greeted by the waitress. she was very helpful, explained the menu to me, and was also very patient as i had never had indian food before. my daughter and i both got a chicken dish, chicken tikki masala and tandoori chicken megan's was a boneless chicken and mine had bones in. both were pretty spicy, which surprised me, didn't really expect that, but what did i know, it was all new to me. they were both very good, and i did enjoy them. we also got rice and naan, which was ok, and she brought us out some sauces, the best being the mint chutney. that one little sauce was my favorite thing out of the whole meal. i would spread it on toast and eat it if i could! all in all, i really enjoyed this place. if we are in the neighborhood i would definitely drop in again, not only for the food, but for the fantastic service. the service, if i could rate it separ