# make a small model, train end-to-end using only VQA dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist

# use vqa model
from models.model_vqa import ALBEF

from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer

import utils
from dataset.utils import save_result
from dataset import create_dataset, create_sampler, create_loader, vqa_collate_fn

from scheduler import create_scheduler
from optim import create_optimizer

# print and plotting 
from pprint import pprint
import matplotlib.pyplot as plt
from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# prep data, set True if you have unzipped the data
unzipped = True
VQA_DATA_DIR = 'data'
# VQA_NEWMODEL_DIR = 'pretrained/vqa'

# Ensure the directories exist
os.makedirs(VQA_DATA_DIR, exist_ok=True)
# os.makedirs(VQA_NEWMODEL_DIR, exist_ok=True)
if not unzipped:
    import os
    import zipfile

    # prep downloaded data


    zip_files = [
        'train2014.zip',
        'test2015.zip',        
        'val2014.zip',
        'data.tar.gz'
    ]

    for zip_file in zip_files:
        zip_path = os.path.join(VQA_DATA_DIR, zip_file)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(VQA_DATA_DIR)
    print('Unzipped all files.')
    


## Setup for training 

In [4]:
# config
args = argparse.Namespace()
args.config = './configs/VQA_only.yaml'
args.checkpoint = '' # './ALBEF_4M.pth'
args.output_dir = './output/vqa_end2end'
args.evaluate = False # to train use False
args.text_encoder = 'bert-base-uncased'
args.text_decoder = 'bert-base-uncased'
args.device = 'cuda'
args.seed = 42
args.distributed = False

config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
pprint(config)

# make result folder and save config
args.result_dir = os.path.join(args.output_dir, 'result')

Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)

yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

{'alpha': 0.0,
 'answer_list': 'data/answer_list.json',
 'batch_size_test': 16,
 'batch_size_train': 32,
 'bert_config': 'configs/config_bert_small.json',
 'distill': False,
 'eos': '[SEP]',
 'image_res': 224,
 'k_test': 128,
 'optimizer': {'lr': 2e-05, 'opt': 'adamW', 'weight_decay': 0.02},
 'schedular': {'cooldown_epochs': 0,
               'decay_rate': 1,
               'epochs': 8,
               'lr': 2e-05,
               'min_lr': 1e-06,
               'sched': 'cosine',
               'warmup_epochs': 4,
               'warmup_lr': 1e-05},
 'test_file': ['data/vqa_test.json'],
 'train_file': ['data/vqa_train.json', 'data/vqa_val.json'],
 'vg_root': 'data/',
 'vqa_root': 'data/',
 'warm_up': False}


In [5]:
# training functions
def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
    # train
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50
    step_size = 100
    warmup_iterations = warmup_steps*step_size

    for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
        question_input = tokenizer(question, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
        answer_input = tokenizer(answer, padding='longest', return_tensors="pt").to(device)

        if epoch>0 or not config['warm_up']:
            alpha = config['alpha']
        else:
            alpha = config['alpha']*min(1,i/len(data_loader))

        loss = model(image, question_input, answer_input, train=True, alpha=alpha, k=n, weights=weights)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        if epoch==0 and i%step_size==0 and i<=warmup_iterations:
            scheduler.step(i//step_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())
    return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config) :
    # test
    model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Generate VQA test result:'
    print_freq = 50

    result = []

    answer_list = [answer+config['eos'] for answer in data_loader.dataset.answer_list]
    answer_input = tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)

    for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image = image.to(device,non_blocking=True)
        question_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)

        topk_ids, topk_probs = model(image, question_input, answer_input, train=False, k=config['k_test'])

        for ques_id, topk_id, topk_prob in zip(question_id, topk_ids, topk_probs):
            ques_id = int(ques_id.item())
            _, pred = topk_prob.max(dim=0)
            result.append({"question_id":ques_id, "answer":data_loader.dataset.answer_list[topk_id[pred]]})

    return result

In [6]:
# setup for training (from main)
utils.init_distributed_mode(args)

device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True

start_epoch = 0
max_epoch = config['schedular']['epochs']
warmup_steps = config['schedular']['warmup_epochs']

Not using distributed mode


In [7]:
# make dataset and dataloader
print("Creating vqa datasets")
datasets = create_dataset('vqa', config)

if args.distributed:
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
else:
    samplers = [None, None]

train_loader, test_loader = create_loader(datasets,samplers,
                                          batch_size=[config['batch_size_train'],config['batch_size_test']],
                                          num_workers=[4,4],is_trains=[True, False],
                                          collate_fns=[vqa_collate_fn,None])

tokenizer = BertTokenizer.from_pretrained(args.text_encoder)

Creating vqa datasets




In [8]:
# clean up GPU memory
import gc
def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()

clear_gpu_memory()    

In [9]:
#### Model ####
print("Creating model")
model = ALBEF(config=config, text_encoder=args.text_encoder, text_decoder=args.text_decoder, tokenizer=tokenizer)
model = model.to(device)

arg_opt = utils.AttrDict(config['optimizer'])
optimizer = create_optimizer(arg_opt, model)
arg_sche = utils.AttrDict(config['schedular'])
lr_scheduler, _ = create_scheduler(arg_sche, optimizer)

# check model
model


Creating model


ALBEF(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-5): 6 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )

In [10]:
# revise the checkpoint to continue training (when error occurs)
args.checkpoint = './output/vqa_end2end/checkpoint_00.pth' # './ALBEF_4M.pth'
if not os.path.exists(args.checkpoint):
    raise FileNotFoundError(f"Checkpoint file '{args.checkpoint}' does not exist.")
# note: need to manually adjust the start epoch
start_epoch = 1
print(f'checkpoint path: {args.checkpoint}. start epoch: {start_epoch}')

checkpoint path: ./output/vqa_end2end/checkpoint_00.pth. start epoch: 1


In [None]:
# load check point to continue training
if args.checkpoint:
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    if args.evaluate:
        state_dict = checkpoint
    else:
        state_dict = checkpoint['model']

    # reshape positional embedding to accomodate for image resolution change
    pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
    state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped

    if not args.evaluate:
        if config['distill']:
            m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)
            state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped

        for key in list(state_dict.keys()):
            if 'bert' in key:
                encoder_key = key.replace('bert.','')
                state_dict[encoder_key] = state_dict[key]
            # intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder)
            if 'text_encoder' in key:
                if 'layer' in key:
                    # print(key)
                    encoder_keys = key.split('.')
                    # print(encoder_keys)
                    # print(encoder_keys[4])
                    tmp_fix_idx = 5 # for the downsized model, idx 5 is the layer number
                    layer_num = int(encoder_keys[tmp_fix_idx]) # 4
                    if layer_num<6:
                        del state_dict[key]
                        continue
                    else:
                        decoder_layer_num = (layer_num-6)
                        encoder_keys[4] = str(decoder_layer_num)
                        encoder_key = '.'.join(encoder_keys)
                else:
                    encoder_key = key
                decoder_key = encoder_key.replace('text_encoder','text_decoder')
                state_dict[decoder_key] = state_dict[key]

                del state_dict[key]

    msg = model.load_state_dict(state_dict,strict=False)
    print('load checkpoint from %s'%args.checkpoint)
    print(msg)


text_encoder.encoder.layer.0.attention.self.query.weight
['text_encoder', 'encoder', 'layer', '0', 'attention', 'self', 'query', 'weight']
attention


  checkpoint = torch.load(args.checkpoint, map_location='cpu')


ValueError: invalid literal for int() with base 10: 'attention'

In [None]:
# handle distributed training
model_without_ddp = model
if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    model_without_ddp = model.module


In [None]:
# training loop, single GPU
print("Start training")
start_time = time.time()

for epoch in range(start_epoch, max_epoch):
    if epoch>0:
        lr_scheduler.step(epoch+warmup_steps)

    if not args.evaluate:
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)

        train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)

    if args.evaluate:
        break

    if utils.is_main_process():
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                      'epoch': epoch,
                    }
        with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
            f.write(json.dumps(log_stats) + "\n")

        save_obj = {
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'config': config,
            'epoch': epoch,
        }
        torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
    if args.distributed:
        dist.barrier()
    else:
        pass  # Skip barrier for non-distributed training
    
    # clean up
    clear_gpu_memory()

vqa_result = evaluation(model, test_loader, tokenizer, device, config)
result_file = save_result(vqa_result, args.result_dir, 'vqa_result_epoch%d'%epoch)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

Start training


  offset = -low * scale
  offset = -low * scale
  offset = -low * scale
  offset = -low * scale


Train Epoch: [0]  [    0/20565]  eta: 9:29:51  lr: 0.000010  loss: 23.6020  time: 1.6626  data: 0.3350  max mem: 1804
Train Epoch: [0]  [   50/20565]  eta: 3:30:50  lr: 0.000010  loss: 16.5129  time: 0.5979  data: 0.0001  max mem: 2837
Train Epoch: [0]  [  100/20565]  eta: 3:29:59  lr: 0.000010  loss: 13.5169  time: 0.6183  data: 0.0001  max mem: 2837
Train Epoch: [0]  [  150/20565]  eta: 3:29:14  lr: 0.000013  loss: 14.0854  time: 0.6205  data: 0.0001  max mem: 2963
Train Epoch: [0]  [  200/20565]  eta: 3:29:13  lr: 0.000013  loss: 11.9379  time: 0.6256  data: 0.0001  max mem: 2963
Train Epoch: [0]  [  250/20565]  eta: 3:29:11  lr: 0.000015  loss: 12.3172  time: 0.6120  data: 0.0001  max mem: 2963
Train Epoch: [0]  [  300/20565]  eta: 3:29:05  lr: 0.000015  loss: 11.6710  time: 0.6310  data: 0.0001  max mem: 2963
Train Epoch: [0]  [  350/20565]  eta: 3:28:46  lr: 0.000018  loss: 10.4605  time: 0.6272  data: 0.0001  max mem: 3141
Train Epoch: [0]  [  400/20565]  eta: 3:28:11  lr: 0.000

ValueError: Default process group has not been initialized, please make sure to call init_process_group.