In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
import numpy as np
import os
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
from torch.distributions import Multinomial, Bernoulli
from tensorboard_logger import configure, log_value
from utils import utils
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

2024-05-29 17:18:11.203676: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-29 17:18:11.233086: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
vit_b_16 = torchvision.models.vit_b_16(weights='ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1')

In [3]:
train_transform = v2.Compose([

    v2.PILToTensor(),
    v2.ToDtype(torch.float32,scale=True),
    v2.Resize(size=384,interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    v2.CenterCrop(384),
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    v2.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

val_transform = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32,scale=True),
    v2.Resize(size=384,interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    v2.CenterCrop(384),
    v2.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])


In [4]:
traindata = torchvision.datasets.ImageFolder(root="/media/samar/HDD1T/rl/PatchDrop-main/200-dataset/train",transform=train_transform)
valdata = torchvision.datasets.ImageFolder(root="/media/samar/HDD1T/rl/PatchDrop-main/200-dataset/val",transform=val_transform)

In [5]:
trainloader = DataLoader(traindata,batch_size=8,num_workers = 8,shuffle=True)
valloader = DataLoader(valdata,batch_size=8,num_workers = 8,shuffle=False)

In [6]:
len(trainloader)

25000

In [7]:
agent = utils.get_model()
agent.cuda()
vit_b_16.eval().cuda() # HR Classifier is Fixed
mappings, _, patch_size = utils.action_space_model()
start_epoch = 0
ppo_epochs = 3
clip = 0.2
test_interval = 1
lr_size = 96
alpha = 0.8
penalty = -0.5
max_epochs = 10
lr = 1e-4
optimizer = torch.optim.Adam(agent.parameters(), lr=lr)
cv_dir = "/media/samar/HDD1T/rl/PatchDrop-main/cv_dir/pretrain"
configure(cv_dir+'/log', flush_secs=5)

In [8]:
def train(epoch):
    # This steps trains the policy network only
    agent.train()
    matches, rewards, rewards_baseline, policies = [], [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):
        inputs, targets = Variable(inputs), Variable(targets).cuda(non_blocking=True)
        
        inputs = inputs.cuda()
        inputs_agent = inputs.clone()
        inputs_map = inputs.clone()
        inputs_sample = inputs.clone()

        # Run the low-res image through Policy Network
        inputs_agent = torch.nn.functional.interpolate(inputs_agent, (lr_size, lr_size))
        probs = F.sigmoid(agent.forward(inputs_agent))
        probs = probs*alpha + (1-alpha) * (1-probs)

        # Sample the policies from the Bernoulli distribution characterized by agent's output
        distr = Bernoulli(probs)
        policy_sample = distr.sample()

        # Test time policy - used as baseline policy in the training step
        policy_map = probs.data.clone()
        policy_map[policy_map<0.5] = 0.0
        policy_map[policy_map>=0.5] = 1.0
        # Agent sampled high resolution images
        inputs_map = utils.agent_chosen_input(inputs_map, policy_map, mappings, patch_size)
        inputs_sample = utils.agent_chosen_input(inputs_sample, policy_sample.int(), mappings, patch_size)

        # Forward propagate images through the classifiers
        preds_map = vit_b_16(inputs_map)
        preds_sample = vit_b_16(inputs_sample)

        # Find the reward for baseline and sampled policy
        reward_map, match = utils.compute_reward(preds_map, targets, policy_map.data, penalty)
        reward_sample, _ = utils.compute_reward(preds_sample, targets, policy_sample.data, penalty)
        advantage = reward_sample.cuda().float() - reward_map.cuda().float()

        # Find the loss for only the policy network
        loss = -distr.log_prob(policy_sample)
        loss = loss * Variable(advantage).expand_as(policy_sample)
        loss = loss.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        matches.append(match.cpu())
        rewards.append(reward_sample.cpu())
        rewards_baseline.append(reward_map.cpu())
        policies.append(policy_sample.data.cpu())

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    print('Train: %d | Acc: %.3f | Rw: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set)))
    log_value('train_accuracy', accuracy, epoch)
    log_value('train_reward', reward, epoch)
    log_value('train_sparsity', sparsity, epoch)
    log_value('train_variance', variance, epoch)
    log_value('train_baseline_reward', torch.cat(rewards_baseline, 0).mean(), epoch)
    log_value('train_unique_policies', len(policy_set), epoch)

In [9]:
def test(epoch):

    agent.eval()

    matches, rewards, policies = [], [], []
    with torch.no_grad():
      for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(valloader), total=len(valloader)):

          inputs, targets = Variable(inputs), Variable(targets).cuda(non_blocking=True)
          
          inputs = inputs.cuda()

          # Get the low resolution agent images
          inputs_agent = inputs.clone()
          inputs_agent = torch.nn.functional.interpolate(inputs_agent, (lr_size, lr_size))
          probs = F.sigmoid(agent.forward(inputs_agent))

          # Sample the test-time policy
          policy = probs.data.clone()
          policy[policy<0.5] = 0.0
          policy[policy>=0.5] = 1.0

          # Get the masked high-res image and perform inference
          inputs = utils.agent_chosen_input(inputs, policy, mappings, patch_size)
          preds = vit_b_16(inputs)

          reward, match = utils.compute_reward(preds, targets, policy.data, penalty)

          matches.append(match)
          rewards.append(reward)
          policies.append(policy.data)

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    print('Test - Acc: %.3f | Rw: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set)))
    log_value('test_accuracy', accuracy, epoch)
    log_value('test_reward', reward, epoch)
    log_value('test_sparsity', sparsity, epoch)
    log_value('test_variance', variance, epoch)
    log_value('test_unique_policies', len(policy_set), epoch)

    # Save the Policy Network - Classifier is fixed in this phase
    agent_state_dict = agent.state_dict()
    state = {
      'agent': agent_state_dict,
      'epoch': epoch,
      'reward': reward,
      'acc': accuracy
    }
    torch.save(state, cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E'%(epoch, accuracy, reward))

In [10]:
for epoch in range(start_epoch, start_epoch+max_epochs):
    train(epoch)
    if epoch % test_interval == 0:
        test(epoch)

100%|██████████| 25000/25000 [2:49:16<00:00,  2.46it/s]  


Train: 0 | Acc: 0.768 | Rw: 4.90E-01 | S: 6.818 | V: 1.677 | #: 24568


  inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(non_blocking=True)
100%|██████████| 6250/6250 [22:09<00:00,  4.70it/s]  


Test - Acc: 0.781 | Rw: 5.66E-01 | S: 5.862 | V: 0.345 | #: 2


100%|██████████| 25000/25000 [2:56:58<00:00,  2.35it/s]  


Train: 1 | Acc: 0.776 | Rw: 4.94E-01 | S: 6.876 | V: 1.642 | #: 19522


100%|██████████| 6250/6250 [21:00<00:00,  4.96it/s]


Test - Acc: 0.783 | Rw: 5.63E-01 | S: 6.026 | V: 0.451 | #: 5


100%|██████████| 25000/25000 [2:56:22<00:00,  2.36it/s]   


Train: 2 | Acc: 0.781 | Rw: 4.92E-01 | S: 6.978 | V: 1.631 | #: 19800


100%|██████████| 6250/6250 [21:16<00:00,  4.89it/s]  


Test - Acc: 0.784 | Rw: 5.64E-01 | S: 6.055 | V: 0.266 | #: 4


100%|██████████| 25000/25000 [2:58:18<00:00,  2.34it/s]   


Train: 3 | Acc: 0.773 | Rw: 4.92E-01 | S: 6.798 | V: 1.606 | #: 17861


100%|██████████| 6250/6250 [21:10<00:00,  4.92it/s]


Test - Acc: 0.785 | Rw: 5.67E-01 | S: 6.000 | V: 0.010 | #: 2


100%|██████████| 25000/25000 [2:58:46<00:00,  2.33it/s]   


Train: 4 | Acc: 0.774 | Rw: 4.93E-01 | S: 6.794 | V: 1.600 | #: 17618


100%|██████████| 6250/6250 [21:15<00:00,  4.90it/s]  


Test - Acc: 0.785 | Rw: 5.67E-01 | S: 5.991 | V: 0.094 | #: 2


  0%|          | 0/25000 [00:02<?, ?it/s]


KeyboardInterrupt: 