In [1]:
import json
import logging
import os
import shutil
from pathlib import Path

# from megatron.data.dataset_utils import get_indexed_dataset_

import horovod.torch as hvd
from dotenv import load_dotenv
import torch
import numpy as np
from torch.utils.data import DataLoader, DistributedSampler
import datasets
from huggingface_hub import hf_hub_download
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

import sys
sys.path.append('..')
from trainer import Trainer, TrainerArgs


import transformers  # noqa: E402
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser  # noqa: E402

from utils import collect_run_configuration, get_cls_by_name, get_optimizer  # noqa: E402
import optimizers  # noqa: E402

In [175]:
class holder:
    pass

args = holder
args.task_name = 'quality'
args.seed = 0
args.batch_size = 2
args.data_n_workers = 1
args.target_seq_len = 10
args.gradient_accumulation_steps = 1
args.sum_loss = True
args.input_seg_size = 512
args.input_size = 512
args.bptt_depth = -1
args.model_attr = 'model'
args.source_prefix = ''
args.input_seq_len = 512

args.num_mem_tokens = 0

### Load tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')



In [3]:
tokenizer.eos_token_id

2

### Load model

In [130]:
!ls ../../runs/debug/hyperpartisan_news_detection/facebook/bart-base/lr1e-05_linear_adamw_wd1e-03_512_mem0/run_1

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
config.json
events.out.tfevents.1658487760.dgx2.1487056.0
events.out.tfevents.1658492318.dgx2.1498274.0
model_best.pth


In [131]:
cpt_path = '../../runs/debug/hyperpartisan_news_detection/facebook/bart-base/lr1e-05_linear_adamw_wd1e-03_512_mem0/run_1/model_best.pth'

In [132]:
from modeling_rmt_enc_dec import RMTEncoderDecoderForConditionalGeneration

In [150]:
from transformers import BartForConditionalGeneration
# model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

state_dict = torch.load(cpt_path, map_location='cpu')
config = AutoConfig.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration(config)
# # model.load_state_dict(state_dict)

In [151]:
# !ls ../../runs/debug/hyperpartisan_news_detection/facebook/bart-base/lr1e-05_linear_adamw_wd1e-03_512_mem0/run_1

In [152]:
model = RMTEncoderDecoderForConditionalGeneration(base_model=model)

In [153]:
model.model.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [154]:
model.set_params(num_mem_tokens=args.num_mem_tokens, 
                input_size=args.input_size,
                input_seg_size=args.input_seg_size,
                model_attr=args.model_attr,
                # backbone_cls=backbone_cls,
                sum_loss=args.sum_loss,
                bptt_depth=args.bptt_depth, 
                pad_token_id=tokenizer.pad_token_id,
                cls_token_id=tokenizer.cls_token_id, 
                sep_token_id=tokenizer.sep_token_id,
                eos_token_id=tokenizer.eos_token_id,)

### Hyperpartisan

In [155]:
from torch.utils.data import Dataset

class HyperpartisanDataset(Dataset):
    def __init__(self, datafile, x_field='text', label_field='label'):
        if isinstance(datafile, str):
            # convert str path to folder to Path
            datafile = Path(datafile)
        self.data = []
        for line in datafile.open('r'):
            self.data += [json.loads(line)]
        self.x_field = x_field
        self.label_field = label_field

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][self.x_field]
        label = self.data[idx][self.label_field]
        return x, label


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

In [156]:
global_attention_first_token = False  # should be True for LED
encode_plus_kwargs = {'truncation': True,
                      'padding': 'longest',
                      'pad_to_multiple_of': 1}
generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}

def collate_fn(batch):
    inputs, labels = zip(*batch)
    if args.source_prefix:
        inputs = [args.source_prefix + inp for inp in inputs]
    features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len,
                                           return_tensors='pt', **encode_plus_kwargs)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len,
                                             return_tensors='pt', **encode_plus_kwargs).input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    features['labels'] = labels
    if 'global_attention_mask' in features:
        # features["global_attention_mask"] = [[1] + [0] * (len(attn_mask) - 1) for attn_mask in features["attention_mask"]]
        logger.warning('WHAT SHOULD BE HERE FOR LED??')
    return features

In [157]:
data_path = '/home/kuratov/data/hyperpartisan_news_detection/train.jsonl'

In [74]:
hvd.init()

In [45]:
train_dataset = HyperpartisanDataset(data_path)
# shuffle train data each epoch (one loop over train_dataset)
train_sampler = DistributedSampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
                                   drop_last=False, seed=args.seed)
per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
global_batch_size = per_worker_batch_size * hvd.size()
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=per_worker_batch_size, sampler=train_sampler,
                              collate_fn=collate_fn, **kwargs)

In [176]:
for item in train_dataloader:
    break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [142]:
src, tgt = train_dataset[0]

In [166]:
encode_plus_kwargs = {'truncation': True,
                      'padding': 'longest',
                      'pad_to_multiple_of': 1}
inp = tokenizer.batch_encode_plus([src], return_tensors='pt', padding='max_length', truncation=True)
labels = tokenizer.batch_encode_plus([tgt], return_tensors='pt', padding=False, truncation=True)['input_ids']

In [167]:
output = model(**inp, labels=labels)

In [159]:
generated = model.generate(**inp)

In [160]:
generated.shape, generated

(torch.Size([1, 20]),
 tensor([[   2,    0, 4148,    5,   97,  865,    6,    5,  433,   34, 1179,    5,
          3302,    9,   10, 2235, 2235, 2235,   12,    2]]))

In [122]:
generated.shape, generated

(torch.Size([1, 20]),
 tensor([[   2,    0,    6,   61,   16,   10, 3556,   13,    5,  232,    4,  370,
           240,   10, 3556,    9,  476,    4,   20,    2]]))

In [120]:
generated.shape, generated

(torch.Size([1, 20]),
 tensor([[  2,   0, 113, 170, 348, 300,   7, 120,  66,   9, 259,  60,  26,  65,
            9,   5, 917,   4,  22,   2]]))

In [161]:
src

'As seen on The Five Police Group Boycotts Ben &amp; Jerry\'s After Black Lives Matter Endorsement As more women come forward with allegations of sexual misconduct against Donald Trump, the Republican nominee is lashing out at the mainstream media, which he claims is running with this "total fiction" to distract from Hillary Clinton\'s scandals. At a rally in Greensboro, N.C., Trump said these accusations are part of a conspiracy involving the Clinton campaign and what he called “the corrupt media.” Greg Gutfeld agreed with much of Trump\'s assessment on "The Five" today, saying, "The media wants Hillary to win. There\'s no doubt about that." Judge Jeanine on Trump Accusers: \'All on One Day? It\'s a Little Too Convenient\' Gutfeld explained that Trump has actually been "set up," as he was given countless hours of coverage and "billions of dollars of free media," only to have the rug pulled out from under him. "The media . [and] Hillary and Obama, they knew elevating Donald Trump - giv

In [162]:
tgt

'false'

In [163]:
tokenizer.batch_decode(generated)

['</s><s>On the other hand, the media has raised the possibility of a Clinton Clinton Clinton-</s>']

###  QuALITy

In [82]:
inp

{'input_ids': tensor([[   0, 1620,  450,  ...,    1,    1,    1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]])}

In [24]:
task_name = 'quality'

In [25]:
def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
    labels = [b['output'][:args.target_seq_len * 10] for b in batch]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    return features

dataset = datasets.load_dataset('tau/scrolls', task_name)
train_dataset = dataset['train']
# shuffle train data each epoch (one loop over train_dataset)
train_sampler = DistributedSampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
                                   drop_last=False, seed=args.seed)

kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=2, sampler=train_sampler,
                              collate_fn=collate_fn, **kwargs)

In [27]:
# get validation dataset
valid_dataloader = None
valid_dataset = dataset['validation']

In [31]:
valid_sampler = DistributedSampler(valid_dataset, rank=1, num_replicas=1, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, sampler=valid_sampler,
                              collate_fn=collate_fn)

In [32]:
train_dataset[0]

{'id': '52995_I3M5VUMM_1',
 'pid': '52995_I3M5VUMM_1_0',
 'input': 'Why is Si retirement so significant to the Space Exploration Team? \n\n (A) There aren’t enough working people in the world. They won’t be able to find a replacement.\n (B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.\n (C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.\n (D) His retirement may inspire others to stop working as well, which would be hugely detrimental as most people don\'t feel the drive to work as is.  \n\n\nSPACEMAN ON A SPREE\n\n\n\n\n   BY MACK REYNOLDS\n\n\n\n\n   Illustrated by Nodel\n\n\n\n\n   What\'s more important—Man\'s conquest\n\n\n   of space, or one spaceman\'s life?\n\n\n\n\n\n\n   I\n\n\n\n\n   They gave him a gold watch. It was meant to be symbolical, of course.\n In the old tradition. It was in the way of an antique, being one of the\n timepieces made generations past

In [None]:
def collate_fn(batch):
        # cut too long strings because they may slow down tokenization
        inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
        if args.input_prefix:
            inputs = [args.input_prefix + inp for inp in inputs]
        features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
        labels = np.array([labels_map[t] for t in labels])
        features['labels'] = torch.from_numpy(labels)
        return features

else:
    raise NotImplementedError('only encoder-decoder models are supported for scrolls datasets or '
                              'encoder models only for contract_nli task')

# get train dataset
if hvd.rank() == 0:
    logger.info(f'preparing dataset for: {args.task_name}')
dataset = datasets.load_dataset('tau/scrolls', args.task_name)
train_dataset = dataset['train']
# shuffle train data each epoch (one loop over train_dataset)
train_sampler = DistributedSampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
                                   drop_last=False, seed=args.seed)
per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
global_batch_size = per_worker_batch_size * hvd.size()
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=per_worker_batch_size, sampler=train_sampler,
                              collate_fn=collate_fn, **kwargs)
# get validation dataset
valid_dataloader = None
if hvd.rank() == 0:
    logger.info(f'preparing validation data from: {args.task_name}')
valid_dataset = dataset['validation']
valid_sampler = DistributedSampler(valid_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=per_worker_batch_size, sampler=valid_sampler,
                              collate_fn=collate_fn, **kwargs)
if args.valid_interval is None:

In [None]:
if args.model_type == 'encoder-decoder':
    global_attention_first_token = False  # should be True for LED
    encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}
    # generate_kwargs = {'max_length': args.target_seq_len, 'min_length': args.target_seq_len}
    generate_kwargs = {}

    def collate_fn(batch):
        # cut too long strings because they may slow down tokenization
        inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
        if args.input_prefix:
            inputs = [args.input_prefix + inp for inp in inputs]
        features = tokenizer.batch_encode_plus(list(inputs), max_length=args.input_seq_len, return_tensors='pt',
                                               **encode_plus_kwargs)
        with tokenizer.as_target_tokenizer():
            labels = tokenizer.batch_encode_plus(list(labels), max_length=args.target_seq_len, return_tensors='pt',
                                                 **encode_plus_kwargs).input_ids
        labels[labels == tokenizer.pad_token_id] = -100
        features['labels'] = labels
        if 'global_attention_mask' in features:
            raise RuntimeError('What global attention mask for Longformer and LongformerEncoder-Decoder should be?')
        return features

elif args.model_type == 'encoder' and args.task_name == 'contract_nli':
    if args.use_generate_on_valid:
        raise RuntimeError('use_generate_on_valid should be set to False for encoder-only models')

    encode_plus_kwargs = {'max_length': args.input_seq_len,
                          'truncation': True,
                          'padding': 'longest',
                          'pad_to_multiple_of': 1}
    generate_kwargs = {}
    labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
    num_labels = len(labels_map)

    def collate_fn(batch):
        # cut too long strings because they may slow down tokenization
        inputs = [b['input'][:args.input_seq_len * 10] for b in batch]
        labels = [b['output'][:args.target_seq_len * 10] for b in batch]
        if args.input_prefix:
            inputs = [args.input_prefix + inp for inp in inputs]
        features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
        labels = np.array([labels_map[t] for t in labels])
        features['labels'] = torch.from_numpy(labels)
        return features

else:
    raise NotImplementedError('only encoder-decoder models are supported for scrolls datasets or '
                              'encoder models only for contract_nli task')

# get train dataset
if hvd.rank() == 0:
    logger.info(f'preparing dataset for: {args.task_name}')
dataset = datasets.load_dataset('tau/scrolls', args.task_name)
train_dataset = dataset['train']
# shuffle train data each epoch (one loop over train_dataset)
train_sampler = DistributedSampler(train_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=True,
                                   drop_last=False, seed=args.seed)
per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
global_batch_size = per_worker_batch_size * hvd.size()
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}
train_dataloader = DataLoader(train_dataset, batch_size=per_worker_batch_size, sampler=train_sampler,
                              collate_fn=collate_fn, **kwargs)
# get validation dataset
valid_dataloader = None
if hvd.rank() == 0:
    logger.info(f'preparing validation data from: {args.task_name}')
valid_dataset = dataset['validation']
valid_sampler = DistributedSampler(valid_dataset, rank=hvd.rank(), num_replicas=hvd.size(), shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=per_worker_batch_size, sampler=valid_sampler,
                              collate_fn=collate_fn, **kwargs)
if args.valid_interval is None:
    args.valid_interval = args.log_interval

# define model
model_cls = get_cls_by_name(args.model_cls)
if hvd.rank() == 0:
    logger.info(f'Using model class: {model_cls}')
if not args.from_pretrained:
    model_cfg = AutoConfig.from_pretrained(args.model_cfg)
    if args.model_type == 'encoder' and args.task_name == 'contract_nli':
        model_cfg.num_labels = num_labels
    model = model_cls(config=model_cfg)
else:
    if hvd.rank() == 0:
        logger.info(f'Loading pretrained model: {args.from_pretrained}')
    if args.model_type == 'encoder-decoder':
        model = model_cls.from_pretrained(args.from_pretrained)
    elif args.model_type == 'encoder' and args.task_name == 'contract_nli':
        model = model_cls.from_pretrained(args.from_pretrained, num_labels=num_labels)

# Aydar # Pass memory settings to pretrained model
if args.num_mem_tokens is not None:
    backbone_cls = get_cls_by_name(args.backbone_cls) if args.backbone_cls is not None else None
    model.set_params(num_mem_tokens=args.num_mem_tokens, 
                input_size=args.input_size,
                input_seg_size=args.input_seg_size,
                model_attr=args.model_attr,
                backbone_cls=backbone_cls,
                sum_loss=args.sum_loss,
                bptt_depth=args.bptt_depth, 
                pad_token_id=tokenizer.pad_token_id,
                cls_token_id=tokenizer.cls_token_id, 
                sep_token_id=tokenizer.sep_token_id,
                eos_token_id=tokenizer.eos_token_id,)

if not args.backbone_trainable:
    for name, param in model.named_parameters():
        if 'classifier' not in name:
            print(f'{name} is frozen')
            param.requires_grad = False
        else:
            print(f'{name} remains trainable')

# define optimizer
optimizer_cls = get_optimizer(args.optimizer)
if optimizer_cls is None:
    raise RuntimeError(f'{args.optimizer} was not found in optimizers, torch.optim, transformers.optimization')

if hvd.rank() == 0:
    logger.info(f'Using optimizer class: {optimizer_cls}')

# todo: group optimizer params
if optimizer_cls in [transformers.optimization.Adafactor, optimizers.Adafactor]:
    # https://github.com/huggingface/transformers/pull/9751/files -> transformers 4.3.0
    optimizer = optimizer_cls(model.parameters(), lr=args.lr,
                              scale_parameter=args.scale_parameter,
                              relative_step=args.relative_step,
                              warmup_init=args.warmup_init,
                              weight_decay=args.weight_decay)
else:
    optimizer = optimizer_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# for encoder only classification
def keep_for_metrics_fn(batch, output):
    # select data from batch and model output that would be used to compute metrics
    data = {}
    if 'generation_outputs' in output:
        data['labels'] = batch['labels']
        data['generation_outputs'] = output['generation_outputs']
    if args.model_type == 'encoder':
        data['labels'] = batch['labels']
        data['predictions'] = torch.argmax(output['logits'].detach(), dim=-1)
    return data

# HF datasets can compute metrics on each gpu process and then aggregate them on process with rank 0
# synchronization is done by using temporay files on a shared filesystem
# rank and number of workers is set by num_process and process_id params
# BUT our Trainer aggregates all prediction from all gpus!
#   this will lead to computing metrics for predictions repeated xN_GPUS times
# need to try:
# - keep_in_memory=True, may lead to OOM for large validation sets, after sync predictions and targets for the full
#       validation set would be stored on each GPU -> xN_GPUs RAM
#   - implemented currently
# - compute metrics on batch lvl
# - add support of HF metrics and turn off aggregation in case if metric has .add_batch method
scrolls_metric = datasets.load_metric(scrolls_metric_path, args.task_name, keep_in_memory=True)

def metrics_fn(data):
    # compute metrics based on stored labels, predictions, ...
    metrics = {}
    y, p = None, None
    if args.model_type == 'encoder-decoder' and 'generation_outputs' in data:
        # replace -100 with pad token in labels
        pad_token_id = tokenizer.pad_token_id
        if isinstance(data['labels'], list):
            data['labels'] = [[t if t != -100 else pad_token_id for t in labels] for labels in data['labels']]
        else:
            data['labels'][data['labels'] == -100] = pad_token_id
        y = tokenizer.batch_decode(data['labels'], skip_special_tokens=True)
        p = tokenizer.batch_decode(data['generation_outputs'], skip_special_tokens=True)
        # todo: do we need to better clean P to remove tokens after eos? not remove special tokens only
    elif args.model_type == 'encoder':
        y, p = data['labels'], data['predictions']

    if y is not None and p is not None:
        if args.model_type == 'encoder-decoder':
            result = scrolls_metric.compute(predictions=p, references=[[_y] for _y in y])
            for metric_name in task_to_metric[args.task_name]:
                metrics[metric_name] = result[metric_name]
        elif args.model_type == 'encoder' and args.task_name == 'contract_nli':
            metrics['exact_match'] = accuracy_score(y, p) * 100
            metrics['f1_micro'] = f1_score(y, p, average='micro')
    return metrics

trainer = Trainer(args, model, optimizer, train_dataloader, valid_dataloader, train_sampler,
                  keep_for_metrics_fn=keep_for_metrics_fn, metrics_fn=metrics_fn,
                  generate_kwargs=generate_kwargs if args.use_generate_on_valid else {})

if not args.validate_only:
    # train loop
    trainer.train()
    # make sure all workers are done
    hvd.barrier()
    # run validation after training
    if args.save_best:
        best_model_path = str(Path(args.model_path) / 'model_best.pth')
        if hvd.rank() == 0:
            logger.info(f'Loading best saved model from {best_model_path}')
        trainer.load(best_model_path)
    if valid_dataloader is not None:
        if hvd.rank() == 0:
            logger.info('Runnning validation on valid data:')
        trainer.validate(valid_dataloader, write_tb=False)
else:
    # run validation, do not write to tensorboard
    if hvd.rank() == 0:
        logger.info('Running validation on train set:')
    trainer.validate(train_dataloader, write_tb=False)
    if valid_dataloader is not None:
        if hvd.rank() == 0:
            logger.info('Running validation on valid data:')
        trainer.validate(valid_dataloader, write_tb=False)