In [3]:
import torch
from datasets import IndexedDataset, WeightedDataset, load_data
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from samplers import DistributedCustomSampler
from losses import trades_loss, apgd_loss
from tqdm.notebook import tqdm
from architectures import CustomModel, load_architecture, add_lora, set_lora_gradients #load_statedict

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

args = get_args()

args.dataset = 'CIFAR10'
args.selection_method = 'random'
args.aug = 'aug'
args.loss_function = 'APGD'

args.iterations = 10
args.pruning_ratio = 0
args.delta = 1
args.batch_size = 24
args.pruning_strategy = 'random'
args.batch_strategy = 'random'
args.sample_size= 128
args.init_lr = 0.001
args.freeze_epochs = 5

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

train_dataset = WeightedDataset(args, train_dataset, train_transform, N, prune_ratio = args.pruning_ratio,  )

train_sampler = DistributedCustomSampler(args, train_dataset, num_replicas=2, rank=0, drop_last=True)

trainloader = DataLoader(train_dataset, batch_size=3, ) #sampler = train_sampler, 

# Example Usage

args.backbone = 'deit_small_patch16_224.fb_in1k'
args.N = 10
args.strategy = 'full_fine_tuning'
rank = 0
model = load_architecture(args, N, rank )
model = CustomModel(args, model)

model.set_fine_tuning_strategy()
model.to(rank)
# model = DDP(model, device_ids=[rank])

./data
train size 47500 val size 2500
Files already downloaded and verified
47500 47500


CustomModel(
  (base_model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='no

In [4]:
from losses import get_loss, get_eval_loss
import numpy as np
from losses import apgd_loss

rank = 'cuda'

optimizer = torch.optim.SGD( model.parameters(),lr=args.init_lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True, )

for iteration in range(args.iterations):

    model.eval()
    train_sampler.set_epoch(iteration)
    
    for batch_id, batch in tqdm(enumerate( trainloader ) ):

        optimizer.zero_grad()

        data, target, idxs = batch

        # print(data,target, idxs)

        data, target = data.to(rank), target.to(rank) 

        loss_values, logits = get_loss(args, model, data, target, optimizer)

        loss = loss_values.mean()
        #loss = #train_dataset.compute_loss(idxs, loss_values)
        print(loss)

        loss.backward()
        optimizer.step()

        break
    
    model.update_fine_tuning_strategy(iteration)

pruning
remove tail
process


0it [00:00, ?it/s]

tensor(4.8118, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(4.4371, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(3.9475, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(3.5182, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(3.2499, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(3.2315, device='cuda:0', grad_fn=<MeanBackward0>)
Unfreezing all layers
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(2.7440, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(2.2918, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(2.3178, device='cuda:0', grad_fn=<MeanBackward0>)
pruning
remove tail
process


0it [00:00, ?it/s]

tensor(1.3087, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
import timm

pretrained_models = timm.list_models(pretrained=True)

for m in pretrained_models:
    print(m)


bat_resnext26ts.ch_in1k
beit_base_patch16_224.in22k_ft_in22k
beit_base_patch16_224.in22k_ft_in22k_in1k
beit_base_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_224.in22k_ft_in22k
beit_large_patch16_224.in22k_ft_in22k_in1k
beit_large_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_512.in22k_ft_in22k_in1k
beitv2_base_patch16_224.in1k_ft_in1k
beitv2_base_patch16_224.in1k_ft_in22k
beitv2_base_patch16_224.in1k_ft_in22k_in1k
beitv2_large_patch16_224.in1k_ft_in1k
beitv2_large_patch16_224.in1k_ft_in22k
beitv2_large_patch16_224.in1k_ft_in22k_in1k
botnet26t_256.c1_in1k
caformer_b36.sail_in1k
caformer_b36.sail_in1k_384
caformer_b36.sail_in22k
caformer_b36.sail_in22k_ft_in1k
caformer_b36.sail_in22k_ft_in1k_384
caformer_m36.sail_in1k
caformer_m36.sail_in1k_384
caformer_m36.sail_in22k
caformer_m36.sail_in22k_ft_in1k
caformer_m36.sail_in22k_ft_in1k_384
caformer_s18.sail_in1k
caformer_s18.sail_in1k_384
caformer_s18.sail_in22k
caformer_s18.sail_in22k_ft_in1k
caformer_s18.sail_in22k_ft_in1k_384
c

In [6]:
1 % 1

0

In [5]:
effective_batch_size = 1024
world_size = 4
per_gpu_batch_size = 320

print( effective_batch_size // (world_size * per_gpu_batch_size), (world_size * per_gpu_batch_size), effective_batch_size) 

0 1280 1024


In [6]:
from architectures.resnet_imagenet import ResNet_imagenet, Bottleneck_imagenet
model = ResNet_imagenet(Bottleneck_imagenet, [3, 4, 6, 3], )
num_features = model.fc.in_features
print(num_features)
model.fc = nn.Linear(num_features, 10)

2048
