In [None]:
#!/usr/bin/env python3
""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import sys
sys.path.append('/cephfs/juxin/QwT/QwT-cls-RepQ-ViT')
import argparse
import copy
import random
import socket
from contextlib import suppress
from functools import partial
from tqdm import tqdm

import torch.distributed
import torch.distributed as dist
import torch.utils.data
from timm.data import Mixup
from timm.data.dataset import ImageDataset
from timm.loss import SoftTargetCrossEntropy
from timm.utils import random_seed, NativeScaler, accuracy
from torch.amp import autocast as amp_autocast
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.scheduler.scheduler_factory import CosineLRScheduler
from torch.utils.data import Dataset

from quant import *
from utils import *
from utils.utils import write, create_transform, create_loader, AverageMeter, broadcast_tensor_from_main_process, gather_tensor_from_multi_processes, compute_quantized_params

HOST_NAME = socket.getfqdn(socket.gethostname())

torch.backends.cudnn.benchmark = True
LINEAR_COMPENSATION_SAMPLES = 128

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

class CompensationBlock(nn.Module):
    def __init__(self, W, b, r2_score, block, linear_init=True, local_rank=0, block_id=None):
        super(CompensationBlock, self).__init__()
        self.block = block

        self.lora_weight = nn.Parameter(torch.zeros((W.size(0), W.size(1))))
        self.lora_bias = nn.Parameter(torch.zeros(W.size(1)))

        if linear_init and (r2_score > 0):
            self.lora_weight.data.copy_(W)
            self.lora_bias.data.copy_(b)
            if local_rank == 0:
                _write('block {} using linear init'.format(block_id))
        else:
            nn.init.zeros_(self.lora_weight)
            nn.init.zeros_(self.lora_bias)
            if local_rank == 0:
                _write('block {} using lora init'.format(block_id))

    def forward(self, x):
        out = self.block(x)
        if self.training:
            lora_weight = self.lora_weight.float()
            out = out + x @ lora_weight + self.lora_bias
        else:
            # QwT layers run in half mode
            lora_weight = self.lora_weight.half()
            out = out + (x.half() @ lora_weight).float() + self.lora_bias

        return out

def enable_quant(submodel):
    for name, module in submodel.named_modules():
        if isinstance(module, QuantConv2d) or isinstance(module, QuantLinear) or isinstance(module, QuantMatMul):
            module.set_quant_state(True, True)

def disable_quant(submodel):
    for name, module in submodel.named_modules():
        if isinstance(module, QuantConv2d) or isinstance(module, QuantLinear) or isinstance(module, QuantMatMul):
            module.set_quant_state(False, False)

class FeatureDataset(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, item):
        return self.X[item]

def lienar_regression(X, Y, block_id=0):
    X = X.reshape(-1, X.size(-1))

    X = gather_tensor_from_multi_processes(X, world_size=args.world_size)

    X_add_one = torch.cat([X, torch.ones(size=[X.size(0), ], device=X.device).reshape(-1, 1)], dim=-1)
    Y = Y.reshape(-1, Y.size(-1))

    Y = gather_tensor_from_multi_processes(Y, world_size=args.world_size)

    # _write('the shape of X_add_one is {}, Y is {}'.format(X_add_one.size(), Y.size()))

    X_add_one_T = X_add_one.t()
    W_overall = torch.inverse(X_add_one_T @ X_add_one) @ X_add_one_T @ Y

    W = W_overall[:-1, :]
    b = W_overall[-1, :]

    Y_pred = X @ W + b

    abs_loss = (Y - Y_pred).abs().mean()

    ss_tot = torch.sum((Y - Y.mean(dim=0)).pow(2))
    ss_res = torch.sum((Y - Y_pred).pow(2))
    r2_score = 1 - ss_res / ss_tot

    # _write('block : {}      abs : {:.6f}      r2 : {:.3f}'.format(block_id, abs_loss, r2_score))

    return W, b, r2_score



parser = argparse.ArgumentParser()
parser.add_argument("--model", default="vit_small", choices=['vit_small', 'vit_base', 'deit_tiny', 'deit_small', 'deit_base', 'deit_tiny_distilled', 'deit_small_distilled', 'deit_base_distilled'], help="model")
parser.add_argument('--data_dir', default='../ImageNet', type=str)

parser.add_argument('--w_bits', default=4, type=int, help='bit-precision of weights')
parser.add_argument('--a_bits', default=4, type=int, help='bit-precision of activation')
parser.add_argument('--start_block', default=0, type=int)

parser.add_argument("--batch_size", default=32, type=int, help="batchsize of validation set")
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument("--seed", default=0, type=int, help="seed")

parser.add_argument("--local-rank", default=0, type=int)
args = parser.parse_args(args=[])

train_aug = 'large_scale_train'
test_aug = 'large_scale_test'
args.drop_path = 0.0
args.num_classes = 1000

model_type = args.model.split("_")[0]
if model_type == "deit":
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    crop_pct = 0.875
elif model_type == 'vit':
    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    crop_pct = 0.9
elif model_type == 'swin':
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    crop_pct = 0.9
else:
    raise NotImplementedError

args.distributed = False
if 'WORLD_SIZE' in os.environ:
    args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0  # global rank
if args.distributed:
    args.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    args.rank = torch.distributed.get_rank()

assert args.rank >= 0


args.log_dir = os.path.join('checkpoint', args.model, 'QwT', 'bs_{}_worldsize_{}_w_{}_a_{}_startblock_{}_sed_{}'.format(args.batch_size, args.world_size, args.w_bits, args.a_bits, args.start_block, args.seed))

args.log_file = os.path.join(args.log_dir, 'log.txt')


if args.local_rank == 0:
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    if os.path.isfile(args.log_file):
        os.remove(args.log_file)

torch.cuda.synchronize()

_write = partial(write, log_file=args.log_file)

if args.distributed:
    _write('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size))
else:
    _write('Training with a single process on 1 GPUs.')
assert args.rank >= 0



if args.local_rank == 0:
    _write(args)

seed(args.seed)

if args.local_rank == 0:
    _write('dataset mean : {} & std : {}'.format(mean, std))

dataset_train = ImageDataset(root=os.path.join(args.data_dir, 'train'), transform=create_transform(train_aug, mean, std, crop_pct))
dataset_eval = ImageDataset(root=os.path.join(args.data_dir, 'val'), transform=create_transform(test_aug, mean, std, crop_pct))

if args.local_rank == 0:
    _write('len of train_set : {}    train_transform : {}'.format(len(dataset_train), dataset_train.transform))
    _write('len of eval_set : {}    eval_transform : {}'.format(len(dataset_eval), dataset_eval.transform))

loader_train = create_loader(
    dataset_train,
    batch_size=args.batch_size,
    is_training=True,
    re_prob=0.0,
    mean=mean,
    std=std,
    num_workers=args.num_workers,
    distributed=args.distributed,
    log_file=args.log_file,
    drop_last=True,
    local_rank=args.local_rank,
    persistent_workers=False
)
loader_eval = create_loader(
    dataset_eval,
    batch_size=args.batch_size,
    is_training=False,
    re_prob=0.,
    mean=mean,
    std=std,
    num_workers=args.num_workers,
    distributed=args.distributed,
    log_file=args.log_file,
    drop_last=False,
    local_rank=args.local_rank,
    persistent_workers=False
)

model_zoo = {
    'vit_small' : 'vit_small_patch16_224',
    'vit_base' : 'vit_base_patch16_224',

    'deit_tiny' : 'deit_tiny_patch16_224',
    "deit_tiny_distilled" : "deit_tiny_distilled_patch16_224",
    'deit_small': 'deit_small_patch16_224',
    "deit_small_distilled": "deit_small_distilled_patch16_224",
    'deit_base' : 'deit_base_patch16_224',
    "deit_base_distilled": "deit_base_distilled_patch16_224",
}
#Quant using RepQ-ViT
_write('Building model ...')
model = build_model(model_zoo[args.model], args)
model.to(args.device)
'''
model.eval()
top1_acc_eval = validate(model, loader_eval)
_write('base eval_acc: {:.2f}'.format(top1_acc_eval.avg))
with open(f'log/{args.model}/acc.log', 'a') as f:
    f.writelines(f'base: {top1_acc_eval.avg}\n')
base_model = copy.deepcopy(model)
wq_params = {'n_bits': args.w_bits, 'channel_wise': True}
aq_params = {'n_bits': args.a_bits, 'channel_wise': False}
# q_model = quant_model(model, input_quant_params=aq_params, weight_quant_params=wq_params)
'''
q_model = model
q_model.to(args.device)
q_model.eval()

os.makedirs(f'log/{args.model}', exist_ok=True)
with open(f'log/{args.model}/structure_quant.txt', 'w') as f:
    f.write(str(q_model))


In [21]:
f = open('log/123.log', 'w')
@torch.no_grad()
def generate_compensation_model(q_model, train_loader, args):
    _write('start to generate compensation model')

    torch.cuda.synchronize()
    output_t = torch.zeros(size=[0,], device=args.device)
    f.writelines(f'LINEAR_COMPENSATION_SAMPLES: {LINEAR_COMPENSATION_SAMPLES}\n')
    for i, (image, _) in tqdm(enumerate(train_loader)):
        f.writelines(f'image.shape: {image.shape}\t')
        image = image.cuda()
        t_out = q_model.forward_before_blocks(image)
        f.writelines(f't_out.shape: {t_out.shape}\n')
        output_t = torch.cat([output_t, t_out.detach()], dim=0)
        torch.cuda.synchronize()
        if i >= (LINEAR_COMPENSATION_SAMPLES // args.batch_size // args.world_size - 1):
            break

    f.writelines(f'output_t.shape: {output_t.shape}\n')
    feature_set = FeatureDataset(output_t.detach().cpu())
    feature_loader = torch.utils.data.DataLoader(feature_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    output_previous = output_t
    for block_id in range(len(q_model.blocks)):

        feature_set.X = output_previous.detach().cpu()

        block = q_model.blocks[block_id]
        output_full_precision = torch.zeros(size=[0, ], device=args.device)
        output_quant = torch.zeros(size=[0, ], device=args.device)
        output_t_ = torch.zeros(size=[0, ], device=args.device)
        for i, t_out in tqdm(enumerate(feature_loader)):
            t_out = t_out.cuda()
            if block_id == 0:
                f.writelines(f't_out=output_t[i*32:i*32+32,:,:]: {output_t[i*32:i*32+32,:,:].equal(t_out)}\n')
            disable_quant(block)
            full_precision_out = block(t_out)
            if block_id == 0:
                f.writelines(f't_out.shape: {t_out.shape}\t')
                f.writelines(f'full_precision_out.shape: {full_precision_out.shape}\n')

            enable_quant(block)
            quant_out = block(t_out)

            output_t_ = torch.cat([output_t_, t_out.detach()], dim=0)
            output_full_precision = torch.cat([output_full_precision, full_precision_out.detach()], dim=0)
            output_quant = torch.cat([output_quant, quant_out.detach()], dim=0)

            torch.cuda.synchronize()
            if i >= (LINEAR_COMPENSATION_SAMPLES // args.batch_size  // args.world_size - 1):
                break

        assert torch.sum((output_previous - output_t_).abs()) < 1e-3
        if block_id == 0:
            f.writelines(f'output_t_.shape: {output_t_.shape} {output_full_precision.shape}\n')
        W, b, r2_score = lienar_regression(output_t_, output_full_precision - output_quant, block_id=block_id)
        q_model.blocks[block_id] = CompensationBlock(W=W, b=b, r2_score=r2_score, block=q_model.blocks[block_id], linear_init=True if block_id >= args.start_block else False, local_rank=args.local_rank, block_id=block_id)
        q_model.cuda()

        qwerty_block = q_model.blocks[block_id]

        output_previous = torch.zeros(size=[0, ], device=args.device)
        for i, t_out in tqdm(enumerate(feature_loader)):
            t_out = t_out.cuda()
            enable_quant(qwerty_block)
            previous_out = qwerty_block(t_out)

            output_previous = torch.cat([output_previous, previous_out.detach()], dim=0)

            torch.cuda.synchronize()
            if i >= (LINEAR_COMPENSATION_SAMPLES // args.batch_size // args.world_size - 1):
                break

    return q_model

q_model = generate_compensation_model(q_model, loader_train, args)
f.close()

start to generate compensation model


3it [00:01,  1.52it/s]
3it [00:00, 34.76it/s]

block 0 using lora init



3it [00:00, 44.65it/s]
3it [00:00, 37.37it/s]

block 1 using lora init



3it [00:00, 37.29it/s]
3it [00:00, 34.06it/s]

block 2 using lora init



3it [00:00, 42.85it/s]
3it [00:00, 38.99it/s]


block 3 using lora init


3it [00:00, 52.43it/s]
3it [00:00, 30.45it/s]


block 4 using lora init


3it [00:00, 44.64it/s]
3it [00:00, 37.75it/s]


block 5 using lora init


3it [00:00, 47.89it/s]
3it [00:00, 38.64it/s]

block 6 using lora init



3it [00:00, 38.21it/s]
3it [00:00, 38.16it/s]


block 7 using lora init


3it [00:00, 48.93it/s]
3it [00:00, 37.51it/s]


block 8 using lora init


3it [00:00, 29.97it/s]
3it [00:00, 32.12it/s]


block 9 using lora init


3it [00:00, 30.86it/s]
3it [00:00, 37.53it/s]


block 10 using lora init


3it [00:00, 43.35it/s]
3it [00:00, 38.33it/s]


block 11 using lora init


3it [00:00, 52.69it/s]
