In [2]:
save_path = '/content/drive/MyDrive/EECS595-Fall2020/model.pt'
train_data_path = '/content/drive/Shared drives/EECS595-Fall2020/Final_Project_Common/EAT/eat_train.json'
temperature = 0.1

In [4]:
import torch
from torch import nn
try:
  import omegaconf
except:
  !pip install omegaconf
  !pip install hydra-core
  !pip install pytorch_lightning
import json
import random
import tqdm
from pytorch_lightning.metrics.functional.classification import auroc, accuracy, recall, precision
from pytorch_lightning.metrics.functional import f1
import math

def kfold(tensor, index):
  try:
    return torch.cat([tensor[:index*tensor.shape[0]//10], tensor[(index+1)*tensor.shape[0]//10:]])
  except:
    return tensor[:index*len(tensor)//10] + tensor[(index+1)*len(tensor)//10:]

def kfoldtest(tensor, index):
  try:
    return tensor[index*tensor.shape[0]//10:(index+1)*tensor.shape[0]//10]
  except:
    return tensor[index*len(tensor)//10:(index+1)*len(tensor)//10]

def smooth_max(x, temp=1.0):
  return torch.sum(torch.softmax(x/temp, dim=-1) * x, dim=-1)

def pad_sentences(sents):
  return sents + ['' for x in range(7-len(sents))]

def get_pairs(words, label, all=False):
  if all:
    allpairs = [(i,j) for i in range(len(words)) for j in range(len(words)) if i < j and i!=j]
  elif label == -1:
    allpairs = random.sample([(i,j) for i in range(len(words)) for j in range(len(words)) if i < j and i!=j], 7)
  else:
    allpairs = [(i, label) for i in range(label)]
  return [(words[i], words[j]) for i,j in allpairs]

def compress_all_pairs(logits):
  output = torch.stack([logits[i*(i+1)//2:(i+1)*(i+2)//2].max() for i in range(math.floor(math.sqrt(8*logits.shape[0]+1)-1)//2)])
  return output

def compress_all_pairs_smooth(logits, temp=1.0):
  output = torch.stack([smooth_max(logits[i*(i+1)//2:(i+1)*(i+2)//2], temp) for i in range(math.floor(math.sqrt(8*logits.shape[0]+1)-1)//2)])
  return output

def score_accuracy(preds, labels, threshold):
  correct = 0
  for p, l in zip(preds, labels):
    p2 = p > threshold
    if l != -1:
      if p2[l-1] and torch.all(~p2[:l-1]):
        correct += 1
    else:
      if torch.all(~p2):
        correct += 1
  return correct/len(preds)

def get_predictions(preds, threshold):
  out = []
  for p in preds:
    p2 = p > threshold
    if torch.any(p2):
      out.append(torch.where(p2)[0][0].item()+1)
    else:
      out.append(-1)
  return out

def log1mexp(x):
    return torch.where(x > -0.693, torch.log(-torch.expm1(x)), torch.log1p(-torch.exp(x)))


# Download RoBERTa already finetuned for MNLI
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.register_classification_head('eat', num_classes=2)
roberta.load_state_dict(torch.load('/content/drive/MyDrive/EECS595-Fall2020/model.pt'))
#roberta.eval()  # disable dropout for evaluation
roberta = roberta.cuda()

from fairseq.data.data_utils import collate_tokens

dev_file = '/content/drive/Shared drives/EECS595-Fall2020/Final_Project_Common/EAT/eat_train.json'
dev_data = json.load(open(dev_file))

sentences = [x['story'] for x in dev_data]
labels = torch.tensor([x['breakpoint'] for x in dev_data])
tokenized_sents = sum([list(zip(pad_sentences(x[:-1]), pad_sentences(x[1:]))) for x in sentences], [])
batch = collate_tokens(
    [roberta.encode(pair[0], pair[1]) for pair in tokenized_sents[:7]], pad_idx=1
)

params = [x for n,x in roberta.named_parameters() if '.eat.' in n]
# For frozen base use
# opt = torch.optim.AdamW(params,lr=2e-5)
opt = torch.optim.AdamW(roberta.parameters(),lr=2e-6)


torch.set_printoptions(sci_mode=False)
# Encode a pair of sentences and make a prediction
split = 1

criterion = nn.BCEWithLogitsLoss()

ema = 0
denom = 0
for epoch in range(50):
  # Comment out this next line for frozen base
  roberta.train()
  for words, label in tqdm.tqdm_notebook(zip(kfold(sentences, split), kfold(labels, split))):
    opt.zero_grad()
    pairs = get_pairs(words, label, True)
    batch = collate_tokens(
      [roberta.encode(pair[0], pair[1]) for pair in pairs], pad_idx=1
    )
    logits = roberta.predict('eat', batch)[:,0]
    predictions = compress_all_pairs_smooth(logits, temperature)
    if label != -1:
      loss = -log1mexp(predictions[:label])
      loss[label-1] = -predictions[label-1]
      loss = loss.mean()
    else:
      loss = -log1mexp(predictions).mean()
    ema = 0.99*ema + loss.item()
    denom = 0.99*denom + 1
    loss.backward()
    opt.step()

  roberta.eval()
  total = 0.0
  count = 0
  preds = []
  bp_preds = []
  ys = []
  bp_ys = []
  for words, label in zip(kfoldtest(sentences, split), kfoldtest(labels, split)):
    pairs = get_pairs(words, label, True)
    batch = collate_tokens(
      [roberta.encode(pair[0], pair[1]) for pair in pairs], pad_idx=1
    )
    with torch.no_grad():
      logits = roberta.predict('eat', batch)[:,0]
      logitsc = compress_all_pairs_smooth(logits, temperature)
      bp_preds.append(torch.exp(logitsc).cpu())
      bp_ys.append(label)
      prediction = logits.max()
      preds.append(torch.exp(prediction).cpu())
      ys.append((label != -1).float().cpu())
      if label != -1:
        loss = -log1mexp(logitsc[:label])
        loss[label-1] = -logitsc[label-1]
        loss = loss.mean()
      else:
        loss = -log1mexp(logitsc).mean()
        total += loss.item()
        count += 1
  torch.save(roberta.state_dict(), save_path)
  preds = torch.stack(preds, dim=0)
  ys = torch.stack(ys, dim=0)
  accs = torch.tensor([accuracy(preds > i/100.0, ys) for i in range(100)])
  score_accs = torch.tensor([score_accuracy(bp_preds, bp_ys, i/100.0) for i in range(100)])
  final_preds = get_predictions(bp_preds, score_accs.argmax().item()/100.0)
  final_preds = [4 if x!=-1 else -1 for x in final_preds]
  f1_score = f1(torch.tensor([x if x!=-1 else 0 for x in final_preds]), torch.tensor([x if x!=-1 else 0 for x in bp_ys]), 6)
  recall_score = recall(torch.tensor([x if x!=-1 else 0 for x in final_preds]), torch.tensor([x if x!=-1 else 0 for x in bp_ys]), 6, 'macro')
  precision_score = precision(torch.tensor([x if x!=-1 else 0 for x in final_preds]), torch.tensor([x if x!=-1 else 0 for x in bp_ys]), 6, 'macro')
  print(f'f1: {f1_score.item()}, recall:{recall_score.item()}, precision:{precision_score.item()}, accuracy:{accuracy(torch.tensor([1 if x!=-1 else 0 for x in final_preds]), torch.tensor([1 if x!=-1 else 0 for x in bp_ys]), 6).item()}')
  print(f'Threshold 1: {accs.argmax().item()/100.0}, Threshold 2: {score_accs.argmax().item()/100.0}')

Using cache found in /root/.cache/torch/hub/pytorch_fairseq_master
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


f1: 0.6538461446762085, recall:0.26787880063056946, precision:0.2142857164144516, accuracy:0.8269230723381042
Threshold 1: 0.09, Threshold 2: 0.14




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




KeyboardInterrupt: ignored

In [1]:
test_path = '/content/drive/Shared drives/EECS595-Fall2020/Final_Project_Common/EAT/eat_test_unlabeled.json'
threshold = 0.35

In [3]:
import torch
from torch import nn
try:
  import omegaconf
except:
  !pip install omegaconf
  !pip install hydra-core
  !pip install pytorch_lightning
import json
import random
import tqdm
from pytorch_lightning.metrics.functional.classification import auroc, accuracy, recall, precision
from pytorch_lightning.metrics.functional import f1
import math


def pad_sentences(sents):
  return sents + ['' for x in range(7-len(sents))]

def get_pairs(words, label, all=False):
  if all:
    allpairs = [(i,j) for i in range(len(words)) for j in range(len(words)) if i < j and i!=j]
  elif label == -1:
    allpairs = random.sample([(i,j) for i in range(len(words)) for j in range(len(words)) if i < j and i!=j], 7)
  else:
    allpairs = [(i, label) for i in range(label)]
  return [(words[i], words[j]) for i,j in allpairs]

def compress_all_pairs(logits):
  output = torch.stack([logits[i*(i+1)//2:(i+1)*(i+2)//2].max() for i in range(math.floor(math.sqrt(8*logits.shape[0]+1)-1)//2)])
  return output

def get_predictions(preds, threshold):
  out = []
  for p in preds:
    p2 = p > threshold
    if torch.any(p2):
      out.append(torch.where(p2)[0][0].item()+1)
    else:
      out.append(-1)
  return out


# Download RoBERTa already finetuned for MNLI
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.register_classification_head('eat', num_classes=2)
roberta.load_state_dict(torch.load(save_path))
roberta.eval()  # disable dropout for evaluation
roberta = roberta.cuda()

from fairseq.data.data_utils import collate_tokens

# Load Data
dev_file = test_path
dev_data = json.load(open(dev_file))

sentences = [x['story'] for x in dev_data]
ids = [x['id'] for x in dev_data]
tokenized_sents = sum([list(zip(pad_sentences(x[:-1]), pad_sentences(x[1:]))) for x in sentences], [])
batch = collate_tokens(
    [roberta.encode(pair[0], pair[1]) for pair in tokenized_sents[:7]], pad_idx=1
)

#Make Predictions
bp_preds = []
for words in sentences:
  pairs = get_pairs(words, None, True)
  batch = collate_tokens(
    [roberta.encode(pair[0], pair[1]) for pair in pairs], pad_idx=1
  )
  with torch.no_grad():
    logits = roberta.predict('eat', batch)[:,0]
    logitsc = compress_all_pairs(logits)
    bp_preds.append(torch.exp(logitsc).cpu())


final_preds = get_predictions(bp_preds, threshold)
outputs = [{'id': id, 'pred_label': 1 if pred == -1 else 0, 'pred_breakpoint': pred} for id, pred in zip(ids, final_preds)]
print(outputs)
with open('predictions.json', 'w') as f:
  f.write(json.dumps(outputs))

Collecting omegaconf
  Downloading https://files.pythonhosted.org/packages/e5/f6/043b6d255dd6fbf2025110cea35b87f4c5100a181681d8eab496269f0d5b/omegaconf-2.0.5-py3-none-any.whl
Collecting PyYAML>=5.1.*
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |█▏                              | 10kB 25.7MB/s eta 0:00:01[K     |██▍                             | 20kB 32.1MB/s eta 0:00:01[K     |███▋                            | 30kB 37.6MB/s eta 0:00:01[K     |████▉                           | 40kB 26.7MB/s eta 0:00:01[K     |██████                          | 51kB 21.5MB/s eta 0:00:01[K     |███████▎                        | 61kB 23.7MB/s eta 0:00:01[K     |████████▌                       | 71kB 19.5MB/s eta 0:00:01[K     |█████████▊                      | 81kB 21.0MB/s eta 0:00:01[K     |███████████                     | 92kB 19.9MB/s eta 0:00:01[K     |████████████▏ 

Downloading: "https://github.com/pytorch/fairseq/archive/master.zip" to /root/.cache/torch/hub/master.zip


running build_ext
cythoning fairseq/data/data_utils_fast.pyx to fairseq/data/data_utils_fast.cpp




cythoning fairseq/data/token_block_utils_fast.pyx to fairseq/data/token_block_utils_fast.cpp
building 'fairseq.libbleu' extension
creating build
creating build/temp.linux-x86_64-3.6
creating build/temp.linux-x86_64-3.6/fairseq
creating build/temp.linux-x86_64-3.6/fairseq/clib
creating build/temp.linux-x86_64-3.6/fairseq/clib/libbleu
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/include/python3.6m -c fairseq/clib/libbleu/libbleu.cpp -o build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/libbleu.o -std=c++11 -O3 -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=libbleu -D_GLIBCXX_USE_CXX11_ABI=0
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/include/python3.6m -c fairseq/clib/libbleu/module.cpp -o build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/modul

100%|██████████| 751652118/751652118 [00:26<00:00, 28364866.31B/s]
1042301B [00:00, 2490997.07B/s]
456318B [00:00, 1509489.76B/s]


[{'id': 'test_0', 'pred_label': 0, 'pred_breakpoint': 3}, {'id': 'test_1', 'pred_label': 0, 'pred_breakpoint': 3}, {'id': 'test_2', 'pred_label': 1, 'pred_breakpoint': -1}, {'id': 'test_3', 'pred_label': 0, 'pred_breakpoint': 3}, {'id': 'test_4', 'pred_label': 0, 'pred_breakpoint': 3}, {'id': 'test_5', 'pred_label': 0, 'pred_breakpoint': 3}, {'id': 'test_6', 'pred_label': 0, 'pred_breakpoint': 4}, {'id': 'test_7', 'pred_label': 0, 'pred_breakpoint': 1}, {'id': 'test_8', 'pred_label': 0, 'pred_breakpoint': 2}, {'id': 'test_9', 'pred_label': 1, 'pred_breakpoint': -1}, {'id': 'test_10', 'pred_label': 0, 'pred_breakpoint': 2}, {'id': 'test_11', 'pred_label': 1, 'pred_breakpoint': -1}, {'id': 'test_12', 'pred_label': 1, 'pred_breakpoint': -1}, {'id': 'test_13', 'pred_label': 0, 'pred_breakpoint': 4}, {'id': 'test_14', 'pred_label': 0, 'pred_breakpoint': 1}, {'id': 'test_15', 'pred_label': 0, 'pred_breakpoint': 1}, {'id': 'test_16', 'pred_label': 0, 'pred_breakpoint': 2}, {'id': 'test_17', '