In [None]:
from citation_mathcer import ExperimentConfig, Experiment
import numpy as np
from torch.utils.data import DataLoader, Dataset



config = ExperimentConfig(
    project_name="citation-matching",
    model_name="FacebookAI/roberta-base",
    run_name=None,
    checkpoint_dir="./checkpoints",
    checkpoint_every=100,
    seed=42,
    collate_sample_size=None,
    batch_size=270,
    initial_logit_scale=np.log(1/0.07),
    train_ratio=.9,
    learning_rate=1e-4,
    logits_learning_rate=1e-2,
    max_grad_norm=0.5
)

experiment = Experiment(config)
results = experiment.get_results(cache_path='./cache/tokenized_1caf5def_eb27a5477eaa3d549aebc4886f3717d1.pt')

# Train from scratch
trained_model = train_citation_model(experiment, results)

# # Or resume from checkpoint
# config.resume_from = "checkpoints/citation-matching/feasible-pine-64/checkpoint-step-60.pt"
# experiment = Experiment(config)
# trained_model = train_citation_model(experiment, results)

In [None]:
tokenizer = experiment.get_tokenizer()
collated = collate(results, tokenizer, config)
dataset = CitationDataset(collated)

# Create train/val split
indices = np.arange(len(dataset))
train_size = int(len(dataset) * config.train_ratio)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

from torch.utils.data import Subset
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# Create dataloaders
generator = torch.Generator()
generator.manual_seed(config.seed + 0)




In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
    collate_fn=citation_collate_fn,
    generator=generator
)
for bi,batch in enumerate(dataloader):
    assert (batch['source_ids']==experiment.config.cite_token_id).sum()==batch['cited_art_ids'].shape[0]  # special cite tokens correspond to the cited article ids
    assert (batch['cited_art_ids'].shape[0]==batch['labels'].shape[0])  # each cited article id has a corresponding target label
    assert (batch['target_ids']==experiment.config.ref_token_id).sum()==batch['target_art_ids'].shape[0]  # special ref tokens correspond to the target article ids

In [None]:
(batch['target_ids']==experiment.config.ref_token_id).sum(), batch['target_art_ids'].shape[0], batch['labels'].shape  # special ref tokens correspond to the target article ids


In [None]:
torch.where((batch['target_ids']==experiment.config.ref_token_id).sum(dim=1)==2)

In [None]:
(batch['target_ids'][728,:]==experiment.config.ref_token_id)*1