In [3]:
import time
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

from datasets import build_fdb_data, collate_fn
from models import build_models
from main import get_args_parser

args = get_args_parser().parse_args([
    '--device', 'cpu',
    ])
device = torch.device(args.device)

In [4]:
print("Loading Dataset...")

dataset_train, dataset_val, postprocessor, num_classes = build_fdb_data(args)

print("Dataset loaded")
print("Loading Models...")

tokenizer, model, criterion = build_models(num_classes, args)
model.to(device)

print("Models Loaded")

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)

Loading Dataset...


100%|██████████| 15594/15594 [00:02<00:00, 6703.92it/s]


Dataset loaded
Loading Models...


Some weights of the model checkpoint at allenai/led-base-16384 were not used when initializing LEDModel: ['lm_head.weight', 'final_logits_bias']
- This IS expected if you are initializing LEDModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LEDModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Models Loaded
number of params: 422922


In [None]:
from models.detr import DETR, PrepareInputs
from models.matcher import HungarianMatcher
from models.criterion import CriterionDETR
from transformers import LEDModel, LEDTokenizerFast  # type: ignore

model = DETR(
    model=LEDModel.from_pretrained("allenai/led-base-16384"),
    num_classes=num_classes,
    hidden_dim=args.hidden_dim,
    num_queries=args.num_queries,
)

model.set_transformer_trainable(args.train_transformer)

tokenizer = PrepareInputs(
    tokenizer=LEDTokenizerFast.from_pretrained("allenai/led-base-16384")
)

matcher = HungarianMatcher(
    cost_class=args.set_cost_class,
    cost_bbox=args.set_cost_bbox,
    cost_giou=args.set_cost_giou,
)

weight_dict = {
    "loss_ce": 1,
    "loss_bbox": args.bbox_loss_coef,
    "loss_giou": args.giou_loss_coef,
}

losses = ["labels", "boxes", "cardinality"]
criterion = CriterionDETR(
    num_classes=num_classes,
    matcher=matcher,
    weight_dict=weight_dict,
    eos_coef=args.eos_coef,
    losses=losses,
)
criterion.to(device)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

data_loader_train = DataLoader(
    dataset_train,
    shuffle=True,
    batch_size=args.batch_size,
    collate_fn=collate_fn,
    num_workers=args.num_workers,
)
data_loader_val = DataLoader(
    dataset_val,
    shuffle=False,
    batch_size=args.batch_size,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=args.num_workers,
)

In [None]:
model.train()
criterion.train()

loss_list = []
data_bar = tqdm(data_loader_train, desc=f"Train Epoch {0:4d}")
for samples, targets, info in data_bar:
    st = time.time()

    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    outputs = []
    for doc in samples:
        inputs = tokenizer([doc]).to(device)
        outputs.append(model(inputs))

    batch_outputs = {
        key: torch.cat([o[key] for o in outputs]) for key in outputs[0].keys()
    }

    loss_dict = criterion(batch_outputs, targets)  # type: Dict[str, torch.Tensor]

    mt = time.time()

    weight_dict = criterion.weight_dict
    losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)  # type: ignore

    loss_dict_unscaled = {f"{k}_unscaled": v for k, v in loss_dict.items()}
    loss_dict_scaled = {
        k: v * weight_dict[k] for k, v in loss_dict.items() if k in weight_dict
    }
    losses_scaled = sum(loss_dict_scaled.values())  # type: ignore

    loss_value = losses_scaled.item()  # type: ignore

    optimizer.zero_grad()
    losses.backward()  # type: ignore
    optimizer.step()

    ot = time.time()

    loss_list.append(losses.item())  # type: ignore
    data_bar.set_postfix(
        {
            "lr": optimizer.param_groups[0]["lr"],
            "loss": sum(loss_list) / len(loss_list),
            "model time": f"{mt - st:.2f} s",
            "optim time": f"{ot - mt:.2f} s",
        }
    )