In [1]:
import json
import os
import argparse

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from pathlib import Path
from torch.utils.data import DataLoader, Subset
import sys
sys.path.append('../lib/exlib/src')
from exlib.modules.sop import SOPImageCls, SOPConfig, get_chained_attr

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 42
if SEED != -1:
    # Torch RNG
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    # Python RNG
    np.random.seed(SEED)
    random.seed(SEED)

In [3]:
# model paths
backbone_model_name = '../pt_models/vit-base-patch16-224-imagenet10cls'
backbone_processor_name = 'google/vit-base-patch16-224'
# sop_config_path = 'configs/imagenet_m.json'

# data paths
TRAIN_DATA_DIR = '../data/imagenet_m/train'
VAL_DATA_DIR = '../data/imagenet_m/val'

# training args
batch_size = 16
lr = 0.000005
num_epochs = 20
warmup_steps = 2000
mask_batch_size = 64

# experiment args
exp_dir = '../exps/imagenet_m'
os.makedirs(exp_dir, exist_ok=True)

In [4]:
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_model_name)
processor = AutoImageProcessor.from_pretrained(backbone_processor_name)
backbone_config = AutoConfig.from_pretrained(backbone_model_name)

config = SOPConfig(
    attn_patch_size=16,
    num_heads=1,
    num_masks_sample=20,
    num_masks_max=200,
    finetune_layers=['model.classifier']
)
config.__dict__.update(backbone_config.__dict__)
config.num_labels = len(backbone_config.label2id)
# config.save_pretrained(exp_dir)

In [5]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

def transform(image):
    # Preprocess the image using the ViTImageProcessor
    image = image.convert("RGB")
    inputs = processor(image, return_tensors='pt')
    return inputs['pixel_values'].squeeze(0)

# Load the dataset
train_dataset = ImageFolder(root=TRAIN_DATA_DIR, transform=transform)
val_dataset = ImageFolder(root=VAL_DATA_DIR, transform=transform)

# Use subset for testing purpose
# num_data = 100
# train_dataset = Subset(train_dataset, range(num_data))
# val_dataset = Subset(val_dataset, range(num_data))

# Create a DataLoader to batch and shuffle the data
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [6]:
from collections import namedtuple

WrappedBackboneOutput = namedtuple("WrappedBackboneOutput", 
                                  ["logits",
                                   "pooler_output"])


class WrappedBackboneModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, inputs):
        outputs = self.model(inputs, output_hidden_states=True)
        return WrappedBackboneOutput(outputs.logits, outputs.hidden_states[-1][:,0])

In [7]:
wrapped_backbone_model = WrappedBackboneModel(backbone_model)
wrapped_backbone_model = wrapped_backbone_model.to(device)
class_weights = get_chained_attr(wrapped_backbone_model, config.finetune_layers[0]).weight #.clone().to(device)

In [8]:
model = SOPImageCls(config, wrapped_backbone_model, class_weights=class_weights, projection_layer=None)
model = model.to(device)

deep copy class weights


In [9]:
from transformers import get_scheduler

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
num_training_steps = len(train_dataloader) * num_epochs
lr_scheduler = get_scheduler(
            'inverse_sqrt',
            optimizer=optimizer, 
            num_warmup_steps=warmup_steps
        )
criterion = nn.CrossEntropyLoss()

In [10]:
def eval(model, dataloader, criterion):
    print('Eval ...')
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(dataloader)))
        for i, batch in enumerate(dataloader):
            # Now you can use `inputs` and `labels` in your training loop.
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            logits = model(inputs)
            
            # val loss
            loss = criterion(logits, labels)
            total_loss += loss.item()
            
            # acc
            _, predicted = torch.max(logits.data, 1)
            correct += (predicted == labels).sum().item()
            
            total += labels.size(0)
            
            progress_bar_eval.update(1)
    
    val_acc = correct / total
    val_loss = total_loss / total
    
    model.train()
    
    return {
        'val_acc': val_acc,
        'val_loss': val_loss
    }

In [None]:
import logging

track = True
# track = False

if track:
    import wandb
    wandb.init(project='sop')
    wandb.run.name = os.path.basename(exp_dir)

# Iterate over the data
best_val_acc = 0.0
step = 0
train_log_interval = 100
val_eval_interval = 1000

logging.basicConfig(filename=os.path.join(exp_dir, 'train.log'), level=logging.INFO)

progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    running_loss = 0.0
    running_total = 0
    for i, batch in enumerate(train_dataloader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(inputs, mask_batch_size=mask_batch_size)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * labels.size(0)
        running_total += labels.size(0)
        
        if i % train_log_interval == train_log_interval - 1 or i == len(train_dataloader) - 1:
            # Print training loss every 100 batches
            curr_lr = float(optimizer.param_groups[0]['lr'])
            log_message = f'Epoch {epoch}, Batch {i + 1}, Loss {running_loss / running_total:.4f}, LR {curr_lr:.8f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'train_loss': running_loss / running_total,
                        'lr': curr_lr,
                        'epoch': epoch,
                        'step': step})
            running_loss = 0.0
            running_total = 0
            
        if i % val_eval_interval == val_eval_interval - 1 or i == len(train_dataloader) - 1:
            val_results = eval(model, val_dataloader, criterion)
            val_acc = val_results['val_acc']
            val_loss = val_results['val_loss']
            log_message = f'Epoch {epoch}, Step {step}, Val acc {val_acc:.4f}, Val loss {val_loss:.4f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'val_acc': val_acc,
                           'val_loss': val_loss,
                        'epoch': epoch,
                        'step': step})
            
            last_dir = os.path.join(exp_dir, 'last')
            best_dir = os.path.join(exp_dir, 'best')
            checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                }
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                os.makedirs(best_dir, exist_ok=True)
                best_checkpoint_path = os.path.join(best_dir, 'checkpoint.pth')
                torch.save(checkpoint, best_checkpoint_path)
                config_best_checkpoint_path = os.path.join(best_dir, 'config.json')
                config.save_to_json(config_best_checkpoint_path)
                print(f'Best checkpoint saved at {best_checkpoint_path}')
                
                # model.save_pretrained(best_dir)
            # model.save_pretrained(last_dir)
            os.makedirs(last_dir, exist_ok=True)
            last_checkpoint_path = os.path.join(last_dir, 'checkpoint.pth')
            torch.save(checkpoint, last_checkpoint_path)
            config_last_checkpoint_path = os.path.join(last_dir, 'config.json')
            config.save_to_json(config_best_checkpoint_path)
            print(f'Last checkpoint saved at {last_checkpoint_path}')
            
        lr_scheduler.step()
        progress_bar.update(1)
        
        step += 1
        
model.save(exp_dir)

[34m[1mwandb[0m: Currently logged in as: [33mfallcat[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/16260 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


Epoch 0, Batch 100, Loss 1.0559, LR 0.00000025
Epoch 0, Batch 200, Loss 0.9883, LR 0.00000050
Epoch 0, Batch 300, Loss 0.7779, LR 0.00000075
Epoch 0, Batch 400, Loss 0.7004, LR 0.00000100
Epoch 0, Batch 500, Loss 0.6542, LR 0.00000125
Epoch 0, Batch 600, Loss 0.5694, LR 0.00000150
Epoch 0, Batch 700, Loss 0.5254, LR 0.00000175
Epoch 0, Batch 800, Loss 0.5095, LR 0.00000200
Epoch 0, Batch 813, Loss 0.4667, LR 0.00000203
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 0, Step 812, Val acc 0.8060, Val loss 0.0336
Best checkpoint saved at ../exps/imagenet_m/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 1, Batch 100, Loss 0.4341, LR 0.00000228
Epoch 1, Batch 200, Loss 0.4446, LR 0.00000253
Epoch 1, Batch 300, Loss 0.3710, LR 0.00000278
Epoch 1, Batch 400, Loss 0.3719, LR 0.00000303
Epoch 1, Batch 500, Loss 0.3902, LR 0.00000328
Epoch 1, Batch 600, Loss 0.4610, LR 0.00000353
Epoch 1, Batch 700, Loss 0.3910, LR 0.00000378
Epoch 1, Batch 800, Loss 0.4060, LR 0.00000403
Epoch 1, Batch 813, Loss 0.3106, LR 0.00000406
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1, Step 1625, Val acc 0.8560, Val loss 0.0271
Best checkpoint saved at ../exps/imagenet_m/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 2, Batch 100, Loss 0.3843, LR 0.00000431
Epoch 2, Batch 200, Loss 0.3632, LR 0.00000456
Epoch 2, Batch 300, Loss 0.3719, LR 0.00000481
Epoch 2, Batch 400, Loss 0.4001, LR 0.00000497
Epoch 2, Batch 500, Loss 0.4232, LR 0.00000485
Epoch 2, Batch 600, Loss 0.3527, LR 0.00000474
Epoch 2, Batch 700, Loss 0.3636, LR 0.00000464
Epoch 2, Batch 800, Loss 0.3434, LR 0.00000454
Epoch 2, Batch 813, Loss 0.4273, LR 0.00000453
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 2, Step 2438, Val acc 0.8680, Val loss 0.0263
Best checkpoint saved at ../exps/imagenet_m/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 3, Batch 100, Loss 0.3837, LR 0.00000444
Epoch 3, Batch 200, Loss 0.3694, LR 0.00000435
Epoch 3, Batch 300, Loss 0.3461, LR 0.00000427
Epoch 3, Batch 400, Loss 0.3209, LR 0.00000420
Epoch 3, Batch 500, Loss 0.3450, LR 0.00000413
Epoch 3, Batch 600, Loss 0.3247, LR 0.00000406
Epoch 3, Batch 700, Loss 0.3175, LR 0.00000399
Epoch 3, Batch 800, Loss 0.3763, LR 0.00000393
Epoch 3, Batch 813, Loss 0.3376, LR 0.00000392
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 3, Step 3251, Val acc 0.8760, Val loss 0.0215
Best checkpoint saved at ../exps/imagenet_m/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 4, Batch 100, Loss 0.3173, LR 0.00000386
Epoch 4, Batch 200, Loss 0.3370, LR 0.00000381
Epoch 4, Batch 300, Loss 0.3347, LR 0.00000375
Epoch 4, Batch 400, Loss 0.3429, LR 0.00000370
Epoch 4, Batch 500, Loss 0.3559, LR 0.00000365
Epoch 4, Batch 600, Loss 0.3749, LR 0.00000360
Epoch 4, Batch 700, Loss 0.3441, LR 0.00000356
Epoch 4, Batch 800, Loss 0.2651, LR 0.00000351
Epoch 4, Batch 813, Loss 0.2991, LR 0.00000351
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 4, Step 4064, Val acc 0.8900, Val loss 0.0211
Best checkpoint saved at ../exps/imagenet_m/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 5, Batch 100, Loss 0.3020, LR 0.00000347
Epoch 5, Batch 200, Loss 0.3595, LR 0.00000342
Epoch 5, Batch 300, Loss 0.2984, LR 0.00000338
Epoch 5, Batch 400, Loss 0.3193, LR 0.00000335
Epoch 5, Batch 500, Loss 0.3058, LR 0.00000331
Epoch 5, Batch 600, Loss 0.2986, LR 0.00000327
Epoch 5, Batch 700, Loss 0.2836, LR 0.00000324
Epoch 5, Batch 800, Loss 0.3101, LR 0.00000321
Epoch 5, Batch 813, Loss 0.3845, LR 0.00000320
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 5, Step 4877, Val acc 0.8860, Val loss 0.0216
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 6, Batch 100, Loss 0.3318, LR 0.00000317
Epoch 6, Batch 200, Loss 0.3022, LR 0.00000314
Epoch 6, Batch 300, Loss 0.2911, LR 0.00000311
Epoch 6, Batch 400, Loss 0.3251, LR 0.00000308
Epoch 6, Batch 500, Loss 0.3230, LR 0.00000305
Epoch 6, Batch 600, Loss 0.3608, LR 0.00000302
Epoch 6, Batch 700, Loss 0.3706, LR 0.00000299
Epoch 6, Batch 800, Loss 0.3164, LR 0.00000297
Epoch 6, Batch 813, Loss 0.2335, LR 0.00000296
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 6, Step 5690, Val acc 0.8880, Val loss 0.0212
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
Epoch 7, Batch 100, Loss 0.3260, LR 0.00000294
Epoch 7, Batch 200, Loss 0.2781, LR 0.00000291
Epoch 7, Batch 300, Loss 0.2744, LR 0.00000289
Epoch 7, Batch 400, Loss 0.2884, LR 0.00000287
Epoch 7, Batch 500, Loss 0.3076, LR 0.00000284
Epoch 7, Batch 600, Loss 0.2894, LR 0.00000282
Epoch 7, Batch 700, Loss 0.2754, LR 0.00000280
Epoch 7, Batch 800, Loss 0.2740, LR 0.00000278
Epoch 7, Batch 813, Loss 0.2405, LR 0.00000277
Eval ...


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 7, Step 6503, Val acc 0.8900, Val loss 0.0211
Last checkpoint saved at ../exps/imagenet_m/last/checkpoint.pth
