# Imports

In [None]:
!pip install transformers
!pip install gradio
!pip install torch
!pip install gdown

In [None]:
!gdown https://drive.google.com/uc?id=1EAjsKML35-b6TseCmmgfkGIq9uqc_oEI

In [None]:
import torch
from transformers import LongformerForTokenClassification, LongformerTokenizer
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import csv
import sys

In [None]:
# use this if on colab and loading the data from google drive
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_GPUS = 1

In [None]:
unfreeze_range = range(9, 12)
learning_rate = 0.001

# Model

In [None]:
longformer = LongformerForTokenClassification.from_pretrained('allenai/longformer-base-4096')
longformer = longformer.to(device)
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

In [None]:
def freeze_model(model):
  for param in model.parameters():
    param.requires_grad = False
  for param in model.classifier.parameters():
    param.requires_grad = True
  for i in unfreeze_range:
      for param in model.longformer.encoder.layer[i].parameters():
        param.requires_grad = True
  return model
  
longformer = freeze_model(longformer)

if NUM_GPUS > 1:
  longformer = torch.nn.DataParallel(longformer, list(range(NUM_GPUS)))

# only optim last layer
optimizer = torch.optim.Adam(longformer.classifier.parameters(), lr=learning_rate)

# Dataset

In [None]:
csv.field_size_limit(sys.maxsize)
class DebateDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_len, max_count=-1):
        # self.df = pd.read_csv(csv_file)
        # self.df = pd.read_csv(csv_file, header=None, sep='\n')
        # df = df[0].str.split('\s\|\s', expand=True)
        self.df = []
        with open(csv_file, 'r') as f:
          reader = csv.reader(f)
          count = 0
          for row in reader:
            count += 1
            if count > max_count and max_count != -1:
              break
            self.df.append([string.replace('\n', ' ') for string in row])
        print(self.df[0])
        print(self.df[1])
        
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df[index]
        abstract = tokenizer(row[0], add_special_tokens=False)['input_ids']
        labels = [0] * (len(abstract) + 3)
        tokens = []
        for i in range(1, len(row), 2):
            tokenized = self.tokenizer(row[i+1], add_special_tokens=False)['input_ids']
            labels += [int(row[i])] * len(tokenized)
            tokens += tokenized
        length = 4 + len(abstract) + len(tokens)
        
        if length > self.max_len:
          tokens = tokens[:self.max_len - (4 + len(abstract))]
          labels = labels[:self.max_len - 1]
        combined = self.tokenizer.build_inputs_with_special_tokens(abstract, tokens)
        labels.append(0)
        # tokens, attention_mask, global_attention_mask, labels, loss_mask
        ret = (
             torch.tensor(combined, dtype=torch.int),
             torch.tensor([1] * len(combined), dtype=torch.bool),
             torch.tensor([1] * (len(abstract)+2) + [0] * (len(tokens) + 2), dtype=torch.bool),
             torch.tensor(labels),
             torch.tensor([0] * (len(abstract)+2) + [1] * (len(tokens) + 2), dtype=torch.bool)
        )
        if ret[0].shape != ret[1].shape or ret[0].shape != ret[2].shape or ret[0].shape != ret[3].shape or ret[0].shape != ret[4].shape:
          print(ret)
          print(ret[0].shape)
          print(ret[1].shape)
          print(ret[2].shape)
          print(ret[3].shape)
          print(ret[4].shape)
        return ret


In [None]:
def collate(sequences):
  tokens, attention_mask, global_attention_mask, labels, loss_mask = list(zip(*sequences))
  tokens_batch = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)
  labels_batch = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
  attention_mask_batch = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
  global_attention_mask_batch = torch.nn.utils.rnn.pad_sequence(global_attention_mask, batch_first=True, padding_value=0)
  loss_mask_batch = torch.nn.utils.rnn.pad_sequence(loss_mask, batch_first=True, padding_value=0)
  return tokens_batch, attention_mask_batch, global_attention_mask_batch, labels_batch, loss_mask_batch

In [None]:
# data_loader = torch.utils.data.DataLoader(dataset, batch_size = 2, shuffle = True, collate_fn=collate, num_workers=1)
max_len = 2500

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
dataset = DebateDataset('./debatefinal.csv', tokenizer, max_len, max_count=2000)
batch_size = 2
test_split = .1
valid_split = .1
shuffle = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
test_index = int(np.floor(test_split * dataset_size))
valid_index = int(np.floor(valid_split * dataset_size))
np.random.seed(random_seed)
np.random.shuffle(indices)
test_indices, val_indices, train_indices = indices[:test_index], indices[test_index:test_index+valid_index], indices[test_index + valid_index:]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, collate_fn=collate, num_workers=1, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, collate_fn=collate, num_workers=1, sampler=test_sampler)
valid_loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, collate_fn=collate, num_workers=1, sampler=valid_sampler)


# Train functions

In [None]:
cross_entropy = CrossEntropyLoss()
def loss_fn(logits, labels, mask):
  logits, labels, mask = torch.flatten(logits, 0, 1), torch.flatten(labels), torch.flatten(mask)
  logits = logits[mask]
  labels = labels[mask]
  return cross_entropy(logits, labels)

In [None]:
def accuracy_fn(logits, labels, mask):
    logits, labels, mask = torch.flatten(logits, 0, 1), torch.flatten(labels), torch.flatten(mask)
    logits = logits[mask]
    labels = labels[mask]
    truth = logits.max(1)[1].eq(labels)
    return truth.sum().item()

In [None]:
def valid_fn():
  longformer.eval()
  with torch.no_grad():
    total_loss = 0
    total_accuracy = 0
    total_count = 0
    for i, (tokens, attention_mask, global_attention_mask, labels, loss_mask) in enumerate(valid_loader):
      tokens = tokens.to(device)
      attention_mask = attention_mask.to(device)
      global_attention_mask = global_attention_mask.to(device)
      labels = labels.to(device)
      loss_mask = loss_mask.to(device)

      logits = longformer(
        tokens, 
        attention_mask = attention_mask,
        global_attention_mask = global_attention_mask
      )['logits']

      total_loss += loss_fn(logits, labels, loss_mask).item()
      total_accuracy += accuracy_fn(logits, labels, loss_mask)
      total_count += loss_mask.sum().item()
    return total_loss/total_count, total_accuracy/total_count

      
      


# Train Loop

In [None]:
losses = []
accuracies = []
epochs = 10
for epoch in range(epochs):
  for i, (tokens, attention_mask, global_attention_mask, labels, loss_mask) in enumerate(train_loader):
    # resets the gradients in the tensors
    optimizer.zero_grad()

    longformer.train()

    tokens = tokens.to(device)
    attention_mask = attention_mask.to(device)
    global_attention_mask = global_attention_mask.to(device)
    labels = labels.to(device)
    loss_mask = loss_mask.to(device)

    logits = longformer(
      tokens, 
      attention_mask = attention_mask,
      global_attention_mask = global_attention_mask
    )['logits']

    loss = loss_fn(logits, labels, loss_mask)
    accuracy = accuracy_fn(logits, labels, loss_mask)/loss_mask.sum().item()
    losses.append(loss.item()/loss_mask.sum().item())
    accuracies.append(accuracy)
    loss.backward()
    optimizer.step()

    if i % 5 == 0:
      print(f'epoch: {epoch}/{epochs} batch: {i} loss: {losses[-1]} accuracy: {accuracies[-1]}') 
    if i % 50 == 0:
      average_loss, average_accuracy = valid_fn()
      print(f'validation loss: {average_loss}, validation accuracy: {average_accuracy}')


In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)
plt.ylim(0, 0.01)

In [None]:
def predict(model, abstract, text):
    abstract = tokenizer(abstract, add_special_tokens=False)['input_ids']
    tokens = tokenizer(text, add_special_tokens=False)['input_ids']
    length = 4 + len(abstract) + len(tokens)


    # if length > self.max_len:
    #   tokens = tokens[:self.max_len - (4 + len(abstract))]
    #   labels = labels[:self.max_len - 1]
    combined = tokenizer.build_inputs_with_special_tokens(abstract, tokens)
    # tokens, attention_mask, global_attention_mask, labels, loss_mask

    global_attention_mask = torch.tensor([1] * (len(abstract)+2) + [0] * (len(tokens) + 2), dtype=torch.bool).unsqueeze(0).to(device)
    tokens = torch.tensor(combined, dtype=torch.int).unsqueeze(0).to(device)
    attention_mask = torch.tensor([1] * len(combined), dtype=torch.bool).unsqueeze(0).to(device)
    
    model.eval()
    logits = model(
        tokens, 
        attention_mask = attention_mask,
        global_attention_mask = global_attention_mask
    )['logits']
    mask = torch.logical_and(logits.max(2)[1], torch.logical_not(global_attention_mask))

    print(mask)
    for (token, highlight) in zip(tokens[0], mask[0]):
        if highlight:
            print('\033[31m' + tokenizer.decode(token) + '\033[0m', end='')
        else:
            print(tokenizer.decode(token), end='')


In [None]:
predict(longformer, 'Privacy violations inevitable – tech and corporations',
'The truth is that consumers love the benefits of digital goods and are willing to give up traditionally private information in exchange for the manifold miracles that the Internet and big data bring. Apple and Android each offer more than a million apps, most of which are built upon this model, as are countless other Internet services. More generally, big data promises huge improvements in economic efficiency and productivity, and in health care and safety. Absent abuses on a scale we have not yet seen, the public’s attitude toward giving away personal information in exchange for these benefits will likely persist, even if the government requires firms to make more transparent how they collect and use our data. One piece of evidence for this is that privacy-respecting search engines and email services do not capture large market shares. In general these services are not as easy to use, not as robust, and not as efficacious as their personal-data-heavy competitors. Schneier understands and discusses all this. In the end his position seems to be that we should deny ourselves some (and perhaps a lot) of the benefits big data because the costs to privacy and related values are just too high. We “have to stop the slide” away from privacy, he says, not because privacy is “profitable or efficient, but because it is moral.” But as Schneier also recognizes, privacy is not a static moral concept. “Our personal definitions of privacy are both cultural and situational,” he acknowledges. Consumers are voting with their computer mice and smartphones for more digital goods in exchange for more personal data. The culture increasingly accepts the giveaway of personal information for the benefits of modern computerized life. This trend is not new. “The idea that privacy can’t be invaded at all is utopian,” says Professor Charles Fried of Harvard Law School. “There are amounts and kinds of information which previously were not given out and suddenly they have to be given out. People adjust their behavior and conceptions accordingly.” That is Fried in the 1970 Newsweek story, responding to an earlier generation’s panic about big data and data mining. The same point applies today, and will apply as well when the Internet of things makes today’s data mining seem as quaint as 1970s-era computation.'
      )

# Gradio

In [20]:
import gradio as gr

def predict_gradio(abstract, text, threshold):
    abstract = tokenizer(abstract, add_special_tokens=False)['input_ids']
    tokens = tokenizer(text, add_special_tokens=False)['input_ids']
    length = 4 + len(abstract) + len(tokens)


    if length > max_len:
      return "TOO LONG"
    combined = tokenizer.build_inputs_with_special_tokens(abstract, tokens)

    global_attention_mask = torch.tensor([1] * (len(abstract)+2) + [0] * (len(tokens) + 2), dtype=torch.bool).unsqueeze(0).to(device)
    tokens = torch.tensor(combined, dtype=torch.int).unsqueeze(0).to(device)
    attention_mask = torch.tensor([1] * len(combined), dtype=torch.bool).unsqueeze(0).to(device)
    
    longformer.eval()
    logits = longformer(
        tokens, 
        attention_mask = attention_mask,
        global_attention_mask = global_attention_mask
    )['logits']
    probs = torch.nn.functional.softmax(logits, dim=2)
    mask = torch.logical_and(probs[:, :, 1] > threshold, torch.logical_not(global_attention_mask))

    ret = ''
    for (token, highlight) in zip(tokens[0], mask[0]):
        if highlight:
            ret += '<span class="highlight">' + tokenizer.decode(token) + '</span>'
        else:
            ret += tokenizer.decode(token)
    print(ret)
    return ret.split('</s></s>')[1]

with gr.Blocks(css=".highlight {background-color: yellow; color: black}") as demo:
    abstract = gr.Textbox(label="Tag")
    text = gr.Textbox(label="Text")
    threshold = gr.Slider(label="Threshold", minimum=0, maximum=1, step=0.01)
    highlighted = gr.Markdown(label="Highlighted")
    greet_btn = gr.Button("Magic")
    greet_btn.click(fn=predict_gradio, inputs=[abstract, text, threshold], outputs=highlighted)

demo.launch(share=True, debug=True)

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://f1932d6819c4f2bd40.gradio.live


