In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

In [3]:
from argparse import Namespace

args = Namespace(
    random_state=42,
    preprocessing=True,
    test_size=0.2,
    num_queries=20,
    batch_size = 5,
    input_path="./input/feedback-prize-2021/"
)

In [4]:
# from datasets import build_fdb_data, collate_fn

# dataset, val, postprocessor, num_classes = build_fdb_data(args)

In [5]:
import numpy as np

from datasets.processing_funcs import PIPELINE
from datasets.fbp_dataset import FBPDataset, load_texts
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split


preprocess = PIPELINE if args.preprocessing else []
documents, tags = load_texts(args.input_path, preprocess)  # type: ignore

encoder = OrdinalEncoder()
label_unique = np.array(tags["discourse_type"].unique())  # type: ignore
encoder.fit(label_unique.reshape(-1, 1))

train_idx, val_idx = train_test_split(
    documents.index, test_size=args.test_size, random_state=args.random_state
)

train_dataset = FBPDataset(documents[train_idx], tags, encoder)  # type:ignore
val_dataset = FBPDataset(documents[val_idx], tags, encoder)  # type:ignore

num_classes = len(label_unique)

100%|██████████| 15594/15594 [01:33<00:00, 166.08it/s] 


In [6]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    return tuple(list(i) for i in zip(*batch))

dl = DataLoader(train_dataset, args.batch_size, collate_fn=collate_fn)

docs, targets, infos = next(iter(dl))

In [7]:
predictions = {
    'pred_logits': torch.rand([args.batch_size, args.num_queries, num_classes + 1]) * torch.Tensor([[1] * num_classes + [10]]),
    'pred_boxes': torch.rand([args.batch_size, args.num_queries, 2]) * torch.Tensor([[1, 0.1]])
}

In [8]:
from models.matcher import HungarianMatcher
from models.criterion import CriterionDETR

matcher = HungarianMatcher()

weight_dict = {'loss_ce': 1., 'loss_bbox': 1., 'loss_giou': 1.}
losses = ['labels', 'boxes', 'cardinality']

criterion = CriterionDETR(num_classes, matcher=matcher, weight_dict=weight_dict,
                            eos_coef=0.1, losses=losses)

In [9]:
criterion(predictions, targets)

{'loss_ce': tensor(4.1624),
 'loss_bbox': tensor(0.2000),
 'loss_giou': tensor(1.0412),
 'cardinality_error': tensor(7.4000)}

In [17]:
from datasets.postprocess import FBPPostProcess

postprocessor = FBPPostProcess(encoder, tags, num_classes)

In [18]:
postprocessor.add_outputs(predictions, infos)

In [27]:
postprocessor.results

Unnamed: 0,id,class,predictionstring,score
0,97E4E42863A3,Evidence,122 123 124 125 126 127 128 129 130 131 132 13...,0.187968
1,97E4E42863A3,Rebuttal,200,0.164482
2,EB793B72A8C4,Evidence,831 832 833 834 835 836 837 838 839 840 841 84...,0.165227
3,EB793B72A8C4,Position,496 497 498 499 500 501 502 503 504 505 506 50...,0.175308
4,EB793B72A8C4,Counterclaim,424 425 426 427 428 429 430 431 432 433 434,0.199229
5,0119F710D008,Claim,11 12 13 14 15 16 17 18 19 20 21 22,0.168848
6,0119F710D008,Claim,62 63 64,0.188737
7,0119F710D008,Counterclaim,0 1 2 3 4,0.199665
8,F00B4D036D97,Position,172 173 174 175 176 177 178 179 180 181 182 18...,0.174763


In [28]:
postprocessor.evaluate()

In [51]:
import util.visualization as viz

idx = 4

id_example = infos[idx]['id']
doc = docs[idx]

In [52]:
viz.highlight_segments(id_example, doc, tags)

In [53]:
viz.highlight_segments(id_example, doc, postprocessor.results)

In [None]:
postprocessor.evaluate()