In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from video_moment_retrieval.utils.logging import init_logging, logger
from video_moment_retrieval.testing.modeling_dd import DoubleDecoderModel, DDDataset
from video_moment_retrieval.testing.configuration_dd import  DoubleDecoderConfig 
from transformers import TrainingArguments, Trainer, PreTrainedTokenizerFast
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import torch

init_logging()

In [6]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=Tokenizer.from_file("tokenizer.json"),
    pad_token="[PAD]",
    unk_token="[UNK]"
)
train_dataset = DDDataset("../../qvhighlights_features/highlight_train_release.jsonl", "../../qvhighlights_features/bert_features", "../../qvhighlights_features/resnet_features", target_tokenizer=tokenizer)
eval_dataset = DDDataset("../../qvhighlights_features/highlight_val_release.jsonl", "../../qvhighlights_features/bert_features", "../../qvhighlights_features/resnet_features", target_tokenizer=tokenizer)

config = DoubleDecoderConfig(
    len(tokenizer.vocab),
    tokenizer.pad_token_id,
)
# logger.info("Running model using config %s", config)

model = DoubleDecoderModel(config)

train_args = TrainingArguments(
    "./train_output/double_decoder_2",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=500,
    num_train_epochs=200,
    logging_steps=200,
    save_steps=200,
    eval_strategy="steps",
    eval_steps=200,
    load_best_model_at_end=True,
    greater_is_better=False,
    max_grad_norm=0.1,
    label_names=["labels"],
    weight_decay=1e-1,
    eval_do_concat_batches=False,
    dataloader_num_workers=2,
)

trainer = Trainer(
    model=model,
    args=train_args,
    data_collator=train_dataset.collate,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss
200,3.5355,2.746895
400,2.6247,2.403987
600,2.3551,2.183485
800,2.1634,2.066316
1000,2.0519,2.002259




: 

In [93]:
dd_model = DoubleDecoderModel.from_pretrained("./train_output/double_decoder/checkpoint-10000")

In [120]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=Tokenizer.from_file("./tokenizer.json"),
    pad_token="[PAD]",
    unk_token="[UNK]"
)
train_dataset = DDDataset("../../qvhighlights_features/highlight_val_release.jsonl", "../../qvhighlights_features/bert_features", "../../qvhighlights_features/resnet_features", target_tokenizer=tokenizer)


In [121]:
loader = DataLoader(train_dataset, batch_size=1, collate_fn=train_dataset.collate)

In [96]:
it = iter(loader)

In [112]:
batch = next(it)

In [113]:
labels = tokenizer.decode(batch["labels"][0])
labels

'[ 14 48 ] ] </s>'

In [114]:
batch["decoder_input_ids"] = batch["decoder_input_ids"][:, :1]
batch["decoder_attention_mask"] = batch["decoder_attention_mask"][:, :1]

In [115]:
del batch["labels"]

In [122]:
batch["decoder_input_ids"] = torch.tensor([[12]], dtype=torch.long)
batch["decoder_attention_mask"] = torch.tensor([[1]], dtype=torch.long)
if "labels" in batch:
    del batch["labels"]

i = 0
for batch in loader:
    i += 1
    if i > 20:
        break
    labels = tokenizer.decode(batch["labels"][0])
    del batch["labels"]
    batch["decoder_input_ids"] = torch.tensor([[12]], dtype=torch.long)
    batch["decoder_attention_mask"] = torch.tensor([[1]], dtype=torch.long)
    while batch["decoder_input_ids"][0,-1].item() != tokenizer.vocab["</s>"]:
        outputs = dd_model(**batch)
        batch["decoder_input_ids"] = torch.cat([batch["decoder_input_ids"], torch.argmax(outputs[0], dim=-1)[..., -1][None, :]] , dim=-1)
        batch["decoder_attention_mask"] = torch.cat([batch["decoder_attention_mask"], torch.tensor([[1]], dtype=torch.long) ], dim=-1)
        
        # print(batch["decoder_input_ids"])
        # print(batch["decoder_attention_mask"])
    print(tokenizer.decode(batch["decoder_input_ids"][0]))
    print(labels)
    print()

[ [ 18 70 ] ] </s>
[ 82 150 ] ] </s>

[ [ 58 104 ] ] </s>
[ 118 136 ] ] </s>

[ [ 22 40 ] ] </s>
[ 56 76 ] [ 96 150 ] ] </s>

[ [ 0 114 ] ] </s>
[ 36 60 ] ] </s>

[ [ 92 118 ] ] </s>
[ 78 92 ] ] </s>

[ [ 0 150 ] ] </s>
[ 0 74 ] [ 76 142 ] [ 144 150 ] ] </s>

[ [ 80 150 ] ] </s>
[ 44 136 ] ] </s>

[ [ 0 16 ] ] </s>
[ 10 108 ] ] </s>

[ [ 0 22 ] [ 28 64 ] ] </s>
[ 0 16 ] ] </s>

[ [ 50 60 ] ] </s>
[ 0 144 ] ] </s>

[ [ 24 46 ] [ 62 78 ] ] </s>
[ 6 20 ] ] </s>

[ [ 0 18 ] [ 22 32 ] ] </s>
[ 0 2 ] [ 52 74 ] [ 84 92 ] ] </s>

[ [ 0 92 ] ] </s>
[ 52 88 ] [ 96 98 ] ] </s>

[ [ 82 44 ] ] </s>
[ 34 62 ] ] </s>

[ [ 0 128 ] [ 140 150 ] ] </s>
[ 26 46 ] ] </s>

[ [ 92 124 ] ] </s>
[ 106 122 ] ] </s>

[ [ 46 122 ] ] </s>
[ 40 90 ] ] </s>

[ [ 114 124 ] ] </s>
[ 0 46 ] ] </s>

[ [ 114 118 ] [ 120 122 ] ] </s>
[ 0 14 ] ] </s>

[ [ 44 84 ] ] </s>
[ 32 54 ] ] </s>

