In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as models

# Load and preprocess the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

class VGG16Model(nn.Module):
    def __init__(self):
        super(VGG16Model, self).__init__()
        self.model = models.vgg16(pretrained=True)
        self.model.classifier[6] = nn.Linear(4096, 10)  # Modify for CIFAR-10

        self.convolutional_layers = list(self.model.features.children())

    def forward(self, x):
        return self.model(x)


# Define an agent for each convolutional layer
class LayerAgent:
    def __init__(self, layer):
        self.layer = layer
        self.num_output_channels = layer.out_channels
        self.weight_array = torch.full((self.num_output_channels,), 6.9, requires_grad=True)
        self.optimizer = optim.Adam([self.weight_array], lr=0.01)

    def get_probability(self):
        return torch.sigmoid(self.weight_array)  # Sigmoid to convert weights to probabilities

    def update_weights(self, action, reward):
        prob = self.get_probability()
        loss = -torch.log(prob) * reward  # Calculate loss for policy gradient
        total_loss = loss.mean()  # or loss.mean() if you prefer averaging
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()


class PruningEnvironment:
    def __init__(self, model, device):
        self.model = model
        self.layer_agents = []
        for layer in self.model.convolutional_layers:
            if isinstance(layer, nn.Conv2d):
                self.layer_agents.append(LayerAgent(layer))

        # print(len(self.layer_agents))

        self.lambda_penalty = 200  # Penalty for incorrect predictions
        self.device = device  # Store the device

    def step(self, actions, input_data, target):
        masks = [torch.bernoulli(agent.get_probability()).bool() for agent in self.layer_agents]
        # for i, mask in enumerate(masks):
        #     print(f'Length of mask {i}: {mask.size(0)}')  # or mask.shape[0]

        # Prune the channels in the model
        j = 0
        for i, layer in enumerate(self.model.convolutional_layers):
            if isinstance(layer, nn.Conv2d):
                # print(i, " ", len(layer.weight.data), " " , len(masks[j]))
                mask = masks[j].float().view(layer.out_channels, 1, 1, 1)
                layer.weight.data *= mask.float()  # Apply pruning based on the mask
                j += 1

        # Forward pass
        input_data = input_data.to(self.device)  # Move input data to the appropriate device (if using GPU)
        prediction = self.model(input_data)  # Forward pass
        _, predicted_label = torch.max(prediction, 1)
        # print(predicted_label)

        # Calculate rewards
        compression_rewards = []
        for i, (agent, mask) in enumerate(zip(self.layer_agents, masks)):
            dropped_channels = (1 - mask.float()).sum()  # Count the number of dropped channels
            R_i_C = dropped_channels.item()  # Compression reward for layer i
            compression_rewards.append(R_i_C)

        # Calculate accuracy reward
        R_acc = 0
        for i in range (len(predicted_label)):
            if (predicted_label[i] == target[i]):
                R_acc += 1
            else:
                R_acc -= self.lambda_penalty
        R_acc /= 256

        for i, agent in enumerate(self.layer_agents):
            R_i = compression_rewards[i] * R_acc  # Combine compression and accuracy rewards
            agent.update_weights(actions[i], R_i)  # Update weights based on action and final reward

def train_model_rl(model, train_loader, epochs, device):
    model.train()
    # print("training started")
    for epoch in range(epochs):
        i = 0
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            # Assume inputs and labels are batched
            env = PruningEnvironment(model, device)  # Pass the device
            # print("env made")
            actions = [agent.get_probability() for agent in env.layer_agents]  # Get current probabilities

            # Take a step in the environment with inputs and labels
            env.step(actions, inputs, labels)
            # print("step taken")

            # Calculate the overall loss
            outputs = model(inputs.to(device))  # Forward pass
            loss = criterion(outputs, labels.to(device))  # Calculate loss based on the final output
            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if (i+1) % 25 == 0:
                print(f'epoch {epoch+1}/{30}, step: {i+1}/{196}: loss = {loss:.5f}, acc = {100*(correct/total):.5f}%')

            i+=1

        accuracy = 100. * correct / total
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}%')


# Initialize model, loss function, and optimizer
model = VGG16Model()
# for i, layer in enumerate(model.convolutional_layers):
#     if isinstance(layer, nn.Conv2d):
#         print(f'Layer {i}: out_channels = {layer.out_channels}')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)  # Move model to the appropriate device

# Train the model using RL for 300 epochs
train_model_rl(model, train_loader, epochs=30, device=device)

# Evaluate the model after training
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

# Evaluate the model
evaluate_model(model, test_loader)