In [1]:
import logging
import json
import torch
from types import SimpleNamespace
from vilbert.vilbert import VILBertActionGrounding, BertConfig
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
import torch.distributed as dist
from VLN_config import config as args






In [2]:
import sys
import os
import torch
import yaml

import numpy as np
import matplotlib.pyplot as plt
import PIL

from PIL import Image
import cv2
import argparse
import glob
import pdb

import torchvision.models as models
import torchvision.transforms as transforms

from faster_rcnn import feature_extractor_new as f_extractor
from faster_rcnn.feature_extractor_new import featureExtractor
#%matplotlib inline  

In [3]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)
print(args)

namespace(adam_epsilon=1e-08, baseline=False, bert_model='bert-base-uncased', best_features=10, clean_train_sets=True, config_file='config/bert_base_6layer_6conect.json', distributed=False, do_lower_case=True, dynamic_attention=False, from_pretrained='save/multitask_model/multi_task_model.bin', gradient_accumulation_steps=1, img_weight=1, in_memory=False, learning_rate=0.0001, local_rank=-1, max_temporal_memory_buffer=3, mean_layer=True, num_train_epochs=10.0, num_workers=0, objective=1, predict_feature=False, save_name='', seed=42, split='mteval', start_epoch=0, task_specific_tokens=True, tasks='1', threshold_similarity=0.7, track_temporal_features=True, train_batch_size=1, visual_target=0, warmup_proportion=0.1, without_coattention=False)


In [4]:
config = BertConfig.from_json_file(args.config_file)
bert_weight_name = json.load(
    open("config/" + args.bert_model + "_weight_name.json", "r")
)

tokenizer = BertTokenizer.from_pretrained(
    args.bert_model, do_lower_case=args.do_lower_case
)

config.track_temporal_features = args.track_temporal_features
config.mean_layer = args.mean_layer
config.max_temporal_memory_buffer = args.max_temporal_memory_buffer


model = VILBertActionGrounding.from_pretrained(
    args.from_pretrained, config=config, default_gpu=True
)

07/20/2020 18:41:01 - INFO - pytorch_transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/mikel/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
07/20/2020 18:41:01 - INFO - vilbert.utils -   loading weights file save/multitask_model/multi_task_model.bin
07/20/2020 18:41:07 - INFO - vilbert.utils -   Weights of VILBertActionGrounding not initialized from pretrained model: ['positional_enc.weight', 'positional_enc.bias', 'img_emb_mean.weight', 'img_emb_mean.bias', 'action_cls.predictions.bias', 'action_cls.predictions.transform.dense.weight', 'action_cls.predictions.transform.dense.bias', 'action_cls.predictions.transform.LayerNorm.weight', 'action_cls.predictions.transform.LayerNorm.bias', 'action_cls.predictions.decoder.weight', 'action_cls.bi_seq_relationship.weight', 'action_

In [5]:
from copy import deepcopy
from faster_rcnn.feature_extractor_new import featureExtractor
import numpy as np
import json
import torch

class DataLoader():
    """This class loads and preprocess the data set ALFRED for VilBERT.
        Input: data.json and model for the feature extractor (FastRCnn)
        Ouput: List of length = number of different instructions on our dataset. Each element of the list
                is a dictionary containing a sequence of images + instruction (text), under the keys
                [imgs] & [desc] respectively. [imgs] its a dictionary with keys [features], 
                [pos_enco] and [infos], gathering the already masked and tokenized features extracted from 
                the fasRCnn, a positional encoder of the bounding boxes and some additional information. 
                [desc] is a dictionary with keys [text_id],[modified_token] and [masked_lm_token], gathering
                the tokenized instruction, the modification after making it and the masked_lm_token for VilBERT, 
                respectively. """
    
    def __init__(self, json_path, model):
        self.data = json.load(open(json_path, "r"))
        self.tokenized_data = deepcopy(self.data)
        self.model = model
        
    def get_processed_data(self):
        return self.tokenized_data
        
    def extract_features(self):
        print("  Extracting features...")
        for i, one_action_data in enumerate(self.tokenized_data):
            print("    Action indx", i)
            f_extractor = featureExtractor(one_action_data["imgs"], self.model)
            features, positional_encoding, infos = f_extractor.extract_features() 
            self.tokenized_data[i]["imgs"] = {"feat":features,"pos_enco": positional_encoding, "spatial":[], "image_mask":[],
                                              "infos":infos, "co_attention_mask":[]}
        return self.tokenized_data
    
    def text_tokenizer(self):
        """We add the special tokens to the text (actions)"""
        print("Tokenizing text...")
        for i, one_action_data in enumerate(self.tokenized_data):
            text = '[CLS]' + one_action_data["desc"][0] + '[SEP]'
            self.length_text = len(text)
            tokenized_text = tokenizer.tokenize(text)
            tokenized_text = tokenizer.convert_tokens_to_ids(tokenized_text)
            segment_ids = [0] * len(tokenized_text)
            input_mask = [1] * len(tokenized_text)
            self.max_length = 37
            if len(tokenized_text) < self.max_length:
                # Note here we pad in front of the sentence
                padding = [0] * (self.max_length - len(tokenized_text))
                tokenized_text = tokenized_text + padding
                input_mask += padding
                segment_ids += padding
                
            self.tokenized_data[i]["desc"] = {"tokenized_text":tokenized_text, 
                                              "input_mask":input_mask,"segment_ids": segment_ids,
                                              "modified_token":[], "masked_lm_token":[]}
            
    def img_tokenizer(self):
        """Add the spacial token IMG before each set of features extracted from an image"""
        print("Tokenizing images...")
        self.extract_features()
        for i, one_action_data in enumerate(self.tokenized_data):
            for j in range(len(one_action_data["imgs"]["feat"])):
                
                mean_pooled_feat = torch.mean(one_action_data["imgs"]["feat"][j], 0).unsqueeze(0) #Equivalent to IMG special token
                one_action_data["imgs"]["feat"][j] = torch.cat((mean_pooled_feat, one_action_data["imgs"]["feat"][j]), dim=0)
                
                one_action_data["imgs"]["infos"][j]["objects"] = torch.cat((torch.tensor([-1]), one_action_data["imgs"]["infos"][j]["objects"] ), dim=0)
                one_action_data["imgs"]["pos_enco"][j] = torch.cat((torch.zeros((1, one_action_data["imgs"]["pos_enco"][j].shape[1])), one_action_data["imgs"]["pos_enco"][j]), dim=0)
                
                boxes = one_action_data["imgs"]["infos"][j]["bbox"]
                image_w = one_action_data["imgs"]["infos"][j]["image_width"]
                image_h = one_action_data["imgs"]["infos"][j]["image_height"]
                
                image_location = torch.zeros((boxes.shape[0], 5))
                image_location[:,:4] = torch.from_numpy(boxes)
                image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
                image_location[:,0] = image_location[:,0] / float(image_w)
                image_location[:,1] = image_location[:,1] / float(image_h)
                image_location[:,2] = image_location[:,2] / float(image_w)
                image_location[:,3] = image_location[:,3] / float(image_h)
                full_image_5D_encoding = torch.FloatTensor([[0, 0, 1, 1, 1]])
                one_action_data["imgs"]["image_mask"].append(torch.tensor([1] * (int(one_action_data["imgs"]["infos"][j]["num_boxes"]+1))))
                
                spatial_img_location = torch.cat((full_image_5D_encoding, image_location), dim=0)
                one_action_data["imgs"]["spatial"].append(spatial_img_location) 
            
    def mask_text(self):
        
        """We will generate 2 new outputs from the Tokenized text: the modified token and the mask_lm_token
            which will be stored in the dictionary in self.tokenized_data["desc"] 
            under the keys (text_id, modified_token, mask_lm_token)"""
        
        for i, one_action_data in enumerate(self.tokenized_data):
            
            modified_token = deepcopy(one_action_data["desc"]["tokenized_text"])
            masked_lm_labels = -1*np.ones(len(modified_token)).astype(int)
            type_modification = np.random.choice(np.array(["MASK", "random", "unaltered"]), p=[0.8, 0.1,0.1])
            length_token = self.length_text - 2
            num_masked_tokens = int(0.15*length_token)

            if num_masked_tokens<1:
                num_masked_tokens=1

            if type_modification=="MASK":
                for i in range(num_masked_tokens):
                    indx = np.random.randint(length_token)
                    modified_token[indx+1] = tokenizer.encode('MASK')[0]
                    masked_lm_labels[indx+1] = tokenizer.encode('MASK')[0]

            elif type_modification=="random":
                for i in range(num_masked_tokens):
                    indx = np.random.randint(length_token)
                    modified_token[indx+1] = np.random.randint(30522) #size of vacbulary
                    masked_lm_labels[indx+1] = np.random.randint(1024)
            else:
                for i in range(num_masked_tokens):
                    indx = np.random.randint(length_token)
                    modified_token[indx+1] = one_action_data["desc"]["tokenized_text"][indx+1] #size of vacbulary
                    masked_lm_labels[indx+1] = one_action_data["desc"]["tokenized_text"][indx+1]
            
            one_action_data["desc"]["modified_token"] = modified_token
            one_action_data["desc"]["masked_lm_token"] = masked_lm_labels
            
    
    def mask_img(self):
        
        """"We will mask the 15% of the patches features with 90% probability to zeroed features 
        and 10% unalteres"""
        
        for i, one_action_data in enumerate(self.tokenized_data):
            for f in range(len(one_action_data["imgs"]["feat"])):
                length_feat = one_action_data["imgs"]["feat"][f].shape[0]
                type_modification = np.random.choice(np.array(["zeros", "unaltered"]), p=[0.9, 0.1])
                num_masked_tokens = int(0.15*length_feat)

                if num_masked_tokens<1:
                    num_masked_tokens=1

                if type_modification=="zeros":
                    for _ in range(num_masked_tokens):
                        i = np.random.randint(length_feat)
                        one_action_data["imgs"]["feat"][f][i] = torch.zeros((1, one_action_data["imgs"]["feat"][f].shape[1]))

    
    def get_data_masked_train(self):
        """This function executes tranforms the text into tockens, the extractor of features
        the masking in the text and image"""
        self.text_tokenizer()
        self.img_tokenizer()
        self.mask_text()
        self.mask_img()
        data = self.get_processed_data()
        for instruction, data_point in enumerate(data):
            #To avoid issues with last action STOP
            if instruction == len(data)-1:
                break
            for type_data_key, type_data_value in data_point.items():
                for key, value in type_data_value.items():
                    if type_data_key == "imgs":
                        if key == "infos":
                            continue
                        if key == "co_attention_mask":
                            type_data_value[key].append(torch.zeros((data[0]["imgs"]["feat"].shape[0], self.max_length)))
                        if key == "image_mask":
                            cat =  value[0]
                            for i in range(1,len(value)):
                                cat = torch.cat((cat, value[i]), dim=0)
                            type_data_value[key] = cat
                        else:
                            type_data_value[key] = torch.cat(value, dim=0)
                    else: 
                        type_data_value[key] = torch.tensor(value)
                        
        return (data[0]["imgs"]["feat"], data[0]["imgs"]["pos_enco"], data[0]["imgs"]["spatial"],data[0]["imgs"]["image_mask"],
                data[0]["desc"]["tokenized_text"], data[0]["desc"]["modified_token"],data[0]["desc"]["masked_lm_token"],
                 data[0]["desc"]["input_mask"], data[0]["desc"]["segment_ids"], data[0]["imgs"]["co_attention_mask"], data[0]["imgs"]["infos"])
        
        
            

In [6]:
frcnn_model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
DataLoader = DataLoader("short_json_data.json", frcnn_model)
data = DataLoader.get_data_masked_train()

Tokenizing text...
Tokenizing images...
  Extracting features...
    Action indx 0
    Action indx 1
    Action indx 2
    Action indx 3


In [7]:
features_masked, pos_enc, spatial, image_mask, tokenized_text, masked_text, masked_lm_token, input_mask, segment_ids, co_attention_mask, infos = data

In [8]:
print(features_masked.shape)
print(tokenized_text)
print(masked_text)


torch.Size([33, 6144])
tensor([  101,  2175,  2000,  1996, 13065,   102,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0])
tensor([  101,  2175,  2000,  1996, 13065,   102,     0,     0,     0,     0,
            0,  7308,  7308,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0])


In [9]:
image_target = []
for i in range(len(infos)):
    image_target.append(infos[i]["objects"])

image_target = torch.cat(image_target, dim=0)
print(image_target.shape)

torch.Size([33])


In [18]:
model.cpu()
model.train()
optimizer = AdamW(model.parameters(),
                    lr=args.learning_rate,
                    eps=args.adam_epsilon,
                    betas=(0.9, 0.98),)


In [20]:
for key, value in dict(model.named_parameters()).items():
    if not value.requires_grad:
        print("This parameter does have grad", key)

In [21]:
for epoch in range(10):
    pred_t, pred_v, att = model(input_ids = tokenized_text.unsqueeze(0).cpu(),
                            image_feat = features_masked.unsqueeze(0).cpu(), # Linear(2048*config.max_temporal_memory_buffer, 2048)
                            image_loc = spatial.unsqueeze(0).cpu(),  #Linear(in_features=5, out_features=1024, bias=True)
                            image_pos_input = pos_enc.unsqueeze(0).cpu(),   #Linear(7, 2048)/(6, 2048)
                            token_type_ids = segment_ids.unsqueeze(0).cpu(), 
                            attention_mask = input_mask.unsqueeze(0).cpu(), 
                            image_attention_mask = image_mask.unsqueeze(0).cpu(), 
                            co_attention_mask = co_attention_mask.unsqueeze(0).cpu(),
                            masked_lm_labels = masked_lm_token.unsqueeze(0).cpu(), 
                            image_label = None,
                            image_target = None,
                            next_sentence_label=None,
                            output_all_attention_masks=True)

    optimizer.zero_grad()
    masked_lm_loss = model.lang_criterion(pred_t.view(-1, 30522), masked_lm_token.cpu().view(-1))
    img_loss = model.vis_criterion(pred_v.view(-1, 91), image_target.cpu()) # why dim 2 (to check) 
    loss = masked_lm_loss + img_loss
    loss.backward()
    optimizer.step()
    print("loss: ", loss)



Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.6649, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(2.6292, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.6354, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.5269, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.6374, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.5980, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.5098, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.4486, grad_fn=<AddBackward0>)
Image features after adding pos enc -> torch.Size([1, 33, 2048])
loss:  tensor(1.4301, grad_fn=<AddBackward0>)
I

In [None]:
loss_masked_lang = model.lang_criterion(pred_t, labels_t)


In [21]:
print((att))

([None, None, None, None, None, None, None, None, None, None, None, None], [None, None, None, None, None, None], [None, None, None, None, None, None])


In [None]:
features = torch.stack(feature_list, dim=0).float().cuda()
    spatials = torch.stack(image_location_list, dim=0).float().cuda()
    image_mask = torch.stack(image_mask_list, dim=0).byte().cuda()
    co_attention_mask = torch.zeros((num_image, num_boxes, max_length)).cuda()

    prediction(text, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task)

In [None]:
print(data[1]["imgs"]["feat"].shape)
concate = torch.cat((data[1]["imgs"]["feat"].unsqueeze(0), data[0]["imgs"]["feat"].unsqueeze(0)), dim=0)
print(concate.shape)

In [None]:
print(data)

In [None]:
print("Structure of data: ")
print(" · Number of instructions: ", len(data))
print("  · Per instruction we have: ", data[0].keys())
print("    · The 'desc' of the instruction has: ", data[0]["desc"].keys())
print("       - lists of length tokenized text -->",len(data[0]["desc"]["tokenized_text"]))
print("    · The 'imgs' of the instruction has:", data[0]["imgs"].keys())
for k, v in data[0]["imgs"].items():
    print(k, v)
    #print(len(data[0]["imgs"]["spatial"][0]))

In [None]:
model

In [None]:
for key, value in dict(model.named_parameters()).items():
    if not value.requires_grad:
        print("This parameter does have grad", key)

In [None]:
model.eval()
cuda = torch.cuda.is_available()
if cuda: model = model.cuda(0)
#Why do we initialize Tokenizer again?
tokenizer = BertTokenizer.from_pretrained(
    args.bert_model, do_lower_case=args.do_lower_case
)

## Data preparation

In [None]:
def bert_tokenize( text):
    text = '[CLS]' + text + '[SEP]'
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    return indexed_tokens

In [None]:
text = "swimming elephant"
indexed_tokens = bert_tokenize(text)
#indexed_tokens, modified_indexed_tokens, masked_lm_labels = pretask_mask_lang_tokens(indexed_tokens)

query = '[CLS]' + text + '[SEP]'
tokens = tokenizer.encode(query)
print(tokens)

In [None]:
def prediction(question, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task_tokens, ):
    print('input question size question: ', question.shape)
    pos_enc_input = torch.FloatTensor([0,1,2,3,4,5])
    print(question.shape)
    masked_lm_loss, masked_img_loss, prediction_t, prediction_v, all_attention_mask = model(
        input_ids=question, image_feat=features, image_loc=spatials,image_pos_input=pos_enc_input, output_all_attention_masks=True
    )
    return masked_lm_loss, masked_img_loss, prediction_t, prediction_v, all_attention_mask

In [None]:
def custom_prediction(query, task, features, infos):

#     print(query)
#     query = '[CLS]' + query + '[SEP]'
#     tokens = tokenizer.encode(query)
#     print(tokens)
    #tokens= 
    #tokens = tokenizer.add_special_tokens_single_sentence(tokens)
    
    tokens = bert_tokenize(query)
    segment_ids = [0] * len(tokens)
    input_mask = [1] * len(tokens)

    max_length = 37
    if len(tokens) < max_length:
        # Note here we pad in front of the sentence
        padding = [0] * (max_length - len(tokens))
        tokens = tokens + padding
        input_mask += padding
        segment_ids += padding

    text = torch.from_numpy(np.array(tokens)).cuda().unsqueeze(0)
    input_mask = torch.from_numpy(np.array(input_mask)).cuda().unsqueeze(0)
    segment_ids = torch.from_numpy(np.array(segment_ids)).cuda().unsqueeze(0)
    task = torch.from_numpy(np.array(task)).cuda().unsqueeze(0)

    num_image = len(infos)

    feature_list = []
    image_location_list = []
    image_mask_list = []
    for i in range(num_image):
        image_w = infos[i]['image_width']
        image_h = infos[i]['image_height']
        feature = features[i]
        num_boxes = feature.shape[0] #first dim size = number boxes

        g_feat = torch.sum(feature, dim=0) / num_boxes # Mean of features of all the selected regions
        num_boxes = num_boxes + 1
        feature = torch.cat([g_feat.view(1,-1), feature], dim=0)
        boxes = infos[i]['bbox']
        image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
        image_location[:,:4] = boxes
        image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
        image_location[:,0] = image_location[:,0] / float(image_w)
        image_location[:,1] = image_location[:,1] / float(image_h)
        image_location[:,2] = image_location[:,2] / float(image_w)
        image_location[:,3] = image_location[:,3] / float(image_h)
        g_location = np.array([0,0,1,1,1])
        image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
        image_mask = [1] * (int(num_boxes))

        feature_list.append(feature)
        image_location_list.append(torch.tensor(image_location))
        image_mask_list.append(torch.tensor(image_mask))

    features = torch.stack(feature_list, dim=0).float().cuda()
    spatials = torch.stack(image_location_list, dim=0).float().cuda()
    image_mask = torch.stack(image_mask_list, dim=0).byte().cuda()
    co_attention_mask = torch.zeros((num_image, num_boxes, max_length)).cuda()
#     print("text: ", text.shape)
#     print("feat: ", features.shape)
#     print("spatials: ", spatials.shape)
#     print("segments_id: ", segment_ids)
#     print("input_mask: ", input_mask)
#     print("image_mask: ", image_mask)
#     print("coatenttion_mask: ", co_attention_mask)
    return prediction(text, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task)

In [None]:
frcnn_model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# read image 
#pic = "faster_rcnn/test2.png"
#pic1 = "faster_rcnn/test.png"
#image_paths = [pic, pic1]



image_path = ['demo/1.jpg']
features, pos_enc, infos = featureExtractor(image_path, frcnn_model).extract_features()
print("features: ", type(features))
print("infos: ", infos)
#features, infos = f_extractor.extract_features(image_path, frcnn_model)


img = PIL.Image.open(image_path[0]).convert('RGB')
img = torch.tensor(np.array(img))

plt.axis('off')
plt.imshow(img)
plt.show()
    
query = "swimming elephant"
task = [9]
masked_lm_loss, masked_img_loss, prediction_t, prediction_v, all_attention_mask = custom_prediction(query, task, features, infos)

In [None]:
print(data[0]["imgs"]["pos_enco"])

In [None]:
masked_lm_loss

In [None]:
 prediction_t.shape

In [None]:
prediction_t

In [None]:
prediction_v.shape

In [None]:
prediction_v

In [None]:
all_attention_mask

In [None]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
config.visual_target = args.visual_target

optimizer_grouped_parameters = []
for key, value in dict(model.named_parameters()).items():
    if value.requires_grad:
        if "cls" in key:
            lr = args.learning_rate
        else:
            lr = args.learning_rate * 0.1
        if any(nd in key for nd in no_decay): # No decay
            optimizer_grouped_parameters += [
                {"params": [value],
                 "lr": lr,
                 "weight_decay": 0.0}
            ]
        
        elif not any(nd in key for nd in no_decay):
            optimizer_grouped_parameters += [
                {"params": [value],
                 "lr": lr,
                 "weight_decay": 0.01}
            ]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=args.learning_rate,
    eps=args.adam_epsilon,
    betas=(0.9, 0.98),
)

num_dataset_points = 19

num_train_optimization_steps = int(
    num_dataset_points
    / args.train_batch_size
    / args.gradient_accumulation_steps
) * (args.num_train_epochs - args.start_epoch)

scheduler = WarmupLinearSchedule(
    optimizer,
    warmup_steps=args.warmup_proportion * num_train_optimization_steps,
    t_total=num_train_optimization_steps,
)

In [None]:
# model.cuda()
# for state in optimizer.state.values():
#     for k, v in state.items():
#         if torch.is_tensor(v):
#             state[k] = v.cuda()

logger.info("***** Running training *****")
logger.info("Num examples = %d", num_dataset_points)
logger.info("Batch size = %d", args.train_batch_size)
logger.info("Num steps = %d", num_train_optimization_steps)

In [None]:
startIterID = 0
global_step = 0

for epoch in range(int(args.start_epoch), int(args.num_train_epochs)):
    model.train()
    for step, batch in enumerate(train_dataset):
        iterId = startIterID + step + (epochId * len(train_dataset))
    