<a href="https://colab.research.google.com/github/HaixinLiuNeuro/ALBEF/blob/main/colab_load_pretrained4M_freezeFineTuneVQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load pretrained 4M model, freeze encoder's parameters fine-tune with only VQA dataset, run evaluation test

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
# setup drive folder
import os

# TODO: Fill in the Google Drive path where you want to save result
GOOGLE_DRIVE_PATH_POST_MYDRIVE = os.path.join('DL_Project', 'ALBEF')
GOOGLE_DRIVE_PATH = os.path.join('/content', 'drive', 'MyDrive', GOOGLE_DRIVE_PATH_POST_MYDRIVE)
os.makedirs(GOOGLE_DRIVE_PATH, exist_ok=True)
print(os.listdir(GOOGLE_DRIVE_PATH))

['output', 'vqa_end2end', 'vqa_onlyPretrainModel', 'vqa_nopretrain_noTune']


In [3]:
# if running locally set GOOGLE PATH
import sys
if 'google.colab' in sys.modules:
  print(f'Running in google colab. Our path is `{GOOGLE_DRIVE_PATH}`')
else:
  GOOGLE_DRIVE_PATH = '.'
  print('Running locally.')

Running in google colab. Our path is `/content/drive/MyDrive/DL_Project/ALBEF`


In [4]:
import sys
import numpy as np
import math
sys.path.append(GOOGLE_DRIVE_PATH)
print(f'Google Drive Path: {GOOGLE_DRIVE_PATH}')

Google Drive Path: /content/drive/MyDrive/DL_Project/ALBEF


In [5]:
# Clone the repo to a content
!git clone -b main https://github.com/HaixinLiuNeuro/ALBEF.git /tmp/ALBEF
!cp -r /tmp/ALBEF/* .
!rm -rf /tmp/ALBEF

Cloning into '/tmp/ALBEF'...
remote: Enumerating objects: 402, done.[K
remote: Counting objects: 100% (240/240), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 402 (delta 138), reused 124 (delta 122), pack-reused 162 (from 2)[K
Receiving objects: 100% (402/402), 71.62 MiB | 68.35 MiB/s, done.
Resolving deltas: 100% (162/162), done.


In [6]:
# install dependency
!pip install transformers==4.25.1
!pip install ruamel.yaml==0.17.*
!pip install matplotlib


Collecting transformers==4.25.1
  Downloading transformers-4.25.1-py3-none-any.whl.metadata (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.9/93.9 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.25.1)
  Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m127.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: t

In [7]:
# import
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist

# use vqa model
from models.model_vqa_freeze import ALBEF

from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer

import utils
from dataset.utils import save_result
from dataset import create_dataset, create_sampler, create_loader, vqa_collate_fn

from scheduler import create_scheduler
from optim import create_optimizer

# print and plotting
from pprint import pprint
import matplotlib.pyplot as plt
from PIL import Image

%load_ext autoreload
%autoreload 2




In [8]:
# %reload_ext autoreload

In [9]:
# prep data
# download from website

# make folder /content/data
DATA_PATH = os.path.join('/content', 'data')
os.makedirs(DATA_PATH, exist_ok=True)

%cd /content/data

# download data from links:
# https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/json_pretrain.zip
# https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/data.tar.gz
# http://images.cocodataset.org/zips/train2014.zip
# http://images.cocodataset.org/zips/val2014.zip
# http://images.cocodataset.org/zips/test2015.zip




# Define the download links
links = [
    # "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/json_pretrain.zip", # pretrain json
    "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/data.tar.gz", # for downstream task json
    "http://images.cocodataset.org/zips/train2014.zip", # comment out if only run evaluation
    "http://images.cocodataset.org/zips/val2014.zip",   # comment out if only run evaluation
    "http://images.cocodataset.org/zips/test2015.zip"
]

# Download and extract each file
for link in links:
    filename = link.split('/')[-1]
    print(f"Downloading {filename}...")

    # Download file
    !wget -q --show-progress {link}

    print(f"Extracting {filename}...")

    # Extract based on file extension
    if filename.endswith('.zip'):
      if '//images.cocodataset.org/zips/' in link:
        !unzip -q {filename}
      else:
        !unzip -q -j {filename}  # -j option flattens the directory structure for json_pretrain.zip
    elif filename.endswith('.tar.gz'):
        !tar -xzf {filename} --strip-components=1  # Remove the top-level directory

    # Delete the zip/tar file after extraction
    print(f"Removing {filename}...")
    !rm {filename}

    print(f"Finished processing {filename}")

print("All downloads and extractions completed!")

%cd /content

/content/data
Downloading data.tar.gz...
Extracting data.tar.gz...
^C
Removing data.tar.gz...
Finished processing data.tar.gz
Downloading train2014.zip...
train2014.zip         6%[>                   ] 826.35M  53.1MB/s    eta 3m 33s ^C
Extracting train2014.zip...
[train2014.zip]
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of train2014.zip or
        train2014.zip.zip, and cannot find train2014.zip.ZIP, period.
Removing train2014.zip...
Finished processing train2014.zip
Downloading val2014.zip...
val2014.zip           7%[>                   ] 480.01M  58.4MB/s    eta 99s    ^C
Extracting val2014.zip...
[val2014.zip]
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a mu

In [10]:
# !rm -rf /content/data

In [11]:
# check files
%cd /content/data
!ls
%cd /content

/content/data
coco_test.json	 nlvr_train.json     ve_dev.json   vqa_test_dev.json
coco_train.json  refcoco+	     ve_test.json  vqa_test.json
nlvr_test.json	 refcoco+_test.json  vg_qa.json    vqa_train.json
/content


In [None]:
#
FETCH_PRETRAINED_MODEL = True
%cd /content

# download data from links:
# https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth
# https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/vqa.pth
# model check point from training
# https://drive.google.com/file/d/1yEsyeB0FkIgWlT2Way_KFLPNLCQy6KoU/view?usp=sharing

if FETCH_PRETRAINED_MODEL:

  # Define the download links
  links = [
      "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth",
      # "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/vqa.pth"
  ]

  # Download and extract each file
  for link in links:
      filename = link.split('/')[-1]
      print(f"Downloading {filename}...")

      # Download file
      !wget -q --show-progress {link}


      print(f"Finished processing {filename}")

  print("All model downloads completed!")




/content
Downloading ALBEF_4M.pth...

## Setup for training

In [None]:
# config
%cd /content
args = argparse.Namespace()
args.config = './configs/VQA.yaml'
args.checkpoint = './ALBEF_4M.pth'
args.output_dir = 'output/vqa_PretrainModel_freezeTune'
args.evaluate = True # to train use False
args.text_encoder = 'bert-base-uncased'
args.text_decoder = 'bert-base-uncased'
args.device = 'cuda'
args.seed = 42
args.distributed = False

config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
pprint(config)

# make result folder and save config
args.result_dir = os.path.join(args.output_dir, 'result')

Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)

yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

In [None]:
# training functions
def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
    # train
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50
    step_size = 100
    warmup_iterations = warmup_steps*step_size

    for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
        question_input = tokenizer(question, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
        answer_input = tokenizer(answer, padding='longest', return_tensors="pt").to(device)

        if epoch>0 or not config['warm_up']:
            alpha = config['alpha']
        else:
            alpha = config['alpha']*min(1,i/len(data_loader))

        loss = model(image, question_input, answer_input, train=True, alpha=alpha, k=n, weights=weights)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        if epoch==0 and i%step_size==0 and i<=warmup_iterations:
            scheduler.step(i//step_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())
    return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config) :
    # test
    model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Generate VQA test result:'
    print_freq = 50

    result = []

    answer_list = [answer+config['eos'] for answer in data_loader.dataset.answer_list]
    answer_input = tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)

    for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image = image.to(device,non_blocking=True)
        question_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)

        topk_ids, topk_probs = model(image, question_input, answer_input, train=False, k=config['k_test'])

        for ques_id, topk_id, topk_prob in zip(question_id, topk_ids, topk_probs):
            ques_id = int(ques_id.item())
            _, pred = topk_prob.max(dim=0)
            result.append({"question_id":ques_id, "answer":data_loader.dataset.answer_list[topk_id[pred]]})

    return result

In [None]:
# setup for training/evaluation (from main)
utils.init_distributed_mode(args)

device = torch.device(args.device)
print(f'device: {device}')

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True

start_epoch = 0
max_epoch = config['schedular']['epochs']
warmup_steps = config['schedular']['warmup_epochs']

In [None]:
# make dataset and dataloader
print("Creating vqa datasets")
datasets = create_dataset('vqa', config)

if args.distributed:
    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
else:
    samplers = [None, None]

train_loader, test_loader = create_loader(datasets,samplers,
                                          batch_size=[config['batch_size_train'],config['batch_size_test']],
                                          num_workers=[4,4],is_trains=[True, False],
                                          collate_fns=[vqa_collate_fn,None])

tokenizer = BertTokenizer.from_pretrained(args.text_encoder)

In [None]:
#### Model ####
print("Creating model")
model = ALBEF(config=config, text_encoder=args.text_encoder, text_decoder=args.text_decoder, tokenizer=tokenizer)
model = model.to(device)

arg_opt = utils.AttrDict(config['optimizer'])
optimizer = create_optimizer(arg_opt, model)
arg_sche = utils.AttrDict(config['schedular'])
lr_scheduler, _ = create_scheduler(arg_sche, optimizer)

# check model
model


In [None]:
# load check point to continue training
if args.checkpoint:
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    if args.evaluate:
        state_dict = checkpoint
    else:
        state_dict = checkpoint['model']

    # with checkpoint of vqa model
    # reshape positional embedding to accomodate for image resolution change
    # pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
    # state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped

    # Check if the key exists before accessing it
    if 'visual_encoder.pos_embed' in state_dict:
        # reshape positional embedding to accomodate for image resolution change
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
    else:
        print("Warning: 'visual_encoder.pos_embed' not found in checkpoint. Skipping positional embedding interpolation.")


    if not args.evaluate:
        if config['distill']:
            m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)
            state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped

        for key in list(state_dict.keys()):
            if 'bert' in key:
                encoder_key = key.replace('bert.','')
                state_dict[encoder_key] = state_dict[key]
            # intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder)
            if 'text_encoder' in key:
                if 'layer' in key:
                    # print(key)
                    encoder_keys = key.split('.')
                    print(encoder_keys)
                    # print(encoder_keys[4])
                    tmp_fix_idx = 4 # for the downsized model, idx 5 is the layer number
                    layer_num = int(encoder_keys[tmp_fix_idx]) # 4
                    if layer_num<6:
                        del state_dict[key]
                        continue
                    else:
                        decoder_layer_num = (layer_num-6)
                        encoder_keys[4] = str(decoder_layer_num)
                        encoder_key = '.'.join(encoder_keys)
                else:
                    encoder_key = key
                decoder_key = encoder_key.replace('text_encoder','text_decoder')
                state_dict[decoder_key] = state_dict[key]

                del state_dict[key]

    msg = model.load_state_dict(state_dict,strict=False)
    print('load checkpoint from %s'%args.checkpoint)
    print(msg)


In [None]:
# handle distributed training
model_without_ddp = model
if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    model_without_ddp = model.module


In [None]:
# run evaluation without training single GPU
print("Start eval")
start_time = time.time()

for epoch in range(start_epoch, max_epoch):
    if epoch>0:
        lr_scheduler.step(epoch+warmup_steps)

    if not args.evaluate:
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)

        train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)

    if args.evaluate:
        break

    if utils.is_main_process():
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                      'epoch': epoch,
                    }
        with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
            f.write(json.dumps(log_stats) + "\n")

        save_obj = {
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'config': config,
            'epoch': epoch,
        }
        torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))

        # save result to google drive
        !cp -r {args.output_dir} {GOOGLE_DRIVE_PATH}

    if args.distributed:
        dist.barrier()
    else:
        pass  # Skip barrier for non-distributed training

# evaluation
vqa_result = evaluation(model, test_loader, tokenizer, device, config)
result_file = save_result(vqa_result, args.result_dir, 'vqa_result_epoch%d'%epoch)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Time time {}'.format(total_time_str))

In [None]:
print(f'google drive: {GOOGLE_DRIVE_PATH} from colab drive: {args.output_dir}')

In [None]:
# save result to google drive
!cp -r {args.output_dir} {GOOGLE_DRIVE_PATH}

In [None]:
# terminate colab runtime
from google.colab import runtime
runtime.unassign()
