In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
num_epochs = 60
policy_training_stop_epoch = 40
learning_rate = 0.01
lambda_penalty = 500

transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

class Agent(nn.Module):
    def __init__(self, num_channels):
        super(Agent, self).__init__()
        self.policy = nn.Linear(num_channels, num_channels)
        nn.init.constant_(self.policy.bias, 6.9)
        nn.init.constant_(self.policy.weight, 0.0)

    def forward(self, state):
        logits = self.policy(state)
        probs = torch.sigmoid(logits)
        return probs

class PrunableConv2d(nn.Module):
    def __init__(self, conv_layer):
        super(PrunableConv2d, self).__init__()
        self.conv = conv_layer
        self.out_channels = conv_layer.out_channels
        self.channel_mask = torch.ones(self.out_channels).to(device)

    def forward(self, x):
        out = self.conv(x)
        out = out * self.channel_mask.view(1, -1, 1, 1)
        return out

vgg16 = torchvision.models.vgg16(pretrained=True)

vgg16.classifier[6] = nn.Linear(4096, 10)

vgg16 = vgg16.to(device)

conv_layer_indices = []
for idx, layer in enumerate(vgg16.features):
    if isinstance(layer, nn.Conv2d):
        conv_layer_indices.append(idx)

agents = []
for idx in conv_layer_indices:
    conv_layer = vgg16.features[idx]
    prunable_conv = PrunableConv2d(conv_layer)
    vgg16.features[idx] = prunable_conv.to(device)
    num_channels = prunable_conv.out_channels
    agent = Agent(num_channels).to(device)
    agents.append(agent)

def get_initial_state(num_channels):
    return torch.ones(num_channels).to(device)

criterion = nn.CrossEntropyLoss()

model_optimizer = optim.SGD(vgg16.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

agent_optimizers = [optim.Adam(agent.parameters(), lr=0.01) for agent in agents]

def train(model, device, train_loader, optimizer, agents, agent_optimizers, epoch, lambda_penalty):
    model.train()
    total_correct = 0
    total_samples = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        if agent_optimizers is not None:
            for agent_opt in agent_optimizers:
                agent_opt.zero_grad()

        log_probs_list = []
        entropies_list = []
        actions_list = []

        for agent, idx in zip(agents, conv_layer_indices):
            prunable_conv = model.features[idx]
            num_channels = prunable_conv.out_channels
            state = get_initial_state(num_channels)
            probs = agent(state)
            m = torch.distributions.Bernoulli(probs)
            actions = m.sample()
            log_probs = m.log_prob(actions)
            entropy = m.entropy()
            log_probs_list.append(log_probs)
            entropies_list.append(entropy)
            actions_list.append(actions)
            prunable_conv.channel_mask = actions.detach()

        outputs = model(inputs)

        classification_loss = criterion(outputs, targets)

        _, predicted = outputs.max(1)
        correct = predicted.eq(targets).sum().item()
        total_correct += correct
        total_samples += targets.size(0)

        classification_loss.backward()
        optimizer.step()

        if agent_optimizers is not None:
            R_acc = torch.where(predicted == targets, torch.ones(batch_size).to(device), -lambda_penalty * torch.ones(batch_size).to(device))
            R_acc_mean = R_acc.mean()

            for agent, agent_opt, log_probs, entropy, actions in zip(agents, agent_optimizers, log_probs_list, entropies_list, actions_list):
                R_iC = torch.sum(1 - actions)
                R_i = R_iC * R_acc_mean
                policy_loss = -log_probs.sum() * R_i
                entropy_loss = -0.01 * entropy.sum()
                total_agent_loss = policy_loss + entropy_loss
                total_agent_loss.backward()
                agent_opt.step()

        if batch_idx % 100 == 0:
            acc = 100. * total_correct / total_samples
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {classification_loss.item():.4f}, Accuracy: {acc:.2f}%')

def test(model, device, test_loader):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss /= len(test_loader)
    accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')
    return accuracy

def calculate_pruning(model, agents, threshold=0.5):
    total_channels = 0
    total_pruned_channels = 0
    print("Pruning Summary:")
    for agent, idx in zip(agents, conv_layer_indices):
        prunable_conv = model.features[idx]
        num_channels = prunable_conv.out_channels
        total_channels += num_channels
        state = get_initial_state(num_channels)
        with torch.no_grad():
            probs = agent(state)
        keep_channels = (probs >= threshold).cpu().numpy()
        num_channels_to_keep = keep_channels.sum()
        num_channels_to_drop = num_channels - num_channels_to_keep
        total_pruned_channels += num_channels_to_drop
        print(f'Layer {idx}: Would prune {num_channels_to_drop} channels out of {num_channels}')
    print(f'Total channels: {total_channels}')
    print(f'Total channels that would be pruned: {total_pruned_channels}')
    pruned_ratio = 100.0 * total_pruned_channels / total_channels
    print(f'Overall pruning ratio: {pruned_ratio:.2f}%')

best_accuracy = 0.0
for epoch in range(num_epochs):
    if epoch < policy_training_stop_epoch:
        train(vgg16, device, train_loader, model_optimizer, agents, agent_optimizers, epoch, lambda_penalty)
    else:
        train(vgg16, device, train_loader, model_optimizer, agents, None, epoch, lambda_penalty=0)
    accuracy = test(vgg16, device, test_loader)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(vgg16.state_dict(), 'vgg16_cifar10_best.pth')

    if epoch == policy_training_stop_epoch:
        print("Calculating how many channels would be pruned based on learned policies...")
        calculate_pruning(vgg16, agents, threshold=0.5)

torch.save(vgg16.state_dict(), 'vgg16_cifar10_final.pth')
