In [1]:
import sys
sys.path.append('/workspace/lxmert/src')
from lxrt.modeling import *
import torch
from torch.cuda.amp import GradScaler, autocast
import easydict
from collections import OrderedDict
import pickle as pkl
from torch.utils.data import Dataset,DataLoader
import numpy as np
import os
import os
from tqdm.auto import tqdm
import transformers
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="5"
import time

In [2]:
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
loss_weight = 0.05
iou_threshold = 0.1

In [4]:
ious = pkl.load(open('obj_ious.pkl', 'rb'))

In [5]:
mode = f'lxmert_sequence_encoder_outsegs_obj_aware_{iou_threshold}_frame_seq_contrast_lxmert{loss_weight}_afterlstm'

In [6]:
class segmentDataset(Dataset):
    """Segment dataset."""
    def __init__(self, vidlistpkl, feat_base, cook2_IVD_dir ="/workspace/evidence_retrieval/COOK2_IVD", mode = 'train' ):
        with open(os.path.join(vidlistpkl), "rb") as f:
            vid_pkl = pkl.load(f)
            
        self.vid_list = vid_pkl
        self.feat_base = feat_base
        with open(os.path.join(cook2_IVD_dir, f"pkl/{mode}.pkl"), "rb") as f:
            self.vid_pkl = pkl.load(f)
    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, idx):
        

        vid = self.vid_list[idx]
        with open(os.path.join(self.feat_base, vid+'.pkl'), 'rb') as fp:
                  item = pkl.load(fp)
        visn_feats, encode, temporal_label, label = item
        query = self.vid_pkl[vid]["query"]
        
        
        
        return {"vid": vid, "visn_feats":visn_feats, 'encode':encode, 'temporal_label':temporal_label,'label':label, 'query':query}

In [7]:
import os
import pickle as pkl
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from statistics import mean
from tqdm import tqdm
import csv
import sys
import time
import numpy as np
import base64
cook2_IVD_dir = "/workspace/COOK2_IVD"


class segmentDataset(Dataset):
    """Segment dataset."""
    def __init__(self, vidlistpkl, feat_base, cook2_IVD_dir ="/workspace/COOK2_IVD", mode = 'train' ):
        with open(os.path.join(vidlistpkl), "rb") as f:
            vid_pkl = pkl.load(f)
            
        self.vid_list = vid_pkl
        self.feat_base = feat_base
        with open(os.path.join(cook2_IVD_dir, f"pkl/{mode}.pkl"), "rb") as f:
            self.vid_pkl = pkl.load(f)
    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, idx):
        

        vid = self.vid_list[idx]
        with open(os.path.join(self.feat_base, vid+'.pkl'), 'rb') as fp:
                  item = pkl.load(fp)
        visn_feats,encode,label,(trans_input_ids,trans_token_type_ids,trans_attention_mask ) = item
        query = self.vid_pkl[vid]["query"]
        
        
        
        return {"vid": vid, "visn_feats":visn_feats, 'encode':encode, 'label':label, 'query':query,\
               'trans_input_ids':trans_input_ids, 'trans_token_type_ids':trans_token_type_ids,\
               'trans_attention_mask':trans_attention_mask}

In [8]:
segment_dataset = segmentDataset('newsplit_train_vid_list.pkl','feat_dump_outsegs')
valid_dataset = segmentDataset('newsplit_valid_vid_list.pkl','feat_dump_outsegs', mode = 'vid')
test_dataset = segmentDataset('newsplit_test_vid_list.pkl','feat_dump_outsegs', mode = 'test')

In [9]:
class sequence_frame_encoder(torch.nn.Module):
    def __init__(self, config, sequence_encoder = 'LSTM'):
        super(sequence_frame_encoder, self).__init__()
        self.frame_encoder = LXRTFeatureExtraction(config)
        state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
        state_dict = torch.load(state_dict_path)
        new_state_dict = OrderedDict()
        for key, value in state_dict.items():
            splittedkey = key.split('.')
            if 'bert' in splittedkey:
                newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
            else:
                newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
            new_state_dict[newkey] = value
        self.frame_encoder.load_state_dict(new_state_dict, strict=False)
        del state_dict
        del new_state_dict
        if sequence_encoder == 'LSTM':
            self.sequence_encoder = torch.nn.LSTM(768, hidden_size = 768//2, batch_first = True, bidirectional = True)
            self.sequence_fc = torch.nn.Linear(768, 768)
            self.contrast_fc1 = torch.nn.Linear(768,768)
            self.contrast_fc2 = torch.nn.Linear(768,768)
            
        #transformer later
        #else:
        #    self.sequence_encoder = torch.nn.TransformerEncoderLayer(d_model = 768,nhead = 12,num_encoder_layers = 1,
        #                                                             dim_feedforward=3072, activation == "gelu", batch_first = True )
    def forward(self,visn_feats, trans_input_ids,trans_token_type_ids,trans_attention_mask ):
        frame_feats = self.frame_encoder(trans_input_ids, 
                                        token_type_ids = trans_token_type_ids,
                                        attention_mask = trans_attention_mask,
                                        visual_feats = visn_feats
                                       )
        frame_feats = frame_feats[1]        
        frame_contrast1 = self.contrast_fc1(frame_feats)
        frame_contrast1 = frame_contrast1.squeeze()
        frame_feats = frame_feats.unsqueeze(0)
        frame_feats,(_,_) = self.sequence_encoder(frame_feats)
        frame_contrast2 = self.contrast_fc1(frame_feats)
        frame_contrast2 = frame_contrast2.squeeze()
        frame_contrast = torch.cat([frame_contrast1, frame_contrast2], dim = -1)
        #frame_contrast = self.contrast_fc(frame_feats)
        #frame_contrast = frame_contrast.squeeze()
        frame_feats = self.sequence_fc(frame_feats)
        frame_feats = frame_feats.squeeze()
        return frame_feats, frame_contrast


In [10]:
config = BertConfig('bert_config.json')
"""image_encoder = LXRTImageModel(config)
state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
state_dict = torch.load(state_dict_path)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
    splittedkey = key.split('.')
    if 'bert' in splittedkey:
        newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
    else:
        newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
    new_state_dict[newkey] = value
image_encoder.load_state_dict(new_state_dict, strict=False)
image_encoder.cuda().train()"""
print()




In [11]:
config = BertConfig('bert_config.json')
"""frame_encoder = LXRTFeatureExtraction(config)
state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
state_dict = torch.load(state_dict_path)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
    splittedkey = key.split('.')
    if 'bert' in splittedkey:
        newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
    else:
        newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
    new_state_dict[newkey] = value
frame_encoder.load_state_dict(new_state_dict, strict=False)"""
frame_encoder = sequence_frame_encoder(config)
frame_encoder.cuda()
#frame_encoder = nn.DataParallel(frame_encoder, device_ids=[0,1])
print()

LXRT encoder with 12 l_layers, 5 x_layers, and 0 r_layers.



In [12]:
def nearest_trans_idx(frame_idx, trans_idx):
    return np.argmin(abs(trans_idx-frame_idx))

img_h, img_w = img_info['img_h'], img_info['img_w']
boxes = boxes.copy()
boxes[:, (0, 2)] /= img_w
boxes[:, (1, 3)] /= img_h

In [13]:
from transformers import BertModel, BertTokenizer
lang_encoder = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
lang_encoder.cuda()
#lang_encoder = nn.DataParallel(lang_encoder, device_ids=[0,1])
print()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).





In [14]:
from lxrt.optimization import BertAdam
epoch = 100
batch_per_epoch = len(segment_dataset)
t_total = int(batch_per_epoch * epoch)
warmup_ratio = 0.05
warmup_iters = int(t_total * warmup_ratio)
optim = BertAdam(list(frame_encoder.parameters()) + list(lang_encoder.parameters()), lr=1e-4, warmup=warmup_ratio, t_total=t_total)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
scaler = GradScaler()

In [15]:
mode 

'lxmert_sequence_encoder_outsegs_obj_aware_0.1_frame_seq_contrast_lxmert0.05_afterlstm'

batch = segment_dataset[0]

    vid, visn_feats, encode, label = batch['vid'], batch["visn_feats"],batch['encode'],batch['label']
    frame_feats, box_feats = visn_feats
    foodname = batch['query']
    trans_input_ids,trans_token_type_ids,trans_attention_mask  = batch['trans_input_ids'],batch['trans_token_type_ids'],batch['trans_attention_mask']
    label = label.cuda()
    visn_feats = frame_feats.cuda(), box_feats.cuda()
    trans_input_ids = trans_input_ids.cuda()
    trans_token_type_ids = trans_token_type_ids.cuda()
    trans_attention_mask = trans_attention_mask.cuda()
    encode = encode.cuda()
    frame_feats = frame_encoder(visn_feats, trans_input_ids, 
                                        trans_token_type_ids,trans_attention_mask
                                       )

In [16]:
floss_fn = torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, pos_weight=None)

cos = nn.CosineSimilarity(dim=0, eps=1e-6)

contrast_feat = frame_feats[1]
frame_contrast_output = torch.matmul(contrast_feat, contrast_feat.transpose(1,0))

frame_contrast_output = m(frame_contrast_output)

label

contrast_label = []
for i in range(len(label)):
    c_label = []
    for j in range(len(label)):
        c_label.append(int(label[i]==label[j]) if label[i]>-0.5 else 0)
    contrast_label.append(c_label)
contrast_label = torch.tensor(contrast_label).float().cuda()

loss_output = loss(frame_contrast_output, contrast_label)
loss_output = loss_output*mask

torch.mean(loss_output)

mask = (label>-0.5).float()
mask = torch.matmul(mask.unsqueeze(-1), mask.unsqueeze(0))

In [17]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(mode)

In [18]:
def retrieval_forward(batch, frame_encoder, text_encoder, compute_loss = True):
    vid, visn_feats, encode, label = batch['vid'], batch["visn_feats"],batch['encode'],batch['label']
    frame_feats, box_feats = visn_feats
    foodname = batch['query']
    trans_input_ids,trans_token_type_ids,trans_attention_mask  = batch['trans_input_ids'],batch['trans_token_type_ids'],batch['trans_attention_mask']
    label = label.cuda()
    visn_feats = frame_feats.cuda(), box_feats.cuda()
    trans_input_ids = trans_input_ids.cuda()
    trans_token_type_ids = trans_token_type_ids.cuda()
    trans_attention_mask = trans_attention_mask.cuda()
    encode = encode.cuda()
    frame_feats,contrast_feat = frame_encoder(visn_feats, trans_input_ids, 
                                        trans_token_type_ids,trans_attention_mask
                                       )
    pooled_output = frame_feats
    output = lang_encoder(encode)
    sequence_output, lang_pooled_output = output[0], output[1]
    dotoutput = torch.matmul(pooled_output, lang_pooled_output.transpose(1,0))
    if compute_loss:
        loss = loss_fn(dotoutput, label)
        frame_contrast_output = torch.matmul(contrast_feat, contrast_feat.transpose(1,0))
        #frame_contrast_output = sigmoid(frame_contrast_output)
        contrast_label = []
        for i in range(len(label)):
            c_label = []
            for j in range(len(label)):
                c_label.append(int(label[i]==label[j]) if label[i]>-0.5 else 0)
            contrast_label.append(c_label)
        contrast_label = torch.tensor(contrast_label).float().cuda()
        mask = (label>-0.5).float()
        mask = torch.matmul(mask.unsqueeze(-1), mask.unsqueeze(0))
        fcontrast_loss = floss_fn(frame_contrast_output, contrast_label)
        fcontrast_loss = torch.mean(fcontrast_loss*mask) 
        return dotoutput, loss+loss_weight*fcontrast_loss
    else:
        return dotoutput

In [19]:
min_loss = -10000
sampling_num = 3
for epoch in tqdm(range(epoch)):
    frame_encoder.train()
    lang_encoder.train()
    for iter_num in range(len(segment_dataset)):
        with autocast():
            btime_b = time.time()
            batch = segment_dataset[iter_num]
            #print(batch['encode'].shape,batch['trans_input_ids'].shape )
            optim.zero_grad()
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder,compute_loss = True)
            writer.add_scalar('Loss/train',loss.detach().cpu(), epoch*len(segment_dataset)+iter_num)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(list(frame_encoder.parameters()) + list(lang_encoder.parameters()),1 )
        scaler.step(optim)
        scaler.update()
        torch.cuda.empty_cache()
    with torch.no_grad():
        frame_encoder.eval()
        lang_encoder.eval()
        ks = [1,3,5]
        recalls = {k:[] for k in ks}
        losses = 0
        for iter_num in range(len(valid_dataset)):
            batch = valid_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder,  lang_encoder,compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
        losses = losses/len(valid_dataset)
        writer.add_scalar('Loss/valid',losses.cpu(), epoch)
        print(epoch, losses)
        for k in ks:
            print( k, sum(recalls[k])/len(recalls[k]))
            writer.add_scalar(f'Recall/valid{k}',sum(recalls[k])/len(recalls[k]), epoch)
        if min_loss<sum(recalls[1])/len(recalls[1]):
                    min_loss = sum(recalls[1])/len(recalls[1])
                    torch.save(lang_encoder.state_dict(), os.path.join(f'{mode}','lang_encoder_ori_best_eval_loss.pth'.format(epoch)))
                    torch.save(frame_encoder.state_dict(), os.path.join(f'{mode}','image_encoder_ori_best_eval_loss.pth'.format(epoch)))
        for iter_num in range(len(test_dataset)):
            batch = test_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
        losses = losses/len(valid_dataset)
        writer.add_scalar('Loss/test',losses.cpu(), epoch)
        print(epoch, losses)
        for k in ks:
                print( k, sum(recalls[k])/len(recalls[k]))
                writer.add_scalar(f'Recall/test{k}',sum(recalls[k])/len(recalls[k]), epoch)
            

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1005.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


0 tensor(2.1175, device='cuda:0')
1 0.20694159842593385
3 0.514609537812975
5 0.7585385703036757


  1%|          | 1/100 [10:30<17:19:58, 630.29s/it]

0 tensor(0.9259, device='cuda:0')
1 0.20466441984772848
3 0.5143639917360101
5 0.7520634952620809
1 tensor(2.0310, device='cuda:0')
1 0.2082063267045154
3 0.5195893505458673
5 0.7579717974406486


  2%|▏         | 2/100 [20:34<16:56:42, 622.48s/it]

1 tensor(0.8917, device='cuda:0')
1 0.20628507306084912
3 0.5201230840621814
5 0.7522193625963607
2 tensor(1.9600, device='cuda:0')
1 0.24049195687746133
3 0.5653026288254843
5 0.7869000431715297


  3%|▎         | 3/100 [30:28<16:32:42, 614.04s/it]

2 tensor(0.8640, device='cuda:0')
1 0.23605196350657975
3 0.5629100720094304
5 0.7837535120715301
3 tensor(1.9493, device='cuda:0')
1 0.25172556898247894
3 0.5817355074114319
5 0.7940252790893059


  4%|▍         | 4/100 [40:28<16:15:26, 609.65s/it]

3 tensor(0.8564, device='cuda:0')
1 0.2500360517125342
3 0.5785440874917205
5 0.7901529061668044
4 tensor(2.0001, device='cuda:0')
1 0.2392252978174203
3 0.5774464582257667
5 0.7919402400537794


  5%|▌         | 5/100 [50:30<16:01:53, 607.51s/it]

4 tensor(0.8710, device='cuda:0')
1 0.23645347539956582
3 0.5742984795764003
5 0.7859390742811536
5 tensor(2.0136, device='cuda:0')
1 0.24360791910957233
3 0.5775358319397665
5 0.7871938065234442


  6%|▌         | 6/100 [1:00:33<15:49:31, 606.08s/it]

5 tensor(0.8700, device='cuda:0')
1 0.2453712820394766
3 0.5763374465362523
5 0.7843978904760028
6 tensor(1.9450, device='cuda:0')
1 0.26698829357129983
3 0.6129203013991367
5 0.8093373505408915


  7%|▋         | 7/100 [1:10:36<15:38:01, 605.17s/it]

6 tensor(0.8437, device='cuda:0')
1 0.2684415220214363
3 0.6110146353400537
5 0.8059063012627548
7 tensor(1.9362, device='cuda:0')
1 0.27930707399633226
3 0.6116129174038388
5 0.8131591398021185


  8%|▊         | 8/100 [1:20:46<15:29:55, 606.47s/it]

7 tensor(0.8336, device='cuda:0')
1 0.28367420346309274
3 0.6094697507487229
5 0.8120020861471179
8 tensor(1.6855, device='cuda:0')
1 0.37033517926729964
3 0.715708783683573
5 0.8831355614005585


  9%|▉         | 9/100 [1:31:01<15:23:45, 609.07s/it]

8 tensor(0.7464, device='cuda:0')
1 0.36690290396788466
3 0.7114904041172907
5 0.8804583163360896
9 tensor(1.6483, device='cuda:0')
1 0.3880374530324281
3 0.7241545906651615
5 0.8920258154659487


 10%|█         | 10/100 [1:41:05<15:11:37, 607.75s/it]

9 tensor(0.7397, device='cuda:0')
1 0.38177169053444454
3 0.7192169625915692
5 0.8869535273284179
10 tensor(1.6291, device='cuda:0')
1 0.3947318437697512
3 0.7382552775626723
5 0.8990326945909276


 11%|█         | 11/100 [1:51:07<14:58:39, 605.84s/it]

10 tensor(0.7305, device='cuda:0')
1 0.39155876698094794
3 0.7318654797958741
5 0.8919560603336075
11 tensor(1.6445, device='cuda:0')
1 0.41041946562441284
3 0.742580761562188
5 0.9000340248827639


 12%|█▏        | 12/100 [2:01:19<14:51:22, 607.76s/it]

11 tensor(0.7311, device='cuda:0')
1 0.41119422026053415
3 0.7404683057447594
5 0.894606624580435
12 tensor(1.5717, device='cuda:0')
1 0.4348045067286645
3 0.7723800793460442
5 0.9130223976574717


 13%|█▎        | 13/100 [2:11:18<14:37:33, 605.21s/it]

12 tensor(0.6982, device='cuda:0')
1 0.43658313639161245
3 0.7684763212593345
5 0.908731312222036
13 tensor(1.5616, device='cuda:0')
1 0.4513161452034621
3 0.7836999356419204
5 0.9170170800862744


 14%|█▍        | 14/100 [2:21:24<14:27:51, 605.49s/it]

13 tensor(0.6905, device='cuda:0')
1 0.4514216354463979
3 0.7794357963485105
5 0.9126181796312296
14 tensor(1.5547, device='cuda:0')
1 0.45782564914729706
3 0.7912451870483646
5 0.9229265391590019


 15%|█▌        | 15/100 [2:31:36<14:20:20, 607.31s/it]

14 tensor(0.6936, device='cuda:0')
1 0.4571098230363188
3 0.7870257105976051
5 0.9194229683167122
15 tensor(1.6173, device='cuda:0')
1 0.4546639485283191
3 0.7852608622378315
5 0.9224860036490339


 16%|█▌        | 16/100 [2:41:31<14:04:57, 603.55s/it]

15 tensor(0.7005, device='cuda:0')
1 0.4561126095505195
3 0.7812528150762793
5 0.9187255739927089
16 tensor(1.6253, device='cuda:0')
1 0.4469118255111159
3 0.7877770940381738
5 0.9202246586782719


 17%|█▋        | 17/100 [2:51:44<13:59:06, 606.59s/it]

16 tensor(0.7227, device='cuda:0')
1 0.4472266806327615
3 0.7844189809414353
5 0.916230430742059
17 tensor(1.5630, device='cuda:0')
1 0.4688647918306575
3 0.8002058303881738
5 0.9252667196299013


 18%|█▊        | 18/100 [3:01:45<13:46:38, 604.86s/it]

17 tensor(0.6810, device='cuda:0')
1 0.4694586024929155
3 0.7940984547730489
5 0.9217548454400796
18 tensor(1.5328, device='cuda:0')
1 0.4795174476588643
3 0.807953163062
5 0.9327709414446901


 19%|█▉        | 19/100 [3:11:51<13:36:54, 605.12s/it]

18 tensor(0.6709, device='cuda:0')
1 0.4793281861107256
3 0.8021809282153244
5 0.9282379865622766
19 tensor(1.5720, device='cuda:0')
1 0.4776284483768212
3 0.8067916642108975
5 0.9299556356783012


 20%|██        | 20/100 [3:21:44<13:21:56, 601.45s/it]

19 tensor(0.6992, device='cuda:0')
1 0.48003553160198104
3 0.7993580718599657
5 0.9252823136269627
20 tensor(1.6072, device='cuda:0')
1 0.48207122437413236
3 0.810305288393491
5 0.9328633250638761


 21%|██        | 21/100 [3:31:40<13:09:54, 599.93s/it]

20 tensor(0.7041, device='cuda:0')
1 0.48066306697085626
3 0.8025088093382454
5 0.9279415416396665
21 tensor(1.5229, device='cuda:0')
1 0.4916953919787311
3 0.8191676850235609
5 0.9381041673365218


 22%|██▏       | 22/100 [3:41:38<12:58:54, 599.16s/it]

21 tensor(0.6730, device='cuda:0')
1 0.4892287248474594
3 0.8121790891842565
5 0.9330479706988388
22 tensor(1.7072, device='cuda:0')
1 0.4677569341634666
3 0.7972406616303027
5 0.9267872765285696


 23%|██▎       | 23/100 [3:51:43<12:51:29, 601.16s/it]

22 tensor(0.7609, device='cuda:0')
1 0.46836050951662034
3 0.7899393353863113
5 0.9206814657255875
23 tensor(1.6559, device='cuda:0')
1 0.4941123712349781
3 0.8144331416434322
5 0.9352885548108179


 24%|██▍       | 24/100 [4:01:53<12:44:41, 603.70s/it]

23 tensor(0.7233, device='cuda:0')
1 0.4926870266361109
3 0.8075658024069864
5 0.9300686861134665
24 tensor(1.7270, device='cuda:0')
1 0.4831611498515635
3 0.8070343639438814
5 0.9308445092785222


 25%|██▌       | 25/100 [4:11:47<12:30:53, 600.71s/it]

24 tensor(0.7455, device='cuda:0')
1 0.4838860377010345
3 0.8009028669184343
5 0.9266952205965676
25 tensor(1.9279, device='cuda:0')
1 0.4750010130299782
3 0.8028778862487261
5 0.9229944167373042


 26%|██▌       | 26/100 [4:21:45<12:19:48, 599.85s/it]

25 tensor(0.8351, device='cuda:0')
1 0.47614224361580654
3 0.7965488422696639
5 0.9187126956071966
26 tensor(1.6440, device='cuda:0')
1 0.4994630408946878
3 0.8268740651473839
5 0.940319007813096


 27%|██▋       | 27/100 [4:31:35<12:06:12, 596.89s/it]

26 tensor(0.7206, device='cuda:0')
1 0.5005517980218981
3 0.8201927920093531
5 0.9361761667162168
27 tensor(1.6698, device='cuda:0')
1 0.49913743212323514
3 0.8272823152688235
5 0.9400829820957264


 28%|██▊       | 28/100 [4:41:35<11:57:21, 597.80s/it]

27 tensor(0.7227, device='cuda:0')
1 0.5018106601681355
3 0.8226987475437129
5 0.9370479808592923
28 tensor(1.6414, device='cuda:0')
1 0.5149898507283281
3 0.8319167180458723
5 0.9421595970571562


 29%|██▉       | 29/100 [4:51:41<11:50:19, 600.27s/it]

28 tensor(0.7098, device='cuda:0')
1 0.5168239433095241
3 0.8268545528859138
5 0.9395612966708791
29 tensor(1.6443, device='cuda:0')
1 0.5090484008255254
3 0.8334668962553585
5 0.9451820854279456


 30%|███       | 30/100 [5:01:48<11:42:38, 602.27s/it]

29 tensor(0.7291, device='cuda:0')
1 0.5129772370749908
3 0.8265357740054043
5 0.9411775906247424
30 tensor(1.6809, device='cuda:0')
1 0.5098016984338395
3 0.8333297645503037
5 0.9413082691240096


 31%|███       | 31/100 [5:11:42<11:29:49, 599.84s/it]

30 tensor(0.7287, device='cuda:0')
1 0.5132070313711158
3 0.8267624824210953
5 0.938917644779468
31 tensor(1.7404, device='cuda:0')
1 0.50666319554748
3 0.830198550303667
5 0.9408007285295767


 32%|███▏      | 32/100 [5:21:39<11:19:01, 599.14s/it]

31 tensor(0.7463, device='cuda:0')
1 0.5116055472636061
3 0.825679247990975
5 0.9391601574391978
32 tensor(1.7833, device='cuda:0')
1 0.5073745761407915
3 0.8315087603854119
5 0.9418586798992118


 33%|███▎      | 33/100 [5:31:39<11:09:24, 599.47s/it]

32 tensor(0.7827, device='cuda:0')
1 0.5099672655132507
3 0.8251609796653916
5 0.9392770745344253
33 tensor(1.6586, device='cuda:0')
1 0.5248945031701341
3 0.8399884697798732
5 0.9486623004386394


 34%|███▍      | 34/100 [5:41:36<10:58:35, 598.72s/it]

33 tensor(0.7219, device='cuda:0')
1 0.5253663607914023
3 0.8335281802888547
5 0.9447273958857229
34 tensor(1.7258, device='cuda:0')
1 0.5174912088037231
3 0.8361257609411965
5 0.9467693970739591


 35%|███▌      | 35/100 [5:51:46<10:52:16, 602.09s/it]

34 tensor(0.7656, device='cuda:0')
1 0.5179363591011608
3 0.8291176398545941
5 0.9418614034381647
35 tensor(1.7453, device='cuda:0')
1 0.514821865881076
3 0.8372959981779085
5 0.9473461362994692


 36%|███▌      | 36/100 [6:01:53<10:43:31, 603.30s/it]

35 tensor(0.7717, device='cuda:0')
1 0.5184162544037434
3 0.8305280242960345
5 0.9435837561809712
36 tensor(1.7640, device='cuda:0')
1 0.5186504863553711
3 0.8355525663090441
5 0.942779878177891


 37%|███▋      | 37/100 [6:11:54<10:32:46, 602.65s/it]

36 tensor(0.7842, device='cuda:0')
1 0.5217409426524784
3 0.8297374582512649
5 0.939821455980639
37 tensor(1.7710, device='cuda:0')
1 0.5242256111964608
3 0.8398491929191062
5 0.9472677470199474


 38%|███▊      | 38/100 [6:22:00<10:23:49, 603.69s/it]

37 tensor(0.7665, device='cuda:0')
1 0.5271210724114146
3 0.8332277882557247
5 0.9434132150877883
38 tensor(1.8377, device='cuda:0')
1 0.5143324244013063
3 0.8342647149279151
5 0.9449902677330079


 39%|███▉      | 39/100 [6:32:18<10:18:06, 607.98s/it]

38 tensor(0.7974, device='cuda:0')
1 0.5173515352107675
3 0.8283629158439167
5 0.9416170494521768
39 tensor(1.8571, device='cuda:0')
1 0.5119872167539429
3 0.8375013336457101
5 0.9485847144307058


 40%|████      | 40/100 [6:42:14<10:04:24, 604.41s/it]

39 tensor(0.8180, device='cuda:0')
1 0.5164608455002676
3 0.8318993472157412
5 0.9440732146222032
40 tensor(1.8737, device='cuda:0')
1 0.5150544366814412
3 0.8344156950978107
5 0.9463876755798066


 41%|████      | 41/100 [6:52:21<9:55:17, 605.37s/it] 

40 tensor(0.8204, device='cuda:0')
1 0.5170886217383608
3 0.8282711725000946
5 0.9424981242269277
41 tensor(1.8485, device='cuda:0')
1 0.5185782154022524
3 0.8387359838483766
5 0.9484746597087497


 42%|████▏     | 42/100 [7:02:35<9:47:25, 607.69s/it]

41 tensor(0.8007, device='cuda:0')
1 0.5242627476620346
3 0.8341470283064829
5 0.9453248935274743
42 tensor(1.8487, device='cuda:0')
1 0.5234004954493674
3 0.8394322803552519
5 0.9492250372086849


 43%|████▎     | 43/100 [7:12:33<9:34:38, 604.89s/it]

42 tensor(0.8163, device='cuda:0')
1 0.525970342673874
3 0.8352606956136852
5 0.9460584974261388
43 tensor(1.8798, device='cuda:0')
1 0.529223901704969
3 0.8406760063357374
5 0.9480864648275625


 44%|████▍     | 44/100 [7:22:40<9:25:09, 605.53s/it]

43 tensor(0.8200, device='cuda:0')
1 0.5304447941976499
3 0.8357151466670952
5 0.9453392187001336
44 tensor(1.8518, device='cuda:0')
1 0.5300031613640456
3 0.8431077632723167
5 0.9481589931940294


 45%|████▌     | 45/100 [7:32:43<9:14:24, 604.82s/it]

44 tensor(0.8014, device='cuda:0')
1 0.5312423729958244
3 0.8384776425669004
5 0.9459278948616201
45 tensor(1.8366, device='cuda:0')
1 0.5339484918849869
3 0.8434565062303951
5 0.9496388335344792


 46%|████▌     | 46/100 [7:42:48<9:04:14, 604.71s/it]

45 tensor(0.7960, device='cuda:0')
1 0.5342227392656573
3 0.8397504357592477
5 0.9462411226714
46 tensor(1.9254, device='cuda:0')
1 0.5284766630127865
3 0.840237415137712
5 0.9476332188886081


 47%|████▋     | 47/100 [7:52:48<8:53:09, 603.57s/it]

46 tensor(0.8424, device='cuda:0')
1 0.5311321335911295
3 0.8360613823239156
5 0.9450967723218833
47 tensor(1.9434, device='cuda:0')
1 0.5299701327160447
3 0.8434956603213885
5 0.9483855815381375


 48%|████▊     | 48/100 [8:02:50<8:42:28, 602.85s/it]

47 tensor(0.8424, device='cuda:0')
1 0.5315170727743089
3 0.8380245247807866
5 0.9442138661470203
48 tensor(1.8920, device='cuda:0')
1 0.5294087180658723
3 0.8432264612176045
5 0.9503057121663409


 49%|████▉     | 49/100 [8:12:57<8:33:33, 604.18s/it]

48 tensor(0.8256, device='cuda:0')
1 0.5306687843061404
3 0.838222962435314
5 0.9465431544565192
49 tensor(1.9203, device='cuda:0')
1 0.5305926211618984
3 0.8442897143039155
5 0.9518573172305281


 50%|█████     | 50/100 [8:22:59<8:22:56, 603.53s/it]

49 tensor(0.8349, device='cuda:0')
1 0.5313223138702936
3 0.8391588333889706
5 0.9475222311519748
50 tensor(2.0713, device='cuda:0')
1 0.5184296867203232
3 0.8379576781304419
5 0.9455935839595632


 51%|█████     | 51/100 [8:33:05<8:13:27, 604.24s/it]

50 tensor(0.8854, device='cuda:0')
1 0.5202378863851477
3 0.8322878510594008
5 0.9425286272303937
51 tensor(2.0225, device='cuda:0')
1 0.5295277988738148
3 0.8432833867090076
5 0.9471131344835089


 52%|█████▏    | 52/100 [8:43:19<8:05:50, 607.30s/it]

51 tensor(0.8636, device='cuda:0')
1 0.5314266405845289
3 0.8382727419292532
5 0.9446242759074515
52 tensor(2.0301, device='cuda:0')
1 0.5294440357005815
3 0.843202080011352
5 0.9487758223044775


 53%|█████▎    | 53/100 [8:53:03<7:50:11, 600.24s/it]

52 tensor(0.8910, device='cuda:0')
1 0.5280331966228916
3 0.8375634885790542
5 0.944744547327722
53 tensor(2.0407, device='cuda:0')
1 0.5246711561746183
3 0.8408325870557604
5 0.9487933468201332


 54%|█████▍    | 54/100 [9:03:18<7:43:36, 604.72s/it]

53 tensor(0.8829, device='cuda:0')
1 0.5256805894748802
3 0.8347800071758258
5 0.946192527411218
54 tensor(2.0852, device='cuda:0')
1 0.5254580093937588
3 0.8432794404436774
5 0.9491016230423773


 55%|█████▌    | 55/100 [9:13:08<7:30:04, 600.10s/it]

54 tensor(0.9088, device='cuda:0')
1 0.5269260366787479
3 0.8380488601707237
5 0.9465821971947554
55 tensor(2.0489, device='cuda:0')
1 0.5325787066298873
3 0.8457702170138993
5 0.9488744296390417


 56%|█████▌    | 56/100 [9:22:56<7:17:36, 596.74s/it]

55 tensor(0.8860, device='cuda:0')
1 0.5333019106417952
3 0.8410960416525068
5 0.9464677416713404
56 tensor(2.0523, device='cuda:0')
1 0.5334411228200089
3 0.8454232749523205
5 0.9484629099249228


 57%|█████▋    | 57/100 [9:33:00<7:09:07, 598.79s/it]

56 tensor(0.8646, device='cuda:0')
1 0.5362990994943312
3 0.8404585592023218
5 0.9463693244116533
57 tensor(2.0554, device='cuda:0')
1 0.5334829653625734
3 0.844942472618665
5 0.9488474149103424


 58%|█████▊    | 58/100 [9:42:59<6:59:18, 599.00s/it]

57 tensor(0.9173, device='cuda:0')
1 0.5332312789110875
3 0.8394721678778412
5 0.9452713918326346
58 tensor(2.0403, device='cuda:0')
1 0.5381188873007637
3 0.8477569863240243
5 0.9516443420717365


 59%|█████▉    | 59/100 [9:53:03<6:50:13, 600.33s/it]

58 tensor(0.8837, device='cuda:0')
1 0.54005965376917
3 0.8430789644980811
5 0.9483804628312116
59 tensor(2.1183, device='cuda:0')
1 0.5280808577056018
3 0.8440910563691875
5 0.9497370989268578


 60%|██████    | 60/100 [10:02:50<6:37:35, 596.38s/it]

59 tensor(0.9154, device='cuda:0')
1 0.5313354852706015
3 0.8378432009684305
5 0.9467502882335255
60 tensor(2.1038, device='cuda:0')
1 0.5386175086865995
3 0.8465477859644579
5 0.9522938520073302


 61%|██████    | 61/100 [10:12:43<6:26:53, 595.22s/it]

60 tensor(0.9046, device='cuda:0')
1 0.5396191760211065
3 0.8406092579608181
5 0.9483493411557196
61 tensor(2.1308, device='cuda:0')
1 0.5283832319995154
3 0.8442198085544194
5 0.950637444611991


 62%|██████▏   | 62/100 [10:22:36<6:16:41, 594.77s/it]

61 tensor(0.9289, device='cuda:0')
1 0.5295784166948768
3 0.8382034638502305
5 0.9471638600638884
62 tensor(2.1815, device='cuda:0')
1 0.534794803162816
3 0.8400493043795327
5 0.9472611464296


 63%|██████▎   | 63/100 [10:32:18<6:04:16, 590.71s/it]

62 tensor(0.9092, device='cuda:0')
1 0.5351633514442642
3 0.8366712683223989
5 0.9450996691249072
63 tensor(2.1597, device='cuda:0')
1 0.5326380413440998
3 0.84573524273191
5 0.949751829563204


 64%|██████▍   | 64/100 [10:42:14<5:55:27, 592.42s/it]

63 tensor(0.9481, device='cuda:0')
1 0.533369565826474
3 0.8395625329020169
5 0.9469811496175012
64 tensor(2.1566, device='cuda:0')
1 0.5335790402294159
3 0.8470515575033215
5 0.9522832630206759


 65%|██████▌   | 65/100 [10:52:01<5:44:35, 590.74s/it]

64 tensor(0.9254, device='cuda:0')
1 0.5348943610908031
3 0.8414505412914781
5 0.9493624391919039
65 tensor(2.2353, device='cuda:0')
1 0.5306745739423153
3 0.8406673352746057
5 0.9493801928360733


 66%|██████▌   | 66/100 [11:02:12<5:38:10, 596.78s/it]

65 tensor(0.9596, device='cuda:0')
1 0.5309052807706376
3 0.8358645566319819
5 0.9469555741981508
66 tensor(2.2385, device='cuda:0')
1 0.5287361767959496
3 0.8437308418283601
5 0.9494721145619197


 67%|██████▋   | 67/100 [11:12:02<5:27:11, 594.90s/it]

66 tensor(0.9793, device='cuda:0')
1 0.5314869593910027
3 0.8383023886672372
5 0.9459844228835325
67 tensor(2.2953, device='cuda:0')
1 0.5316967347230782
3 0.8415898690056485
5 0.9489302785403683


 68%|██████▊   | 68/100 [11:22:03<5:18:15, 596.73s/it]

67 tensor(0.9917, device='cuda:0')
1 0.5320516431191749
3 0.8351166423069817
5 0.945458616428121
68 tensor(2.2981, device='cuda:0')
1 0.5272665011636609
3 0.8447326587425752
5 0.9485638121800489


 69%|██████▉   | 69/100 [11:32:11<5:10:01, 600.05s/it]

68 tensor(1.0022, device='cuda:0')
1 0.52670347476714
3 0.8387335420485242
5 0.9450415247762072
69 tensor(2.3015, device='cuda:0')
1 0.5335270674728636
3 0.8448126638918065
5 0.95074298669603


 70%|███████   | 70/100 [11:42:16<5:00:49, 601.64s/it]

69 tensor(0.9962, device='cuda:0')
1 0.5343532110485448
3 0.8386888417420133
5 0.9481185323466577
70 tensor(2.2714, device='cuda:0')
1 0.5338718610798409
3 0.8465981274387221
5 0.9496927413562484


 71%|███████   | 71/100 [11:52:12<4:49:59, 599.98s/it]

70 tensor(0.9707, device='cuda:0')
1 0.5362125186991398
3 0.8402452751458667
5 0.9474055679134439
71 tensor(2.3486, device='cuda:0')
1 0.5271615010935291
3 0.8410014321323339
5 0.9474540631432513


 72%|███████▏  | 72/100 [12:02:06<4:39:03, 598.00s/it]

71 tensor(1.0234, device='cuda:0')
1 0.5290113084780957
3 0.8333760952383943
5 0.9437608006736259
72 tensor(2.3428, device='cuda:0')
1 0.533531572705756
3 0.8449308640717075
5 0.9506918631954162


 73%|███████▎  | 73/100 [12:12:08<4:29:38, 599.20s/it]

72 tensor(0.9946, device='cuda:0')
1 0.5356363333066613
3 0.8383759344056922
5 0.9472540698276329
73 tensor(2.3751, device='cuda:0')
1 0.5315388187734198
3 0.8445304974739998
5 0.949785324186827


 74%|███████▍  | 74/100 [12:22:01<4:18:54, 597.47s/it]

73 tensor(1.0325, device='cuda:0')
1 0.5338010762788435
3 0.8382127029158606
5 0.9466413865990181
74 tensor(2.2682, device='cuda:0')
1 0.5438275893871111
3 0.8484187622462889
5 0.951223390658105


 75%|███████▌  | 75/100 [12:31:53<4:08:14, 595.80s/it]

74 tensor(0.9610, device='cuda:0')
1 0.546606305778391
3 0.8441268180157449
5 0.9489243926048698
75 tensor(2.2850, device='cuda:0')
1 0.5422712503253888
3 0.8473077655728136
5 0.95168948561866


 76%|███████▌  | 76/100 [12:42:03<4:00:00, 600.03s/it]

75 tensor(0.9699, device='cuda:0')
1 0.5445976340900538
3 0.8425377237436831
5 0.9489852995897875
76 tensor(2.3412, device='cuda:0')
1 0.5344925398288233
3 0.8450299510550848
5 0.9506251182586539


 77%|███████▋  | 77/100 [12:52:05<3:50:13, 600.59s/it]

76 tensor(1.0007, device='cuda:0')
1 0.5369880584844737
3 0.8400821074455977
5 0.9474008289256424
77 tensor(2.4211, device='cuda:0')
1 0.5279363457241494
3 0.8418639803033443
5 0.9478316741208521


 78%|███████▊  | 78/100 [13:02:00<3:39:36, 598.93s/it]

77 tensor(1.0471, device='cuda:0')
1 0.5302795847757301
3 0.836386067095959
5 0.9447929595790674
78 tensor(2.3023, device='cuda:0')
1 0.5403743239938174
3 0.84867044470419
5 0.9508184266742137


 79%|███████▉  | 79/100 [13:11:56<3:29:21, 598.15s/it]

78 tensor(0.9777, device='cuda:0')
1 0.5431753797878158
3 0.8444293489219091
5 0.9480752228450452
79 tensor(2.3701, device='cuda:0')
1 0.537283554234919
3 0.8464845574636387
5 0.9498011216823008


 80%|████████  | 80/100 [13:22:03<3:20:15, 600.78s/it]

79 tensor(1.0138, device='cuda:0')
1 0.5374347558489655
3 0.8402678479752852
5 0.9477683571704354
80 tensor(2.3586, device='cuda:0')
1 0.5357566536655737
3 0.8483824654239303
5 0.9504658084070141


 81%|████████  | 81/100 [13:31:49<3:08:49, 596.26s/it]

80 tensor(1.0127, device='cuda:0')
1 0.5369082714915223
3 0.8419627494974354
5 0.9481434477272757
81 tensor(2.3268, device='cuda:0')
1 0.542991283394292
3 0.8504037621423464
5 0.9512304968402568


 82%|████████▏ | 82/100 [13:41:51<2:59:22, 597.92s/it]

81 tensor(0.9759, device='cuda:0')
1 0.545694769968513
3 0.8464556165517931
5 0.9491881350303615
82 tensor(2.3292, device='cuda:0')
1 0.5458090593548683
3 0.8503668991790128
5 0.9523341717449947


 83%|████████▎ | 83/100 [13:51:55<2:49:58, 599.89s/it]

82 tensor(0.9768, device='cuda:0')
1 0.548853697448889
3 0.8481121132055564
5 0.9494718858597846
83 tensor(2.3887, device='cuda:0')
1 0.5384048211728022
3 0.8480862191735841
5 0.9520020249178328


 84%|████████▍ | 84/100 [14:01:57<2:40:07, 600.49s/it]

83 tensor(1.0234, device='cuda:0')
1 0.5417634709823541
3 0.8421734883613083
5 0.9486588397112398
84 tensor(2.4028, device='cuda:0')
1 0.5391031323185239
3 0.8512230771225543
5 0.9501306515920763


 85%|████████▌ | 85/100 [14:11:58<2:30:07, 600.51s/it]

84 tensor(1.0212, device='cuda:0')
1 0.5419223809981568
3 0.8448088406593275
5 0.9478378294907354
85 tensor(2.3895, device='cuda:0')
1 0.5435493250653113
3 0.8490018887117742
5 0.9522727856578636


 86%|████████▌ | 86/100 [14:21:59<2:20:09, 600.71s/it]

85 tensor(1.0040, device='cuda:0')
1 0.5461995015249694
3 0.843236923043388
5 0.949569794006442
86 tensor(2.4424, device='cuda:0')
1 0.5372127281340182
3 0.8453268305509432
5 0.9494791693676995


 87%|████████▋ | 87/100 [14:32:01<2:10:14, 601.09s/it]

86 tensor(1.0570, device='cuda:0')
1 0.5387015039387463
3 0.8380058358482094
5 0.946639989112318
87 tensor(2.4032, device='cuda:0')
1 0.5469940390712593
3 0.8498261813771631
5 0.9506606227519763


 88%|████████▊ | 88/100 [14:42:01<2:00:08, 600.72s/it]

87 tensor(1.0067, device='cuda:0')
1 0.549116261100001
3 0.84540283517169
5 0.9483777121060091
88 tensor(2.4026, device='cuda:0')
1 0.5472351548150144
3 0.8481717163616909
5 0.950945620118884


 89%|████████▉ | 89/100 [14:51:59<1:49:59, 599.95s/it]

88 tensor(1.0036, device='cuda:0')
1 0.5486423990310264
3 0.8439834479729368
5 0.9485551984839483
89 tensor(2.4100, device='cuda:0')
1 0.5482915167685709
3 0.848395720343706
5 0.9507278269874186


 90%|█████████ | 90/100 [15:02:03<1:40:12, 601.27s/it]

89 tensor(1.0013, device='cuda:0')
1 0.5502476367643584
3 0.8458868325454124
5 0.9481948063076462
90 tensor(2.4121, device='cuda:0')
1 0.5474462761520095
3 0.8493744415560656
5 0.950427119465698


 91%|█████████ | 91/100 [15:12:03<1:30:06, 600.74s/it]

90 tensor(1.0010, device='cuda:0')
1 0.5501252471507865
3 0.8459562507474443
5 0.9479568136864148
91 tensor(2.4114, device='cuda:0')
1 0.5495512755325198
3 0.8515263154173973
5 0.9512850551612152


 92%|█████████▏| 92/100 [15:21:59<1:19:54, 599.29s/it]

91 tensor(1.0098, device='cuda:0')
1 0.5505769445596633
3 0.8461637607488719
5 0.9492526497554754
92 tensor(2.4531, device='cuda:0')
1 0.5392924242757287
3 0.8466223462341074
5 0.9492304242418149


 93%|█████████▎| 93/100 [15:31:57<1:09:52, 598.93s/it]

92 tensor(1.0284, device='cuda:0')
1 0.5429112947672134
3 0.8433514645337001
5 0.9469085209252722
93 tensor(2.4357, device='cuda:0')
1 0.5439961280405085
3 0.8484465479926507
5 0.9495975307566891


 94%|█████████▍| 94/100 [15:41:59<59:59, 599.85s/it]  

93 tensor(1.0148, device='cuda:0')
1 0.5455697154152515
3 0.8461003074929715
5 0.9478659108607734
94 tensor(2.4352, device='cuda:0')
1 0.5480412107715795
3 0.8493603041648103
5 0.9503864434410275


 95%|█████████▌| 95/100 [15:52:00<50:01, 600.27s/it]

94 tensor(1.0108, device='cuda:0')
1 0.5497958091087441
3 0.8466504426937441
5 0.9485712930845158
95 tensor(2.4282, device='cuda:0')
1 0.5474631222286959
3 0.8502781171361289
5 0.9507039419646085


 96%|█████████▌| 96/100 [16:02:03<40:04, 601.01s/it]

95 tensor(1.0064, device='cuda:0')
1 0.549626620861687
3 0.8468032666735019
5 0.9488294166549676
96 tensor(2.4362, device='cuda:0')
1 0.5466429302422986
3 0.8521169023566549
5 0.9512917185831767


 97%|█████████▋| 97/100 [16:11:49<29:49, 596.50s/it]

96 tensor(1.0098, device='cuda:0')
1 0.5489671031861177
3 0.848556116376153
5 0.9491090971911539
97 tensor(2.4399, device='cuda:0')
1 0.548401322049544
3 0.8503050998927814
5 0.9512196278612922


 98%|█████████▊| 98/100 [16:21:47<19:53, 596.94s/it]

97 tensor(1.0136, device='cuda:0')
1 0.5507528628403177
3 0.8471134020125959
5 0.948892423852609
98 tensor(2.4407, device='cuda:0')
1 0.5466890203175988
3 0.8510863200680615
5 0.9506865601021182


 99%|█████████▉| 99/100 [16:31:41<09:56, 596.32s/it]

98 tensor(1.0175, device='cuda:0')
1 0.5486819297363498
3 0.8476043406935041
5 0.9485449167326047
99 tensor(2.4421, device='cuda:0')
1 0.5479518428952148
3 0.8502466512196935
5 0.9499247653510022


100%|██████████| 100/100 [16:41:33<00:00, 600.94s/it]

99 tensor(1.0205, device='cuda:0')
1 0.5501137483362526
3 0.8473539130170674
5 0.9477644358913689





In [19]:
vidpkl = pkl.load(open(os.path.join('COOK2_IVD','pkl','vid.pkl'), 'rb'))

In [20]:
frame_encoder_state_dict = torch.load(os.path.join(mode, 'image_encoder_ori_best_eval_loss.pth'))
frame_encoder.load_state_dict(frame_encoder_state_dict)
lang_encoder_state_dict = torch.load(os.path.join(mode, 'lang_encoder_ori_best_eval_loss.pth'))
lang_encoder.load_state_dict(lang_encoder_state_dict)

<All keys matched successfully>

In [22]:
with torch.no_grad():
        frame_encoder.eval()
        lang_encoder.eval()
        ks = [1,3,5]
        recalls = {k:[] for k in ks}
        losses = 0
        for iter_num in range(len(valid_dataset)):
            batch = valid_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder,epoch, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
        losses = losses/len(valid_dataset)
        writer.add_scalar('Loss/valid',losses.cpu(), epoch)
        print(epoch, losses)
        for k in ks:
            print( k, sum(recalls[k])/len(recalls[k]))
        for iter_num in range(len(test_dataset)):
            batch = test_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder,epoch, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
        losses = losses/len(valid_dataset)
        print(epoch, losses)
        for k in ks:
                print( k, sum(recalls[k])/len(recalls[k]))
            

99 tensor(2.7031, device='cuda:0')
1 0.5551631858077444
3 0.8598434427384077
5 0.9522972897459667
99 tensor(1.1041, device='cuda:0')
1 0.5587508731452381
3 0.8555428180545263
5 0.9489275610512035


In [20]:
retrieval_pred = {}
with torch.no_grad():
        frame_encoder.eval()
        lang_encoder.eval()
        ks = [1,3,5]
        recalls = {k:[] for k in ks}
        losses = 0
        for iter_num in tqdm(range(len(test_dataset))):
            vid_pred = {}
            batch = test_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                vid_pred[2*i+1] = pred[i]
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
            retrieval_pred[batch['vid']] = vid_pred
        losses = losses/len(valid_dataset)
        print(epoch, losses)
        for k in ks:
                print( k, sum(recalls[k])/len(recalls[k]))
with torch.no_grad():
        frame_encoder.eval()
        lang_encoder.eval()
        ks = [1,3,5]
        recalls = {k:[] for k in ks}
        losses = 0
        for iter_num in tqdm(range(len(valid_dataset))):
            vid_pred = {}
            batch = valid_dataset[iter_num]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                vid_pred[2*i+1] = pred[i]
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
            retrieval_pred[batch['vid']] = vid_pred
        losses = losses/len(valid_dataset)
        print(epoch, losses)
        for k in ks:
                print( k, sum(recalls[k])/len(recalls[k]))
            

100%|██████████| 135/135 [00:37<00:00,  3.64it/s]
  0%|          | 0/312 [00:00<?, ?it/s]

100 tensor(1.0021, device='cuda:0')
1 0.552947379644617
3 0.8337703010705011
5 0.9445555350399873


100%|██████████| 312/312 [01:19<00:00,  3.90it/s]

100 tensor(2.4114, device='cuda:0')
1 0.5495512755325198
3 0.8515263154173973
5 0.9512850551612152





In [24]:
iou_threshold=0.05

In [25]:
retrieval_pred = {}
retrieval_label = {}
with torch.no_grad():
        frame_encoder.eval()
        lang_encoder.eval()
        ks = [1,3,5]
        recalls = {k:[] for k in ks}
        losses = 0
        for iter_num in tqdm(range(len(test_dataset))):
            vid_pred = {}
            vid_label = {}
            batch = test_dataset[iter_num]
            vid = batch['vid']
            vid_iou = ious[vid]
            dotoutput, loss = retrieval_forward(batch, frame_encoder, lang_encoder, compute_loss = True)
            losses += loss
            label = list(batch['label'].numpy())
            vid_labels = list(set(label))
            vid_labels.remove(-1)
            hard_segs = []
            for i in vid_labels:
                if i==0:
                    if vid_iou[(0,1)]<iou_threshold:
                        hard_segs.append(i)
                elif i == max(vid_labels):
                    if vid_iou[(i-1,i)]<iou_threshold:
                        hard_segs.append(i)
                else:
                    if vid_iou[(i-1,i)]<iou_threshold and vid_iou[(i,i+1)]<iou_threshold:
                        hard_segs.append(i)
            for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if gt not in hard_segs:
                                continue
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                vid_pred[2*i+1] = pred[i]
                                vid_label[2*i+1] = gt
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        if examples<1:
                            continue
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
            retrieval_pred[batch['vid']] = vid_pred
            retrieval_label[batch['vid']] = vid_label
        losses = losses/len(valid_dataset)
        print(epoch, losses)
        for k in ks:
                print( k, sum(recalls[k])/len(recalls[k]))

100%|██████████| 135/135 [00:38<00:00,  3.53it/s]

100 tensor(1.0021, device='cuda:0')
1 0.5516363098202698
3 0.8204740007479437
5 0.9387999926629





In [21]:
pkl.dump(retrieval_pred, open('procedure_contrast_retrieval_result.json', 'wb'))

99 tensor(2.5399, device='cuda:0')
1 0.5500848235384761
3 0.8597046990109899
5 0.9546420315049812
99 tensor(1.0750, device='cuda:0')
1 0.5493910851437317
3 0.8543261536011475
5 0.9503941514540208
class sequence_frame_encoder(torch.nn.Module):
    def __init__(self, config, sequence_encoder = 'LSTM'):
        super(sequence_frame_encoder, self).__init__()
        self.frame_encoder = LXRTFeatureExtraction(config)
        state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
        state_dict = torch.load(state_dict_path)
        new_state_dict = OrderedDict()
        for key, value in state_dict.items():
            splittedkey = key.split('.')
            if 'bert' in splittedkey:
                newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
            else:
                newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
            new_state_dict[newkey] = value
        self.frame_encoder.load_state_dict(new_state_dict, strict=False)
        del state_dict
        del new_state_dict
        if sequence_encoder == 'LSTM':
            self.sequence_encoder = torch.nn.LSTM(768, hidden_size = 768//2, batch_first = True, bidirectional = True)
            self.sequence_fc = torch.nn.Linear(768, 768)
            self.contrast_fc = torch.nn.Linear(768,768)
            
        #transformer later
        #else:
        #    self.sequence_encoder = torch.nn.TransformerEncoderLayer(d_model = 768,nhead = 12,num_encoder_layers = 1,
        #                                                             dim_feedforward=3072, activation == "gelu", batch_first = True )
    def forward(self,visn_feats, trans_input_ids,trans_token_type_ids,trans_attention_mask ):
        frame_feats = self.frame_encoder(trans_input_ids, 
                                        token_type_ids = trans_token_type_ids,
                                        attention_mask = trans_attention_mask,
                                        visual_feats = visn_feats
                                       )
        frame_feats = frame_feats[1]        
        #frame_contrast = self.contrast_fc(frame_feats)
        #frame_contrast = frame_contrast.squeeze()
        frame_feats = frame_feats.unsqueeze(0)
        frame_feats,(_,_) = self.sequence_encoder(frame_feats)
        frame_contrast = self.contrast_fc(frame_feats)
        frame_contrast = frame_contrast.squeeze()
        frame_feats = self.sequence_fc(frame_feats)
        frame_feats = frame_feats.squeeze()
        return frame_feats, frame_contrast
99 tensor(2.3990, device='cuda:0')
1 0.5509897188919297
3 0.8555678469627827
5 0.9494686560239055
99 tensor(1.0209, device='cuda:0')
1 0.5510184906637344
3 0.8487851879782635
5 0.9473827653403941
class sequence_frame_encoder(torch.nn.Module):
    def __init__(self, config, sequence_encoder = 'LSTM'):
        super(sequence_frame_encoder, self).__init__()
        self.frame_encoder = LXRTFeatureExtraction(config)
        state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
        state_dict = torch.load(state_dict_path)
        new_state_dict = OrderedDict()
        for key, value in state_dict.items():
            splittedkey = key.split('.')
            if 'bert' in splittedkey:
                newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
            else:
                newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
            new_state_dict[newkey] = value
        self.frame_encoder.load_state_dict(new_state_dict, strict=False)
        del state_dict
        del new_state_dict
        if sequence_encoder == 'LSTM':
            self.sequence_encoder = torch.nn.LSTM(768, hidden_size = 768//2, batch_first = True, bidirectional = True)
            self.sequence_fc = torch.nn.Linear(768, 768)
            self.contrast_fc = torch.nn.Linear(768,768)
            
        #transformer later
        #else:
        #    self.sequence_encoder = torch.nn.TransformerEncoderLayer(d_model = 768,nhead = 12,num_encoder_layers = 1,
        #                                                             dim_feedforward=3072, activation == "gelu", batch_first = True )
    def forward(self,visn_feats, trans_input_ids,trans_token_type_ids,trans_attention_mask ):
        frame_feats = self.frame_encoder(trans_input_ids, 
                                        token_type_ids = trans_token_type_ids,
                                        attention_mask = trans_attention_mask,
                                        visual_feats = visn_feats
                                       )
        frame_feats = frame_feats[1]        
        frame_contrast = self.contrast_fc(frame_feats)
        frame_contrast = frame_contrast.squeeze()
        frame_feats = frame_feats.unsqueeze(0)
        frame_feats,(_,_) = self.sequence_encoder(frame_feats)
        #frame_contrast = self.contrast_fc(frame_feats)
        #frame_contrast = frame_contrast.squeeze()
        frame_feats = self.sequence_fc(frame_feats)
        frame_feats = frame_feats.squeeze()
        return frame_feats, frame_contrast

In [24]:
with open('yc2_recipes.json', 'r') as fp:
    recipes = json.load(fp)

In [25]:
with open('vid2frame_indices.pkl', 'rb') as fp:
    frame_indices = pkl.load(fp)

In [24]:

#encode = encode.cuda()

In [26]:
torch.cuda.empty_cache()

In [41]:

recipe_retrieval_result = {}
test_retrieval_result_all = {}
with torch.no_grad():
    with autocast(): 
        for iter_num in range(len(segment_dataset)):
            batch = segment_dataset[iter_num]
            vid, visn_feats, encode, label = batch['vid'], batch["visn_feats"],batch['encode'],batch['label']
            frame_feats, box_feats = visn_feats
            foodname = batch['query']
            vid_frame_indices = frame_indices[vid]
            trans_input_ids,trans_token_type_ids,trans_attention_mask  = batch['trans_input_ids'],batch['trans_token_type_ids'],batch['trans_attention_mask']
            label = label.cuda()
            visn_feats = frame_feats.cuda(), box_feats.cuda()
            trans_input_ids = trans_input_ids.cuda()
            trans_token_type_ids = trans_token_type_ids.cuda()
            trans_attention_mask = trans_attention_mask.cuda()
            vid_recipe = recipes[foodname]
            #recipe_logit[vid] = {}
            flatten_recipe = []
            for recipe in vid_recipe:
                for recipe_doc in recipe['split_ins']:
                    if recipe_doc!= '':
                        flatten_recipe.append(recipe_doc)
            encode = tokenizer.batch_encode_plus(flatten_recipe, return_tensors = 'pt',padding = True)
            frame_feats,_ = frame_encoder(visn_feats, trans_input_ids, 
                                                    trans_token_type_ids,trans_attention_mask
                                                   )
            pooled_output = frame_feats
            output = lang_encoder(input_ids = encode['input_ids'].cuda(), token_type_ids = encode['token_type_ids'].cuda(),
                                  attention_mask = encode['attention_mask'].cuda())
            sequence_output, lang_pooled_output = output[0], output[1]
            dotoutput = torch.matmul(pooled_output, lang_pooled_output.transpose(1,0))
            retrieval_result = torch.topk(dotoutput,k=3, dim = -1)[1].tolist()
            vid_retrieval_result = {vid_frame_indices[i]:[flatten_recipe[j] for j in retrieval_result[i]] for i in range(len(vid_frame_indices))}
            recipe_retrieval_result[vid] = vid_retrieval_result
        for iter_num in range(len(valid_dataset)):
            batch = valid_dataset[iter_num]
            vid, visn_feats, encode, label = batch['vid'], batch["visn_feats"],batch['encode'],batch['label']
            frame_feats, box_feats = visn_feats
            foodname = batch['query']
            vid_frame_indices = frame_indices[vid]
            trans_input_ids,trans_token_type_ids,trans_attention_mask  = batch['trans_input_ids'],batch['trans_token_type_ids'],batch['trans_attention_mask']
            label = label.cuda()
            visn_feats = frame_feats.cuda(), box_feats.cuda()
            trans_input_ids = trans_input_ids.cuda()
            trans_token_type_ids = trans_token_type_ids.cuda()
            trans_attention_mask = trans_attention_mask.cuda()
            vid_recipe = recipes[foodname]
            #recipe_logit[vid] = {}
            flatten_recipe = []
            for recipe in vid_recipe:
                for recipe_doc in recipe['split_ins']:
                    if recipe_doc!= '':
                        flatten_recipe.append(recipe_doc)
            encode = tokenizer.batch_encode_plus(flatten_recipe, return_tensors = 'pt',padding = True)
            frame_feats,_ = frame_encoder(visn_feats, trans_input_ids, 
                                                    trans_token_type_ids,trans_attention_mask
                                                   )
            pooled_output = frame_feats
            output = lang_encoder(input_ids = encode['input_ids'].cuda(), token_type_ids = encode['token_type_ids'].cuda(),
                                  attention_mask = encode['attention_mask'].cuda())
            sequence_output, lang_pooled_output = output[0], output[1]
            dotoutput = torch.matmul(pooled_output, lang_pooled_output.transpose(1,0))
            retrieval_result = torch.topk(dotoutput,k=3, dim = -1)[1].tolist()
            vid_retrieval_result = {vid_frame_indices[i]:[flatten_recipe[j] for j in retrieval_result[i]] for i in range(len(vid_frame_indices))}
            recipe_retrieval_result[vid] = vid_retrieval_result
        for iter_num in range(len(test_dataset)):
            batch = test_dataset[iter_num]
            vid, visn_feats, encode, label = batch['vid'], batch["visn_feats"],batch['encode'],batch['label']
            frame_feats, box_feats = visn_feats
            foodname = batch['query']
            vid_frame_indices = frame_indices[vid]
            trans_input_ids,trans_token_type_ids,trans_attention_mask  = batch['trans_input_ids'],batch['trans_token_type_ids'],batch['trans_attention_mask']
            label = label.cuda()
            visn_feats = frame_feats.cuda(), box_feats.cuda()
            trans_input_ids = trans_input_ids.cuda()
            trans_token_type_ids = trans_token_type_ids.cuda()
            trans_attention_mask = trans_attention_mask.cuda()
            vid_recipe = recipes[foodname]
            #recipe_logit[vid] = {}
            flatten_recipe = []
            for recipe in vid_recipe:
                for recipe_doc in recipe['split_ins']:
                    if recipe_doc!= '':
                        flatten_recipe.append(recipe_doc)
            encode = tokenizer.batch_encode_plus(flatten_recipe, return_tensors = 'pt',padding = True)
            frame_feats,_ = frame_encoder(visn_feats, trans_input_ids, 
                                                    trans_token_type_ids,trans_attention_mask
                                                   )
            pooled_output = frame_feats
            output = lang_encoder(input_ids = encode['input_ids'].cuda(), token_type_ids = encode['token_type_ids'].cuda(),
                                  attention_mask = encode['attention_mask'].cuda())
            sequence_output, lang_pooled_output = output[0], output[1]
            dotoutput = torch.matmul(pooled_output, lang_pooled_output.transpose(1,0))
            retrieval_result = torch.topk(dotoutput,k=3, dim = -1)[1].tolist()
            vid_retrieval_result = {vid_frame_indices[i]:[flatten_recipe[j] for j in retrieval_result[i]] for i in range(len(vid_frame_indices))}
            recipe_retrieval_result[vid] = vid_retrieval_result
            retrieval_result_all = np.argsort(-1*dotoutput.detach().cpu().numpy())
            vid_retrieval_result_all = {vid_frame_indices[i]:[flatten_recipe[j] for j in retrieval_result_all[i]] for i in range(len(vid_frame_indices))}
            test_retrieval_result_all[vid] = vid_retrieval_result_all

In [42]:
json.dump(test_retrieval_result_all, open('lxmert_sequence_procedure_contrast_recipe_test_retrieval_result_all.json', 'w'))

In [35]:
retrieval_result

[[178, 59, 79],
 [2, 79, 82],
 [2, 79, 82],
 [2, 60, 79],
 [2, 60, 0],
 [2, 60, 82],
 [60, 2, 27],
 [0, 27, 177],
 [0, 174, 60],
 [0, 174, 60],
 [60, 0, 174],
 [60, 0, 179],
 [12, 60, 179],
 [179, 12, 60],
 [179, 12, 108],
 [108, 12, 166],
 [179, 108, 12],
 [67, 108, 46],
 [67, 46, 81],
 [67, 104, 7],
 [104, 179, 68],
 [179, 104, 68],
 [0, 179, 174],
 [174, 0, 12],
 [179, 104, 174],
 [179, 104, 81],
 [12, 67, 60],
 [12, 67, 60],
 [12, 38, 179],
 [179, 38, 12],
 [60, 38, 12],
 [60, 38, 179],
 [60, 179, 165],
 [179, 38, 104],
 [38, 12, 81],
 [67, 12, 38],
 [81, 46, 116],
 [81, 46, 42],
 [116, 46, 131],
 [46, 116, 131],
 [46, 131, 42],
 [46, 81, 131],
 [81, 46, 144],
 [81, 144, 6],
 [144, 97, 87],
 [86, 169, 160],
 [9, 47, 57],
 [9, 47, 77],
 [47, 9, 87],
 [87, 29, 47],
 [29, 87, 53],
 [29, 99, 53],
 [18, 120, 53],
 [109, 184, 48],
 [48, 184, 109],
 [184, 48, 120],
 [184, 48, 109],
 [184, 48, 120],
 [184, 120, 109],
 [30, 135, 184],
 [135, 66, 30],
 [135, 66, 38],
 [135, 38, 66],
 [135, 6

In [40]:
np.argsort(-1*dotoutput.detach().cpu().numpy())

array([[178,  59,  79, ..., 107, 109, 172],
       [  2,  79,  82, ...,  85,  53, 109],
       [  2,  79, 177, ...,  85,  53, 109],
       ...,
       [135,  66,  38, ..., 130, 156, 164],
       [135,  38,  66, ..., 130, 156, 164],
       [135,  66,  38, ..., 164,  85, 156]])

In [None]:
dotoutput

In [39]:
dotoutput.detach().cpu().numpy()

array([[  8.7   ,   3.46  ,  10.68  , ...,  -0.415 ,   0.2246,  -4.71  ],
       [ 13.164 ,   5.08  ,  16.03  , ...,  -0.7393,   0.4844,  -7.074 ],
       [ 17.67  ,   4.477 ,  21.44  , ...,  -1.004 ,  -2.32  , -11.25  ],
       ...,
       [  0.4785,   3.332 ,  -0.2866, ...,   0.0984,   1.104 ,   5.062 ],
       [  3.555 ,   3.045 ,   3.95  , ...,   0.1044,   0.124 ,   2.402 ],
       [  4.914 ,   0.5605,   6.5   , ...,   0.7183,  -1.6045,   1.834 ]],
      dtype=float16)

In [28]:
json.dump(recipe_retrieval_result, open('lxmert_sequence_procedure_contrast_recipe_retrieval_result.json', 'w'))

In [29]:
pkl.dump(recipe_retrieval_result, open('lxmert_sequence_procedure_contrast_recipe_retrieval_result.pkl', 'wb'))

In [52]:
encode.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

Help on method forward in module transformers.models.bert.modeling_bert:

forward(input_ids: Union[torch.Tensor, NoneType] = None, attention_mask: Union[torch.Tensor, NoneType] = None, token_type_ids: Union[torch.Tensor, NoneType] = None, position_ids: Union[torch.Tensor, NoneType] = None, head_mask: Union[torch.Tensor, NoneType] = None, inputs_embeds: Union[torch.Tensor, NoneType] = None, encoder_hidden_states: Union[torch.Tensor, NoneType] = None, encoder_attention_mask: Union[torch.Tensor, NoneType] = None, past_key_values: Union[List[torch.FloatTensor], NoneType] = None, use_cache: Union[bool, NoneType] = None, output_attentions: Union[bool, NoneType] = None, output_hidden_states: Union[bool, NoneType] = None, return_dict: Union[bool, NoneType] = None) -> Union[Tuple[torch.Tensor], transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions] method of transformers.models.bert.modeling_bert.BertModel instance
    The [`BertModel`] forward method, overrides the `__call

In [62]:
len(dotoutput), len(vid_frame_indices)

(64, 64)

In [22]:
vid = batch['vid']

In [28]:
vidpkl[vid]['segments']

[(43.0, 55.0, 'add the cabbage and water to a pan'),
 (62.0, 76.0, 'add cream and butter to a pot'),
 (112.0, 120.0, 'add the cabbage to the pot of potatos'),
 (120.0, 127.0, 'add the green onions to the pot'),
 (127.0, 143.0, 'add the bacon and pepper to the pot'),
 (143.0, 151.0, 'mash the ingredients in the pot')]