In [7]:
import os
import platform
import sys

os.environ['TOKENIZERS_PARALLELISM'] = 'true'
if platform.system() == 'Linux':
    % load_ext autoreload
    % autoreload 2


    from sys import platform
    from google.colab import drive
    import shutil


    def mount_drive():
        drive.mount("/content/drive", force_remount=False)
        # main_folder_path = "drive/My Drive/nlp/"
        # sys.path.append(os.path.abspath(main_folder_path))
        # os.environ['PYTHONPATH'] += ":/content/drive/MyDrive/nlp/"


    # noinspection PyUnresolvedReferences
    def install_deps():
        ! pip install python-Levenshtein -q
        ! pip install pytorch_lightning==1.1.4 -q
        ! pip install wandb -q
        ! pip install transformers -q
        ! pip install pydantic -q
        ! pip install --upgrade --force-reinstall --no-deps kaggle -q
        ! pip install -q kaggle
        ! pip install -q datasets
        ! pip install clean-text
        ! pip install lightning-bolts
        # !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl


    def download_kaggle():
        if not os.path.exists("~/.kaggle/") or 'kaggle.json' not in os.listdir():
            os.system('mkdir ~/.kaggle')
            shutil.copyfile('drive/MyDrive/data/kaggle.json', "kaggle.json")
            os.system('cp kaggle.json ~/.kaggle/')
            os.system('chmod 600 ~/.kaggle/kaggle.json')
            os.system('kaggle competitions download -c coleridgeinitiative-show-us-the-data')
            os.system('mkdir input')
            os.system('unzip -q coleridgeinitiative-show-us-the-data.zip -d input/coleridgeinitiative-show-us-the-data')


    if not os.path.exists("/content/drive"):
        mount_drive()
        install_deps()
        download_kaggle()
    sys.path.append(os.path.abspath('/content/drive/MyDrive/nlp-master'))


In [None]:
# Commented out IPython magic to ensure Python compatibility.
CV_F = 0
import os
sys.argv = ['-f ']
os.environ["WANDB_API_KEY"] = 'f7a90002357566fdfb99c6768a70a2ec502afe35'

from tqdm import tqdm
from typing import List
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import pytorch_lightning as pl
import pandas as pd
import numpy as np

from word_river.model.dataset.utils import cleaning, get_augs, get_weights
from word_river.model.pl.utils import PrintLogger
from word_river.train_data.utils import get_publications, Ranger
from word_river.cli_parser.train import model_args, training_args, wandb_args, data_args
from word_river.train_data.prepare_cval import Spliter
from word_river.model.dataset import Collator
from word_river.dtypes import Item, Publication

from pathlib import Path
import multiprocessing as mp
from datasets import load_dataset
import random, nltk
from word_river.model.architecture import MyModel
from word_river.model.pl import ModelPl

nltk.download('punkt')
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

model_args.model_name_or_path = 'distilbert-base-cased'
wandb_args.project, wandb_args.entity = 'coleridgeinitiative', 'alexch'

data_args.max_source_length = 120
data_args.ds_dir = Path('/Users/alexch/tmp/input/coleridgeinitiative-show-us-the-data')
data_args.num_workers = mp.cpu_count()
data_args.GET_EVERY_PUB = 4

training_args.output_dir = Path('/Users/alexch/tmp/input/checkpoint')
training_args.early_stopping_patience = 5
training_args.mixed_precision = True
training_args.gpus = '0'
training_args.lr_1 = 3e-5
training_args.lr_2 = .01
training_args.per_device_train_batch_size = 40
training_args.early_stopping_patience = 33
training_args.augs_csv = '/Users/alexch/tmp/input/data_set_800.csv'

len_ones = 1000
len_zer = 9000

re_write = False
item_data = Path('/Users/alexch/tmp/input/item_folder')
path_train_items = item_data / 'train_items_2.pickle'
path_test_items = item_data / 'test_items_2.pickle'
path_pubs = item_data / 'pubs.csv'
training_args.gpus = None

if platform.system() == 'Linux':
    item_data = Path('/content/drive/MyDrive/data/')
    path_train_items = item_data / 'train_items_2.pickle'
    path_test_items = item_data / 'test_items_2.pickle'
    training_args.augs_csv = '/content/drive/MyDrive/data/govt2/data_set_800.csv'
    training_args.output_dir = Path('drive/MyDrive/checkpoint')
    data_args.ds_dir = Path('input/coleridgeinitiative-show-us-the-data')
    training_args.gpus = '0'
    path_all_items = item_data / 'all_items.csv'


def get_train_test_items():
    if path_all_items.exists():
        print('Loading items...')
        all_items: pd.DataFrame = pd.read_csv(path_all_items, compression='zip')
        all_items['char_range'] = [eval(x) for x in all_items['char_range']]

        def nans(v):
            if type(v) == str:
                return v
            if np.isnan(v):
                return
            assert False

        all_items['dataset_title'] = [nans(x) for x in all_items['dataset_title']]
        all_items['dataset_label'] = [nans(x) for x in all_items['dataset_label']]

    else:
        print('Creating items...')

        if path_pubs.exists():
            pubs = pd.read_csv(path_pubs)
            pubs.dataset_title = [eval(x) for x in pubs.dataset_title]
            pubs.dataset_label = [eval(x) for x in pubs.dataset_label]
        else:
            df = pd.read_csv(data_args.ds_dir / 'train.csv')
            pubs: pd.DataFrame = get_publications(df)
            with mp.Pool(mp.cpu_count()) as p:
                pubs["text"] = list(tqdm(p.imap(cleaning, pubs['text']), total=len(pubs['text']), desc='Split'))
            # noinspection PyUnresolvedReferences
            pubs.to_csv(path_pubs, index=False)

        pubs = [Publication(**rec) for rec in pubs.to_dict('records')]

        spliter = Spliter(data_args, 'sent_tokenize', ranger=Ranger(data_args, mode='get_longest'), get_first=False)

        all_items: List[Item] = spliter(pubs)
        all_items = pd.DataFrame([i.dict() for i in all_items])
        # noinspection PyUnresolvedReferences
        all_items['pub_title'] = [x.strip() for x in all_items['pub_title']]
        # noinspection PyUnresolvedReferences
        all_items.to_csv(path_all_items, compression="zip", index=None)

    # noinspection PyUnresolvedReferences
    all_items = [Item(**rec) for rec in all_items.to_dict('records')]
    print('len(all_items)', len(all_items))

    all_items = [x for x in all_items if 'data' in x.text.lower() or 'stud' in x.text.lower()]

    adni_filter = lambda item: item.dataset_label == 'adni' and not (
            'sample' in item.text or 'data' in item.text or 'stud' in item.text)
    all_items = [x for x in all_items if not adni_filter(x)]

    # reduce  ADNI
    all_items = \
        random.sample(
            [x for x in all_items if x.dataset_title == "Alzheimer's Disease Neuroimaging Initiative (ADNI)"],
            2500) + \
        [x for x in all_items if x.dataset_title != "Alzheimer's Disease Neuroimaging Initiative (ADNI)"]

    conll2003 = load_dataset('conll2003')
    neg_texts = [' '.join(x) for x in conll2003['train']['tokens']]
    neg_items = []
    for txt in neg_texts:
        neg_items.append(
            Item(
                pub_title='conll2003',
                dataset_title=None,
                dataset_label=None,
                text=txt,
                char_range=(1, 1)
            )
        )

    title_stat = pd.Series([x.dataset_title for x in all_items if x.dataset_title]).value_counts()

    # 1 choose test titles
    test_titles = set(title_stat[CV_F::5].index)
    train_titles = set([x for x in title_stat.index if x not in test_titles])

    # 2 select train/test publications
    test_pubs = set([x.pub_title for x in all_items if x.dataset_title in test_titles])
    train_pubs = set([x.pub_title for x in all_items]) - test_pubs

    # 3 create train/test items
    train_items = [x for x in all_items if x.pub_title in train_pubs and x.dataset_title not in test_titles]
    test_items = [x for x in all_items if x.pub_title in test_pubs and x.dataset_title not in train_titles]

    train_items = [x for x in train_items if x.dataset_title] + neg_items + random.sample(
        [x for x in train_items if not x.dataset_title], 15000)
    return train_items, test_items


In [None]:
train_items, test_items = get_train_test_items()

# noinspection PyTypeChecker
dl_test = DataLoader(random.sample([x for x in test_items if x.dataset_title], len_ones) + \
                     random.sample([x for x in test_items if not x.dataset_title], len_zer),
                     collate_fn=Collator(data_args, model_args),
                     batch_size=120,
                     num_workers=data_args.num_workers,
                     shuffle=False
                     )

dl = DataLoader(train_items,
                collate_fn=Collator(data_args,
                                    model_args,
                                    augmentation_prob=0.4,
                                    augmentation_list=get_augs(training_args, data_args),
                                    ),
                # sampler=RandomSampler(train_items, num_samples=400, replacement=True),
                sampler=WeightedRandomSampler(get_weights(train_items), 900, replacement=True),
                batch_size=training_args.per_device_train_batch_size,
                # batch_size=training_args.per_device_train_batch_size,
                num_workers=data_args.num_workers,
                shuffle=False
                )

model = ModelPl(MyModel, model_args, data_args, training_args, None)
training_args.callbacks_monitor = 'fbeta_val'

logers = [
    WandbLogger(config={**model_args.__dict__, **data_args.__dict__, **training_args.__dict__}),
    PrintLogger()
]

callbacks = [
    ModelCheckpoint(
        dirpath=training_args.output_dir,
        filename='model',
        monitor=training_args.callbacks_monitor,
        mode=training_args.callbacks_mode,
        verbose=True),

    EarlyStopping(
        monitor=training_args.callbacks_monitor, min_delta=0.001,
        patience=training_args.early_stopping_patience,
        verbose=True,
        mode=training_args.callbacks_mode),
    LearningRateMonitor(logging_interval="epoch"),
]

trainer = pl.Trainer(
    precision=(training_args.mixed_precision and 16) or 32,
    gpus=training_args.gpus,
    #   logger=wandb_logger,
    #  max_epochs=33,
    logger=logers,
    #   logger=PrintLogger(),
    num_sanity_val_steps=0,
    # callbacks=callbacks
    # tpu_cores=7
)

trainer.tune(model, dl, dl_test)
trainer.fit(model, dl, dl_test)
