In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch

In [4]:
from argparse import Namespace

args = Namespace(
    random_state=42,
    preprocessing=True,
    test_size=0.2,
    num_queries=20,
    batch_size = 5,
    input_path="./input/feedback-prize-2021/"
)

In [5]:
# from datasets import build_fdb_data, collate_fn

# dataset, val, postprocessor, num_classes = build_fdb_data(args)

In [6]:
import numpy as np

from datasets.processing_funcs import PIPELINE
from datasets.fbp_dataset import FBPDataset, load_texts
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split


preprocess = PIPELINE if args.preprocessing else []
documents, tags = load_texts(args.input_path, preprocess)  # type: ignore

encoder = OrdinalEncoder()
label_unique = np.array(tags["discourse_type"].unique())  # type: ignore
encoder.fit(label_unique.reshape(-1, 1))

train_idx, val_idx = train_test_split(
    documents.index, test_size=args.test_size, random_state=args.random_state
)

train_dataset = FBPDataset(documents[train_idx], tags, encoder)  # type:ignore
val_dataset = FBPDataset(documents[val_idx], tags, encoder)  # type:ignore

num_classes = len(label_unique)

100%|██████████| 15594/15594 [00:08<00:00, 1751.77it/s]


In [7]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    return tuple(list(i) for i in zip(*batch))

dl = DataLoader(train_dataset, args.batch_size, collate_fn=collate_fn)

docs, targets, infos = next(iter(dl))

In [8]:
predictions = {
    'pred_logits': torch.rand([args.batch_size, args.num_queries, num_classes + 1]) * torch.Tensor([[1] * num_classes + [10]]),
    'pred_boxes': torch.rand([args.batch_size, args.num_queries, 2]) * torch.Tensor([[1, 0.1]])
}

In [9]:
from models.matcher import HungarianMatcher
from models.criterion import CriterionDETR

matcher = HungarianMatcher()

weight_dict = {'loss_ce': 1., 'loss_bbox': 1., 'loss_giou': 1.}
losses = ['labels', 'boxes', 'cardinality']

criterion = CriterionDETR(num_classes, matcher=matcher, weight_dict=weight_dict,
                            eos_coef=0.1, losses=losses)

In [10]:
criterion(predictions, targets)

{'loss_ce': tensor(4.1917),
 'loss_bbox': tensor(0.1219),
 'loss_giou': tensor(0.8441),
 'cardinality_error': tensor(7.8000)}

In [40]:
from datasets.postprocess import FBPPostProcess

postprocessor = FBPPostProcess(encoder, tags, num_classes)

In [41]:
postprocessor.add_outputs(predictions, infos)

In [42]:
postprocessor.results

Unnamed: 0,id,class,predictionstring,score
0,97E4E42863A3,Rebuttal,272 273 274 275 276 277 278 279 280 281 282 28...,0.201358
1,EB793B72A8C4,Rebuttal,85 86 87,0.167245
2,0119F710D008,Evidence,33 34 35 36 37 38 39 40 41,0.152022
3,0119F710D008,Position,87 88 89 90 91 92 93 94 95 96 97 98 99 100 101...,0.180292
4,F00B4D036D97,Lead,50 51 52 53 54 55 56 57 58 59 60 61 62,0.150565
5,F00B4D036D97,Position,146 147 148 149 150 151 152 153 154 155 156 15...,0.181172
6,F00B4D036D97,Position,117 118 119 120 121 122 123 124 125 126 127 12...,0.15393


In [None]:
postprocessor.evaluate()

In [54]:
import util.visualization as viz

idx = 0

id_example = infos[idx]['id']
doc = docs[idx]

In [55]:
viz.highlight_segments(id_example, doc, tags)

In [56]:
viz.highlight_segments(id_example, doc, postprocessor.results)

In [34]:
postprocessor.reset_results()