In [1]:
import json
import os
import argparse

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from pathlib import Path
from torch.utils.data import DataLoader, Subset
import sys
sys.path.append('../lib/exlib/src')
from exlib.modules.sop import SOPImageCls, SOPConfig, get_chained_attr

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 42
if SEED != -1:
    # Torch RNG
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    # Python RNG
    np.random.seed(SEED)
    random.seed(SEED)

In [3]:
# model paths
backbone_model_name = '../pt_models/vit-base-patch16-224-imagenet10cls'
backbone_processor_name = 'google/vit-base-patch16-224'
# sop_config_path = 'configs/imagenet_m.json'

# data paths
TRAIN_DATA_DIR = '../data/imagenet_m/train'
VAL_DATA_DIR = '../data/imagenet_m/val'

# training args
batch_size = 16
lr = 0.00001
num_epochs = 20
warmup_steps = 20
mask_batch_size = 64

# experiment args
exp_dir = '../exps/imagenet_m_2h_debug'
os.makedirs(exp_dir, exist_ok=True)

In [4]:
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_model_name)
processor = AutoImageProcessor.from_pretrained(backbone_processor_name)
backbone_config = AutoConfig.from_pretrained(backbone_model_name)

config = SOPConfig(
    attn_patch_size=16,
    num_heads=2,
    num_masks_sample=20,
    num_masks_max=200,
    finetune_layers=['model.classifier']
)
config.__dict__.update(backbone_config.__dict__)
config.num_labels = len(backbone_config.label2id)
# config.save_pretrained(exp_dir)

In [5]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

def transform(image):
    # Preprocess the image using the ViTImageProcessor
    image = image.convert("RGB")
    inputs = processor(image, return_tensors='pt')
    return inputs['pixel_values'].squeeze(0)

# Load the dataset
train_dataset = ImageFolder(root=TRAIN_DATA_DIR, transform=transform)
val_dataset = ImageFolder(root=VAL_DATA_DIR, transform=transform)

# Use subset for testing purpose
num_data = 100
train_dataset = Subset(train_dataset, range(num_data))
val_dataset = Subset(val_dataset, range(num_data))

# Create a DataLoader to batch and shuffle the data
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [6]:
from collections import namedtuple

WrappedBackboneOutput = namedtuple("WrappedBackboneOutput", 
                                  ["logits",
                                   "pooler_output"])


class WrappedBackboneModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, inputs):
        outputs = self.model(inputs, output_hidden_states=True)
        return WrappedBackboneOutput(outputs.logits, outputs.hidden_states[-1][:,0])

In [7]:
wrapped_backbone_model = WrappedBackboneModel(backbone_model)
wrapped_backbone_model = wrapped_backbone_model.to(device)
class_weights = get_chained_attr(wrapped_backbone_model, config.finetune_layers[0]).weight #.clone().to(device)

In [8]:
model = SOPImageCls(config, wrapped_backbone_model, class_weights=class_weights, projection_layer=None)
model = model.to(device)

In [9]:
model.class_weights

Parameter containing:
tensor([[-0.0185,  0.0317, -0.0175,  ..., -0.0125,  0.0481, -0.0027],
        [-0.0412,  0.0001,  0.0059,  ...,  0.0216,  0.0190, -0.0097],
        [-0.0507,  0.0458, -0.0032,  ..., -0.0404,  0.0268,  0.0041],
        ...,
        [ 0.0043, -0.0497,  0.0034,  ...,  0.0272, -0.0130,  0.0116],
        [-0.0034, -0.0683,  0.0002,  ...,  0.0356, -0.0020,  0.0197],
        [-0.0093,  0.0100, -0.0089,  ...,  0.0260,  0.0024, -0.0216]],
       device='cuda:0', requires_grad=True)

In [10]:
from transformers import get_scheduler

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
num_training_steps = len(train_dataloader) * num_epochs
lr_scheduler = get_scheduler(
            'inverse_sqrt',
            optimizer=optimizer, 
            num_warmup_steps=warmup_steps
        )
criterion = nn.CrossEntropyLoss()

In [11]:
def eval(model, dataloader, criterion):
    print('Eval ...')
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    total_nnz = 0
    total_num_masks = 0
    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(dataloader)))
        for i, batch in enumerate(dataloader):
            # Now you can use `inputs` and `labels` in your training loop.
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs, return_tuple=True)
            
            logits = outputs.logits
            
            for i in range(len(logits)):
                pred = logits[i].argmax(-1).item()

                pred_mask_idxs_sort = outputs.mask_weights[i,:,pred].argsort(descending=True)
                mask_weights_sort = (outputs.mask_weights * outputs.logits_all)[i,pred_mask_idxs_sort,pred]
                masks_sort = outputs.masks[0,pred_mask_idxs_sort]
                masks_sort_used = (masks_sort[mask_weights_sort > 0] > masks_sort[mask_weights_sort > 0].mean()).int()
                mask_weights_sort_used = mask_weights_sort[mask_weights_sort > 0]
                nnz = (masks_sort[mask_weights_sort > 0] > 0).sum() / masks_sort[mask_weights_sort > 0].view(-1).shape[0]
                total_nnz += nnz
                total_num_masks += len(masks_sort_used)
                
                # import pdb; pdb.set_trace()
            
            # val loss
            
            loss = criterion(logits, labels)
            total_loss += loss.item()
            
            # acc
            _, predicted = torch.max(logits.data, 1)
            correct += (predicted == labels).sum().item()
            
            total += labels.size(0)
            
            progress_bar_eval.update(1)
    
    val_acc = correct / total
    val_loss = total_loss / total
    val_nnz = total_nnz / total
    val_n_masks_avg = total_num_masks / total
    
    results = {
        'val_acc': val_acc,
        'val_loss': val_loss,
        'val_nnz': val_nnz,
        'val_n_masks_avg': val_n_masks_avg
    }
    
    print(results)
    
    model.train()
    
    return results

In [12]:
import logging

# track = True
track = False

if track:
    import wandb
    wandb.init(project='sop')
    wandb.run.name = os.path.basename(exp_dir)

# Iterate over the data
best_val_acc = 0.0
step = 0
train_log_interval = 10
val_eval_interval = 100

print('lr', lr)

logging.basicConfig(filename=os.path.join(exp_dir, 'train.log'), level=logging.INFO)

progress_bar = tqdm(range(num_training_steps))
print(num_training_steps)
for epoch in range(num_epochs):
    running_loss = 0.0
    running_total = 0
    print('train_dataloader', len(train_dataloader))
    for i, batch in enumerate(train_dataloader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(inputs, mask_batch_size=mask_batch_size)
        loss = criterion(logits, labels)
        
        # make_dot(loss.mean(), params=dict(model.named_parameters()))
        # break
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * labels.size(0)
        running_total += labels.size(0)
        
        if i % train_log_interval == train_log_interval - 1 or i == len(train_dataloader) - 1:
            # Print training loss every 100 batches
            curr_lr = float(optimizer.param_groups[0]['lr'])
            log_message = f'Epoch {epoch}, Batch {i + 1}, Loss {running_loss / running_total:.4f}, LR {curr_lr:.8f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'train_loss': running_loss / running_total,
                        'lr': curr_lr,
                        'epoch': epoch,
                        'step': step})
            running_loss = 0.0
            running_total = 0
            
        if i % val_eval_interval == val_eval_interval - 1 or i == len(train_dataloader) - 1:
            val_results = eval(model, val_dataloader, criterion)
            val_acc = val_results['val_acc']
            val_loss = val_results['val_loss']
            log_message = f'Epoch {epoch}, Step {step}, Val acc {val_acc:.4f}, Val loss {val_loss:.4f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'val_acc': val_acc,
                           'val_loss': val_loss,
                        'epoch': epoch,
                        'step': step})
            
            last_dir = os.path.join(exp_dir, 'last')
            best_dir = os.path.join(exp_dir, 'best')
            checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                }
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                os.makedirs(best_dir, exist_ok=True)
                best_checkpoint_path = os.path.join(best_dir, 'checkpoint.pth')
                torch.save(checkpoint, best_checkpoint_path)
                config_best_checkpoint_path = os.path.join(best_dir, 'config.json')
                config.save_to_json(config_best_checkpoint_path)
                print(f'Best checkpoint saved at {best_checkpoint_path}')
                
                # model.save_pretrained(best_dir)
            # model.save_pretrained(last_dir)
            os.makedirs(last_dir, exist_ok=True)
            last_checkpoint_path = os.path.join(last_dir, 'checkpoint.pth')
            torch.save(checkpoint, last_checkpoint_path)
            config_last_checkpoint_path = os.path.join(last_dir, 'config.json')
            config.save_to_json(config_best_checkpoint_path)
            print(f'Last checkpoint saved at {last_checkpoint_path}')
            
        lr_scheduler.step()
        progress_bar.update(1)
        
        step += 1
    # break
        
model.save(exp_dir)

lr 1e-05


  0%|          | 0/140 [00:00<?, ?it/s]

140
train_dataloader 7


  return F.conv2d(input, weight, bias, self.stride,


Epoch 0, Batch 7, Loss 1.5250, LR 0.00000300
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.52, 'val_loss': 0.11529592156410218, 'val_nnz': tensor(0.0334, device='cuda:0'), 'val_n_masks_avg': 2.74}
Epoch 0, Step 6, Val acc 0.5200, Val loss 0.1153
Best checkpoint saved at ../exps/imagenet_m_2h_debug/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 1, Batch 7, Loss 1.5385, LR 0.00000650
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.59, 'val_loss': 0.10896327555179595, 'val_nnz': tensor(0.0336, device='cuda:0'), 'val_n_masks_avg': 2.8}
Epoch 1, Step 13, Val acc 0.5900, Val loss 0.1090
Best checkpoint saved at ../exps/imagenet_m_2h_debug/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 2, Batch 7, Loss 1.2022, LR 0.00001000
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.61, 'val_loss': 0.09831135392189026, 'val_nnz': tensor(0.0326, device='cuda:0'), 'val_n_masks_avg': 2.99}
Epoch 2, Step 20, Val acc 0.6100, Val loss 0.0983
Best checkpoint saved at ../exps/imagenet_m_2h_debug/best/checkpoint.pth
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 3, Batch 7, Loss 1.0434, LR 0.00000861
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.61, 'val_loss': 0.10319440811872482, 'val_nnz': tensor(0.0325, device='cuda:0'), 'val_n_masks_avg': 2.7}
Epoch 3, Step 27, Val acc 0.6100, Val loss 0.1032
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 4, Batch 7, Loss 0.8724, LR 0.00000767
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.59, 'val_loss': 0.11016914665699006, 'val_nnz': tensor(0.0327, device='cuda:0'), 'val_n_masks_avg': 2.63}
Epoch 4, Step 34, Val acc 0.5900, Val loss 0.1102
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 5, Batch 7, Loss 0.8203, LR 0.00000698
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.56, 'val_loss': 0.10996847093105316, 'val_nnz': tensor(0.0326, device='cuda:0'), 'val_n_masks_avg': 2.87}
Epoch 5, Step 41, Val acc 0.5600, Val loss 0.1100
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 6, Batch 7, Loss 0.7557, LR 0.00000645
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.54, 'val_loss': 0.12349519103765488, 'val_nnz': tensor(0.0323, device='cuda:0'), 'val_n_masks_avg': 2.6}
Epoch 6, Step 48, Val acc 0.5400, Val loss 0.1235
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 7, Batch 7, Loss 0.6893, LR 0.00000603
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.53, 'val_loss': 0.12035199970006943, 'val_nnz': tensor(0.0339, device='cuda:0'), 'val_n_masks_avg': 2.55}
Epoch 7, Step 55, Val acc 0.5300, Val loss 0.1204
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 8, Batch 7, Loss 0.6353, LR 0.00000568
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.54, 'val_loss': 0.12330398544669151, 'val_nnz': tensor(0.0337, device='cuda:0'), 'val_n_masks_avg': 2.74}
Epoch 8, Step 62, Val acc 0.5400, Val loss 0.1233
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 9, Batch 7, Loss 0.6185, LR 0.00000538
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.54, 'val_loss': 0.1285986825823784, 'val_nnz': tensor(0.0329, device='cuda:0'), 'val_n_masks_avg': 2.54}
Epoch 9, Step 69, Val acc 0.5400, Val loss 0.1286
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 10, Batch 7, Loss 0.5491, LR 0.00000513
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.52, 'val_loss': 0.12666562601923942, 'val_nnz': tensor(0.0340, device='cuda:0'), 'val_n_masks_avg': 2.68}
Epoch 10, Step 76, Val acc 0.5200, Val loss 0.1267
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 11, Batch 7, Loss 0.5290, LR 0.00000491
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.53, 'val_loss': 0.1359690722823143, 'val_nnz': tensor(0.0340, device='cuda:0'), 'val_n_masks_avg': 2.42}
Epoch 11, Step 83, Val acc 0.5300, Val loss 0.1360
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 12, Batch 7, Loss 0.5870, LR 0.00000471
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.53, 'val_loss': 0.1257265503704548, 'val_nnz': tensor(0.0341, device='cuda:0'), 'val_n_masks_avg': 2.52}
Epoch 12, Step 90, Val acc 0.5300, Val loss 0.1257
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 13, Batch 7, Loss 0.5902, LR 0.00000454
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.54, 'val_loss': 0.12587388277053832, 'val_nnz': tensor(0.0351, device='cuda:0'), 'val_n_masks_avg': 2.41}
Epoch 13, Step 97, Val acc 0.5400, Val loss 0.1259
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 14, Batch 7, Loss 0.5906, LR 0.00000439
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.52, 'val_loss': 0.12749230310320855, 'val_nnz': tensor(0.0357, device='cuda:0'), 'val_n_masks_avg': 2.49}
Epoch 14, Step 104, Val acc 0.5200, Val loss 0.1275
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 15, Batch 7, Loss 0.5731, LR 0.00000424
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.51, 'val_loss': 0.141374394595623, 'val_nnz': tensor(0.0356, device='cuda:0'), 'val_n_masks_avg': 2.37}
Epoch 15, Step 111, Val acc 0.5100, Val loss 0.1414
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 16, Batch 7, Loss 0.5545, LR 0.00000412
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.54, 'val_loss': 0.1364646652340889, 'val_nnz': tensor(0.0361, device='cuda:0'), 'val_n_masks_avg': 2.36}
Epoch 16, Step 118, Val acc 0.5400, Val loss 0.1365
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 17, Batch 7, Loss 0.4611, LR 0.00000400
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.52, 'val_loss': 0.14457677200436592, 'val_nnz': tensor(0.0351, device='cuda:0'), 'val_n_masks_avg': 2.57}
Epoch 17, Step 125, Val acc 0.5200, Val loss 0.1446
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 18, Batch 7, Loss 0.4900, LR 0.00000389
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.51, 'val_loss': 0.14386578559875488, 'val_nnz': tensor(0.0342, device='cuda:0'), 'val_n_masks_avg': 2.4}
Epoch 18, Step 132, Val acc 0.5100, Val loss 0.1439
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
train_dataloader 7
Epoch 19, Batch 7, Loss 0.4861, LR 0.00000379
Eval ...


  0%|          | 0/7 [00:00<?, ?it/s]

{'val_acc': 0.52, 'val_loss': 0.140265831053257, 'val_nnz': tensor(0.0340, device='cuda:0'), 'val_n_masks_avg': 2.41}
Epoch 19, Step 139, Val acc 0.5200, Val loss 0.1403
Last checkpoint saved at ../exps/imagenet_m_2h_debug/last/checkpoint.pth
Saved model to ../exps/imagenet_m_2h_debug
