In [1]:
import os
os.chdir("/home/zhuoyan/vision/branch_embedding/")
import argparse

import yaml
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.model import Worker
from libs.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse

def parse_args(input_args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, help='config file path')
    parser.add_argument('-n', '--name', type=str, help='job name')
    parser.add_argument('-g', '--gpu', type=str, default='0', help='GPU IDs')
    parser.add_argument('-pf', '--print_freq', type=int, default=1, help='print frequency (x100 itrs)')
    
    args = parser.parse_args(input_args)
    args.print_freq *= 100
    
    return args

# Example
input_args = ['-c', '/home/zhuoyan/vision/branch_embedding/configs/resnet18_cifar10.yaml', 
              '-n', 'train_resnet18_cifar10', 
              '-pf', '2']
args = parse_args(input_args)

print(args)


Namespace(config='/home/zhuoyan/vision/branch_embedding/configs/resnet18_cifar10.yaml', name='train_resnet18_cifar10', gpu='0', print_freq=200)


In [3]:
cfg = load_config(args.config)

In [4]:
cfg['data']

{'dataset': 'cifar10',
 'root': '/backup/zhuoyan/datasets',
 'downsample': True,
 'train_split': 'train',
 'val_split': 'test',
 'batch_size': 64,
 'num_workers': 8}

In [5]:
# set up checkpoint folder
os.makedirs('log', exist_ok=True)
ckpt_path = os.path.join('log', args.name)
ensure_path(ckpt_path)

# load config
try:
    cfg_path = os.path.join(ckpt_path, 'config.yaml')
    check_file(cfg_path)
    cfg = load_config(cfg_path)
    print('config loaded from checkpoint folder')
    cfg['_resume'] = True
    print("load ckpt")
except:
    check_file(args.config)
    cfg = load_config(args.config)
    print('config loaded from command line')
    print("begin")

config loaded from checkpoint folder
load ckpt


In [6]:
ckpt_path

'log/train_resnet18_cifar10'

### worker

In [7]:
cfg['model']

{'branch_enc': {'attn_pdrop': 0.1,
  'embd_dim': 256,
  'embd_type': 0,
  'eos': False,
  'n_heads': 4,
  'n_layers': 5,
  'out_dim': 128,
  'path_pdrop': 0.1,
  'pe_type': 0,
  'proj_pdrop': 0.1,
  'seq_len': 7},
 'branch_vae': {'hid_dim': 32, 'in_dim': 7, 'latent_dim': 2, 'n_layers': 3},
 'content_enc': {'arch': 'resnet8_cifar', 'out_dim': 128, 'pretrained': False},
 'resnet': {'arch': 'resnet18', 'dataset': 'cifar10'}}

In [8]:
worker = Worker(cfg['model'])
yaml.dump(cfg, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))

In [9]:
worker.macs_brk

tensor([0.1512, 0.1362, 0.1020, 0.1358, 0.1018, 0.1357, 0.1017, 0.1356])

In [10]:
worker.resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (maxpool): Identity()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (actv): ReLU(inplace=True)
  (layer1): ModuleList(
    (0-1): 2 x BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Identity()
      (actv): ReLU(inplace=True)
    )
  )
  (layer2): ModuleList(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, 

In [None]:
m

### dataset

In [28]:
cfg['data']

{'dataset': 'cifar10',
 'root': '/backup/zhuoyan/datasets',
 'downsample': True,
 'train_split': 'train',
 'val_split': 'test',
 'batch_size': 64,
 'num_workers': 8}

In [32]:
rng = fix_random_seed(cfg.get('seed', 2023))

In [33]:
train_set = make_dataset(
    dataset=cfg['data']['dataset'],
    root=cfg['data']['root'],
    split=cfg['data']['train_split'], 
    downsample=cfg['data'].get('downsample', False),
)
train_loader = make_data_loader(
    train_set, 
    generator=rng,
    batch_size=cfg['data']['batch_size'],
    num_workers=cfg['data']['num_workers'],
    is_training=True,
)

val_set = make_dataset(
    dataset=cfg['data']['dataset'],
    root=cfg['data']['root'],
    split=cfg['data']['val_split'],
    downsample=cfg['data'].get('downsample', False),
)
val_loader = make_data_loader(
    val_set, 
    generator=rng,
    batch_size=cfg['data']['batch_size'],
    num_workers=cfg['data']['num_workers'],
    is_training=False,
)

itrs_per_epoch = len(train_loader)
print('train data size: {:d}'.format(len(train_set)))
print('number of iterations per epoch: {:d}'.format(itrs_per_epoch))

Files already downloaded and verified
Files already downloaded and verified
train data size: 50000
number of iterations per epoch: 781


### opt, scheduler

In [29]:
cfg['model']

{'resnet': {'arch': 'resnet18', 'dataset': 'cifar10'},
 'branch_enc': {'embd_dim': 256,
  'out_dim': 128,
  'n_heads': 4,
  'n_layers': 5,
  'attn_pdrop': 0.1,
  'proj_pdrop': 0.1,
  'path_pdrop': 0.1,
  'eos': False,
  'embd_type': 0,
  'pe_type': 0,
  'seq_len': 7},
 'content_enc': {'out_dim': 128, 'arch': 'resnet8_cifar', 'pretrained': False},
 'branch_vae': {'hid_dim': 32, 'n_layers': 3, 'latent_dim': 2, 'in_dim': 7}}

### train

In [30]:
cfg['train']

{'n_branches': 64,
 'k': 1,
 'min_n_positives': 64,
 'temperature': 10,
 'delta_rs': 0.5,
 'rank_weight': 1.0,
 'sort_weight': 1.0,
 'bce_weight': 0.0,
 'bce_loss': 0.0,
 'vae_batch_size': 64}

In [18]:
bit_mask = 2 ** torch.arange(7 - 1, -1, -1)

In [19]:
bit_mask

tensor([64, 32, 16,  8,  4,  2,  1])

In [11]:
cfg['train']

{'n_branches': 64,
 'k': 1,
 'min_n_positives': 64,
 'temperature': 10,
 'delta_rs': 0.5,
 'rank_weight': 1.0,
 'sort_weight': 1.0,
 'bce_weight': 0.0,
 'bce_loss': 0.0,
 'vae_batch_size': 64}

In [10]:
cfg['train']['n_branches']

64

In [16]:
b_idx = list(range(64))

b_idx = random.sample(b_idx, 64)
b_idx = torch.LongTensor(b_idx)
b_idx

tensor([ 0, 27, 53, 19, 32, 38, 63, 50, 17,  5, 10, 25, 55, 35, 14, 34, 12, 31,
        56, 39, 13, 61, 40,  6, 45,  8, 46, 54, 11, 36, 18,  3, 26,  9, 24, 60,
        48, 51,  1,  7, 52, 16, 28, 33, 62, 20, 59, 30, 43, 41, 58, 44, 42, 22,
        15, 47, 21, 37, 57, 49, 23, 29,  4,  2])

In [25]:
b_idx[:, None].bitwise_and(bit_mask)

tensor([[ 0,  0,  0,  0,  0,  0,  0],
        [ 0,  0, 16,  8,  0,  2,  1],
        [ 0, 32, 16,  0,  4,  0,  1],
        [ 0,  0, 16,  0,  0,  2,  1],
        [ 0, 32,  0,  0,  0,  0,  0],
        [ 0, 32,  0,  0,  4,  2,  0],
        [ 0, 32, 16,  8,  4,  2,  1],
        [ 0, 32, 16,  0,  0,  2,  0],
        [ 0,  0, 16,  0,  0,  0,  1],
        [ 0,  0,  0,  0,  4,  0,  1],
        [ 0,  0,  0,  8,  0,  2,  0],
        [ 0,  0, 16,  8,  0,  0,  1],
        [ 0, 32, 16,  0,  4,  2,  1],
        [ 0, 32,  0,  0,  0,  2,  1],
        [ 0,  0,  0,  8,  4,  2,  0],
        [ 0, 32,  0,  0,  0,  2,  0],
        [ 0,  0,  0,  8,  4,  0,  0],
        [ 0,  0, 16,  8,  4,  2,  1],
        [ 0, 32, 16,  8,  0,  0,  0],
        [ 0, 32,  0,  0,  4,  2,  1],
        [ 0,  0,  0,  8,  4,  0,  1],
        [ 0, 32, 16,  8,  4,  0,  1],
        [ 0, 32,  0,  8,  0,  0,  0],
        [ 0,  0,  0,  0,  4,  2,  0],
        [ 0, 32,  0,  8,  4,  0,  1],
        [ 0,  0,  0,  8,  0,  0,  0],
        [ 0,

In [27]:
masks = b_idx[:, None].bitwise_and(bit_mask).ne(0)
masks.shape

torch.Size([64, 7])

In [26]:
2**7

128