In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [11]:
# Load the CIFAR-10 dataset and filter classes
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Filter out airplane and automobile images
trainset = [(img, label) for img, label in trainset if label in [0, 1]] # airplane = 0, automobile=1 
testset = [(img, label) for img, label in testset if label in [0, 1]]

In [16]:
# Define the neural network architecture for the Speaker
class Speaker(nn.Module):
    def __init__(self, input_size, message_size):
        super(Speaker, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, message_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        message = torch.sigmoid(self.fc2(x))
        return message

# Define the neural network architecture for the Listener
class Listener(nn.Module):
    def __init__(self, message_size, num_classes):
        super(Listener, self).__init__()
        self.fc1 = nn.Linear(message_size, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, message):
        x = torch.relu(self.fc1(message))
        logits = self.fc2(x)
        return logits

# Example usage:
input_size = 32 * 32 * 3  # Flattened image size
message_size = 5  # Size of the message produced by the Speaker
num_classes = 2  # Number of classes in the problem (e.g., 'airplane' and 'automobile')

speaker = Speaker(input_size, message_size)
listener = Listener(message_size, num_classes)

In [18]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

def train_agents(speaker, listener, dataloader, optimizer_speaker, optimizer_listener, num_epochs):
    for epoch in range(num_epochs):
        total_reward = 0.0
        
        for images, labels in dataloader:
            # Flatten the images
            images = images.view(images.size(0), -1)
            
            # Forward pass through the Speaker
            message = speaker(images)
            
            # Forward pass through the Listener
            logits = listener(message)
            
            # Sample action (class) from the logits using Categorical distribution
            m = Categorical(logits=logits)
            action = m.sample()
            
            # Calculate reward (binary: +1 for correct classification, 0 otherwise)
            reward = torch.eq(action, labels).float()
            
            # Calculate loss for Speaker (REINFORCE loss)
            loss_speaker = -m.log_prob(action) * reward
            loss_speaker = loss_speaker.mean()
            
            # Calculate loss for Listener (cross-entropy)
            loss_listener = F.cross_entropy(logits, labels)
            
            # Backpropagation and optimization
            optimizer_speaker.zero_grad()
            optimizer_listener.zero_grad()
            loss_speaker.backward(retain_graph=True)  # Specify retain_graph=True
            loss_listener.backward()
            optimizer_speaker.step()
            optimizer_listener.step()
            
            total_reward += reward.sum().item()
        
        avg_reward = total_reward / len(dataloader.dataset)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Average Reward: {avg_reward:.4f}')

# Example usage:
num_epochs = 10  # Adjust the number of epochs as needed
optimizer_speaker = optim.Adam(speaker.parameters(), lr=0.001)
optimizer_listener = optim.Adam(listener.parameters(), lr=0.001)

# Assuming you have 'train_loader' set up with the training data
train_agents(speaker, listener, train_loader, optimizer_speaker, optimizer_listener, num_epochs)


Epoch [1/10], Average Reward: 0.7177
Epoch [2/10], Average Reward: 0.8071
Epoch [3/10], Average Reward: 0.8270
Epoch [4/10], Average Reward: 0.8532
Epoch [5/10], Average Reward: 0.8631
Epoch [6/10], Average Reward: 0.8851
Epoch [7/10], Average Reward: 0.8949
Epoch [8/10], Average Reward: 0.9110
Epoch [9/10], Average Reward: 0.9167
Epoch [10/10], Average Reward: 0.9229


In [19]:
def evaluate_agents(speaker, listener, dataloader):
    listener.eval()  # Set the Listener in evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            # Flatten the images
            images = images.view(images.size(0), -1)

            # Forward pass through the Speaker
            message = speaker(images)

            # Forward pass through the Listener
            logits = listener(message)

            # Get the predicted class
            _, predicted = torch.max(logits, 1)

            # Calculate accuracy
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Testing Accuracy of the Listener: {100 * accuracy:.2f}%')

# Example usage:
# Assuming you have 'test_loader' set up with the testing data
evaluate_agents(speaker, listener, test_loader)


Testing Accuracy of the Listener: 89.00%
