### 数据集加载

In [None]:
from datasets import load_dataset
from dataclasses import dataclass
from bias_tuning.datasets import build_transform 

from functools import partial

In [2]:
@dataclass
class DatasetConfig:
    input_size: int = 224
    color_jitter: float = 0.3
    aa: str = 'rand-m9-mstd0.5-inc1'
    train_interpolation: str = 'bicubic'
    reprob: float = 0.25
    remode: str = 'pixel'
    recount: int = 1
    eval_crop_ratio: float = 0.875

data_config = DatasetConfig()

In [35]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained('./vit-base-patch16-224')

def process_example(example):
    if example['image'].mode != 'RGB':
        example['image'] = example['image'].convert('RGB')
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['label']
    return inputs

trainset = load_dataset('/data/jc/dataset/imagenet-1k', split='train', streaming=True)
valset = load_dataset('/data/jc/dataset/imagenet-1k', split='validation', streaming=True)

prepared_trainset = trainset.map(process_example)
prepared_valset = valset.map(process_example)

### 模型搭建

In [None]:
import low_rank
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('./vit-base-patch16-224')

count_params = sum(p.numel() for p in model.parameters())
model_lr = low_rank.ModuleLowRank(compress_ratio=2, 
                                name_omit=['norm', 'head', 'patch_embed', 'downsample'],
                                is_approximate=True)
model = model_lr(model)
count_lr_params = sum(p.numel() for p in model.parameters())

print(f'Original model params: {count_params}, Low rank model params: {count_lr_params}')

# only set bias to be optimized
for name, param in model.named_parameters():
    if 'bias' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

count_learn_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Learnable params: {count_learn_params}')

### Train

In [47]:
import torch 

def collect_fn(batch):
    batch = {
        'pixel_values': torch.cat([x['pixel_values'] for x in batch], dim=0),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    return batch

In [None]:
import numpy as np 
import evaluate

metric = evaluate.load('accuracy')

def compulate_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return metric.compute(predictions=preds, references=labels)

In [50]:
from transformers import Trainer
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    # train epochs and batch size
    num_train_epochs=50,              # total number of training epochs
    per_device_train_batch_size=128,  # batch size per device during training
    per_device_eval_batch_size=128,   # batch size for evaluation
    max_steps= int(50*1.2e7/128),

    # learning rate and warmup steps
    learning_rate=5e-4,               # initial learning rate
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    
    # logging
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    report_to='tensorboard',

    # evaluation
    eval_steps=1000,
    evaluation_strategy='steps',
    save_strategy='steps',
    save_total_limit=3,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=prepared_trainset,         # training dataset
    eval_dataset=prepared_valset,             # evaluation dataset
    tokenizer=processor,
    data_collator=collect_fn,
    compute_metrics=compulate_metrics,
)

trainer.train()
trainer.evaluate()

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss,Validation Loss,Accuracy
5,No log,2.613436,0.52182
