In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as models

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

class VGG16Model(nn.Module):
    def __init__(self):
        super(VGG16Model, self).__init__()
        self.model = models.vgg16(pretrained=True)
        self.model.classifier[6] = nn.Linear(4096, 10)

        self.convolutional_layers = list(self.model.features.children())

    def forward(self, x):
        return self.model(x)

class LayerAgent:
    def __init__(self, layer, device):
        self.layer = layer
        self.device = device
        self.num_output_channels = layer.out_channels
        self.weight_array = torch.full((self.num_output_channels,), 6.9, requires_grad=True, device=device)
        self.optimizer = optim.Adam([self.weight_array], lr=0.01)

    def get_probability(self):
        return torch.sigmoid(self.weight_array)

    def update_weights(self, action, reward):
        prob = self.get_probability()
        loss = -torch.log(prob) * reward.to(self.device)
        total_loss = loss.mean()
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

class PruningEnvironment:
    def __init__(self, model, device):
        self.model = model
        self.layer_agents = []
        for layer in self.model.convolutional_layers:
            if isinstance(layer, nn.Conv2d):
                self.layer_agents.append(LayerAgent(layer, device))
        self.lambda_penalty = 200
        self.device = device

    def step(self, actions, input_data, target):
        input_data = input_data.to(self.device)
        target = target.to(self.device)

        masks = [torch.bernoulli(agent.get_probability()).bool().to(self.device) for agent in self.layer_agents]

        j = 0
        for i, layer in enumerate(self.model.convolutional_layers):
            if isinstance(layer, nn.Conv2d):
                mask = masks[j].float().view(layer.out_channels, 1, 1, 1)
                layer.weight.data = (layer.weight.data.to(self.device) * mask).to(self.device)
                j += 1

        prediction = self.model(input_data)
        _, predicted_label = torch.max(prediction, 1)

        compression_rewards = []
        for i, mask in enumerate(masks):
            dropped_channels = (1 - mask.float()).sum()
            R_i_C = dropped_channels.item()
            compression_rewards.append(R_i_C)

        R_acc = 0
        for i in range(len(predicted_label)):
            if predicted_label[i] == target[i]:
                R_acc += 1
            else:
                R_acc -= self.lambda_penalty
        R_acc /= input_data.size(0)

        for i, agent in enumerate(self.layer_agents):
            R_i = compression_rewards[i] * R_acc
            agent.update_weights(actions[i], torch.tensor(R_i, device=self.device))

def train_model_rl(model, train_loader, epochs, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            env = PruningEnvironment(model, device)
            actions = [agent.get_probability() for agent in env.layer_agents]
            env.step(actions, inputs, labels)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        accuracy = 100. * correct / total
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}%')

def count_zero_weight_channels(model):
    total_zero_channels = 0

    print("\nZero Weight Channels Report:")
    for i, layer in enumerate(model.model.features.children()):
        if isinstance(layer, nn.Conv2d):
            weights = layer.weight.data
            zero_channels = (weights == 0).all(dim=(1, 2, 3)).sum().item()
            total_zero_channels += zero_channels
            print(f"Layer {i}: Dropped {int(zero_channels)} channels (weights = 0)")

    print(f"Total zero weight channels across all Conv2d layers: {int(total_zero_channels)}\n")
    return total_zero_channels

def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

def fine_tune_model(model, train_loader, epochs, device, lr=0.0001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        env = PruningEnvironment(model, device)

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        accuracy = 100. * correct / total
        print(f'Fine-Tuning Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}%')

        total_zero_channels = count_zero_weight_channels(model)
        print(f"Total dropped channels (weights = 0) after fine-tuning: {total_zero_channels}")

model = VGG16Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_model_rl(model, train_loader, epochs=15, device=device)
fine_tune_model(model, train_loader, epochs=20, device=device)
evaluate_model(model, test_loader)