In [11]:
import torch
import datasets
from datasets import load_dataset  # huggingface datasets

import utils
import config
from nltk.tokenize import sent_tokenize
import random
from tqdm import tqdm

def sentence_split(text):
    return sent_tokenize(text)

import os
def create_instances(dir_path, title_start = "========,", ssplit = True):
    files = []
    utils.get_files_recursive(dir_path, files)
    print("Found " + str(len(files)) + " text files.")
    res = {
        'id': [],
        'labels': [],
        'sentences': [],
    }
    for f in tqdm(files):
        example = process_document(f, title_start = title_start, ssplit = ssplit)
        res['id'].append(example['id'])
        res['labels'].append(example['labels'])
        res['sentences'].append(example['sentences'])
    return res


def process_document(path, title_start = "========,", forbidden_start = "***LIST***", ssplit = True):
    lines = ([l for l in utils.load_lines(path) if not l.startswith(forbidden_start)]) if ssplit else (sentence_split(utils.load_txt_file(path)))
    labels = []
    sentences = []
    for i in range(len(lines)):
        if lines[i].startswith(title_start):
            continue
        if (i-1) >= 0 and lines[i-1].startswith(title_start):
            sentences.append(lines[i])
            labels.append(1)
        else:
            sentences.append(lines[i])
            labels.append(0)

    return {
        'id': path,
        'labels': labels,
        'sentences': sentences,
    }

def get_training_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    return device

device = get_training_device()

dataset = load_dataset("json", data_files='./data/json/dev.json')
# take only first 100 examples
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=2357, shuffle=True)
split_dataset['val'] = split_dataset.pop('test')  # rename the test split to val
from sentence_transformers import SentenceTransformer

model_name = 'all-MiniLM-L6-v2'
sentence_model = SentenceTransformer(model_name, device=device)

# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
def process(example):
    embds = sentence_model.encode(example['sentences'])
    out = {
        'id': example['id'],
        'labels': example['labels'],
        'sentences': example['sentences'],
        'embeddings': embds,
        'len': len(embds)
    }
    return out

def preprocess_list_data(data):
    # for avoiding long time, we process the dataset in batches for parallel processing
    batch_size = 512
    all_processed = {
        'id': [],
        'labels': [],
        'sentences': [],
        'embeddings': [],
        'len': []
    }
    batch_ids = [list(range(i, min(i + batch_size, len(data)))) for i in range(0, len(data), batch_size)]
    for batch_id in tqdm(batch_ids):
        batch = data.select(batch_id)
        all_sentences = []
        for example in batch:
            all_sentences.extend(example['sentences'])
        embs = sentence_model.encode(all_sentences)
        offset = 0
        for example in batch:
            next_offset = offset + len(example['sentences'])
            embds = embs[offset:offset + next_offset]
            offset = next_offset
            all_processed['id'].append(example['id'])
            all_processed['labels'].append(example['labels'])
            all_processed['sentences'].append(example['sentences'])
            all_processed['embeddings'].append(embds)
            all_processed['len'].append(len(embds))

    return all_processed


def preprocess_dataset(dataset: datasets.Dataset):
    # for avoiding long time, we process the dataset in batches for parallel processing
    return datasets.Dataset.from_dict(preprocess_list_data(dataset))

train_dataset = preprocess_dataset(split_dataset['train'].select(range(10)))
val_dataset = preprocess_dataset(split_dataset['val'].select(range(10)))
test_dataset = preprocess_dataset(datasets.Dataset.from_dict(create_instances(dir_path='./data/input')))


# save the tokenized dataset
train_dataset.save_to_disk('./data/tokenized_train')
val_dataset.save_to_disk('./data/tokenized_val')
test_dataset.save_to_disk('./data/tokenized_test')

100%|██████████| 1/1 [00:02<00:00,  2.53s/it]


Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/50 [00:00<?, ? examples/s]

Found cached dataset json (/Users/make/.cache/huggingface/datasets/json/default-00ff1c561751f0d8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached split indices for dataset at /Users/make/.cache/huggingface/datasets/json/default-00ff1c561751f0d8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-47356b4111c2d823.arrow and /Users/make/.cache/huggingface/datasets/json/default-00ff1c561751f0d8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-bfe10a35ef1eedaa.arrow
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  2.46it/s]


./data/input	0
Found 50 text files.


100%|██████████| 50/50 [00:00<00:00, 5895.51it/s]
100%|██████████| 1/1 [00:02<00:00,  2.52s/it]


Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/50 [00:00<?, ? examples/s]

In [None]:
"""
Convert sentence embeddings and labels for trainable blocks
"""
import config
import random

fake_sent_embedding = sentence_model.encode([config.fake_sent])[0]
embedding_size = sentence_model.get_sentence_embedding_dimension()


def create_fake_block(block, lines):
    block_fake = block.copy()
    random.shuffle(block_fake)
    p = random.random()
    if p >= 0.5:
        for i in range(len(block_fake)):
            p = random.random()
            if p >= 0.5:
                l = lines[random.randint(0, len(lines) - 1)]
                block_fake[i] = (l, 0)
    return block_fake


def create_one_instance(block, lines):
    records = []
    fake_block = create_fake_block(block, lines)

    blocks = list(zip(block, fake_block))
    for item in blocks:
        real_sentence = item[0][0]
        real_label = item[0][1]
        fake_sentence = item[1][0]
        records.append((real_sentence, fake_sentence, real_label))
    return records


def get_blocks(dataset, test=False):
    all_blocks = []
    for example in dataset:
        labels = example['labels']
        embds = example['embeddings']
        raw_blocks = []
        stride = 1 if test else config.sent_stride
        i = 0
        while i < len(labels):
            block = [(embd, label) for embd, label in
                     zip(embds[i:i + config.sent_window], labels[i:i + config.sent_window])]
            if len(block) < config.sent_window:
                block.extend([(fake_sent_embedding, 0)] * (config.sent_window - len(block)))
            raw_blocks.append(block)
            i += stride

        if not test:
            raw_blocks = raw_blocks[:int(config.perc_blocks_train * len(raw_blocks))]
            random.shuffle(raw_blocks)

        doc_recs = []
        for rb in raw_blocks:
            records = create_one_instance(rb, embds)
            doc_recs.extend(records)

        # save doc_recs to numpy array
        all_blocks.extend(doc_recs)
    return all_blocks


train_blocks = get_blocks(train_dataset)
val_blocks = get_blocks(val_dataset, test=True)

In [None]:
import numpy as np
import os
from tqdm import tqdm

dtype = np.dtype([('real', np.float32, embedding_size), ('fake', np.float32, embedding_size), ('label', np.int8)])
data_dir = "./data/processed"


def save_to_numpy(blocks, filename):
    # create file if it doesn't exist
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    # save all_blocks to numpy array
    arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(len(blocks),))

    # create batches of blocks for faster write
    batch_size = 1024
    batches = [blocks[i:i + batch_size] for i in range(0, len(blocks), batch_size)]
    for batch_idx, batch in tqdm(enumerate(batches), desc=f'writing {filename}'):
        arr_batch = np.array(batch, dtype=dtype)
        arr[batch_idx * batch_size: batch_idx * batch_size + len(arr_batch)] = arr_batch
    arr.flush()


save_to_numpy(train_blocks, os.path.join(data_dir, 'train.bin'))
save_to_numpy(val_blocks, os.path.join(data_dir, 'val.bin'))

In [None]:
# load the numpy array from disk
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=dtype, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=dtype, mode='r')

In [None]:
import torch
from datasets import load_dataset

batch_size = 12  # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = config.sent_window

device_type = 'cuda' if 'cuda' in device else 'cpu'  # for later use in torch.autocast

train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=dtype, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=dtype, mode='r')


def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = []
    y = []
    for i in ix:
        batch_numpy = data[i:i + block_size]
        # given numpy array strides not a multiple of the element byte size. Copy the numpy array to reallocate the memory.
        batch_numpy = np.copy(batch_numpy)
        x.append(torch.stack(
            [torch.from_numpy(np.copy(batch_numpy['real'])), torch.from_numpy(np.copy(batch_numpy['fake']))]))
        y.append(torch.from_numpy(batch_numpy['label']))
    x = torch.stack(x)
    y = torch.stack(y)
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [None]:
sample_batch = get_batch('train')
assert sample_batch[0].shape == (batch_size, 2, block_size, embedding_size)
assert sample_batch[1].shape == (batch_size, block_size)

In [None]:
# create model
import torch
from tqdm import tqdm
from transformers import AdamW


class TopicSegmentationModel(torch.nn.Module):

    def __init__(self):
        super(TopicSegmentationModel, self).__init__()

        self.linear1 = torch.nn.Linear(block_size * embedding_size, block_size)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(block_size, block_size)

    def forward(self, x):
        x_real = x[:, 0, :, :].reshape(batch_size, -1)
        x = self.linear1(x_real)
        x = self.activation(x)
        x = self.linear2(x)
        return x


eval_iters = 10
eval_interval = 10
num_epochs = 1
best_val_loss = 1e9
ckpt_path = './model_checkpoints/ckpt.pt'
num_batches = len(train_data) // batch_size
always_save_checkpoint = True
init_from = "scratch"
epoch_offset = 0
loss_fn = torch.nn.BCEWithLogitsLoss()
model = TopicSegmentationModel()
sentence_model.to(device)

if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    optimizer = AdamW(sentence_model.parameters(), lr=1e-5)
else:
    print(f"Resuming training from {ckpt_path}")
    # resume training from a checkpoint.
    checkpoint = torch.load(ckpt_path, map_location=device)
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line

    # create the model
    optimizer = AdamW(sentence_model.parameters(), lr=1e-5)
    optimizer.load_state_dict(checkpoint['optimizer'])

    epoch = checkpoint['iter_num']

    state_dict = checkpoint['model']
    epoch_offset = checkpoint['epoch_offset']
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']


@torch.no_grad()
def estimate_loss():
    out = {}
    sentence_model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            y_hat = model(x)
            y = y.float()
            loss_value = loss_fn(y_hat, y)
            losses[k] = loss_value
        out[split] = losses.mean()
    sentence_model.train()
    return out


for epoch in range(epoch_offset, epoch_offset + num_epochs):
    # train on training set
    # evaluate the loss on train/val sets and write checkpoints
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            # print losses
            print(f"epoch {epoch} train loss {losses['train']:.4f} val loss {losses['val']:.4f}")
            if epoch > -1:
                checkpoint = {
                    'model': sentence_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'iter_num': epoch,
                    'best_val_loss': best_val_loss,
                    'epoch_offset': epoch,
                }
                print(f"saving checkpoint to {ckpt_path}")
                torch.save(checkpoint, ckpt_path)
    model.train()
    train_loss = 0
    for batch_idx in range(num_batches):
        x, y = get_batch('train')
        optimizer.zero_grad()
        y_hat = model(x)
        y = y.float()
        loss_value = loss_fn(y_hat, y)
        train_loss += loss_value.item()
        loss_value.backward()
        optimizer.step()
    #
    # # evaluate on validation set
    # model.eval()
    # with torch.no_grad():
    #     x, y = get_batch('val')
    #     y_hat = model(x)
    #     y = y.float()
    #     loss_value = loss_fn(y_hat, y)
    #     print(f'Epoch {epoch} validation loss: {loss_value.item()}')


In [None]:
def predict(dataset, output_dir):
    res = model.forward(dataset)

    print("Documents to segment: " + str(len(test_texts[0])))
    flat_blocks = []
    for x in test_texts[0]:
        print(len(x[1]))
        flat_blocks.extend(x[1])

    print("Number of prediction blocks: " + str(len(flat_blocks)))

    print("Predicting with the model (this may take a while, depending on the number of documents)...")
    res_list = list(res)
    print("Predictions completed.")

    thold = 0.3 if config.MODEL_TYPE == "cats" else 0.5

    glob_cntr = 0
    docs = test_texts[0]

    agg_docs = []

    for i in range(len(docs)):
        fname = docs[i][0]
        if i % 1000 == 1:
            print(fname)
            print(str(i) + " of " + str(len(docs)) + " documents...")
        blocks = docs[i][1]
        preds_blocks = res_list[glob_cntr: glob_cntr + len(blocks)]
        glob_cntr += len(blocks)

        sent_scores = [(b[0][0], b[0][1], []) for b in blocks]
        for b_ind in range(len(blocks)):
            for relb_ind in range(len(blocks[b_ind])):
                if blocks[b_ind][relb_ind][0] == config.fake_sent:
                    break
                else:
                    sent_ind = b_ind + relb_ind
                    score = preds_blocks[b_ind][relb_ind][1]
                    sent_scores[sent_ind][2].append(score)
        agg_sent_scores = [(x[0], x[1], np.mean(x[2]), (1 if np.mean(x[2]) >= thold else 0)) for x in sent_scores]
        agg_docs.append(agg_sent_scores)

    # printing out predictions
    docnames = [x[0] for x in docs]
    print("Storing segmented texts...")
    docscores = zip(docnames, agg_docs)
    for name, sentscores in docscores:
        print("Document: " + name)
        lines = []
        for s in sentscores:
            if s[2] >= thold:
                lines.append(config.seg_start)
            lines.append(s[0] + "\t" + str(s[2]) if write_pred_score else s[0])
        utils.write_list(os.path.join(output_dir, name.split("/")[-1] + ".seg"), lines)
    print("Stored.")

predict(test_dataset, "./data/output")