In [57]:
!pip install torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim




[notice] A new release of pip is available: 24.1.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [58]:
number_of_epochs = 1
learning_rate = 0.06

In [59]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]  # Normalize the images to [-1, 1]
)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=6000,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=6000,
                                         shuffle=False, num_workers=2)


def one_hot_encode(labels, num_classes=10):
    """ Converts a batch of labels to one-hot encoded format. """
    batch_size = labels.size(0)
    one_hot_labels = torch.zeros(batch_size, num_classes, device=labels.device)
    one_hot_labels.scatter_(1, labels.unsqueeze(1), 1)
    return one_hot_labels


def create_positive_data(data, labels):
    """ Return original data and one-hot encoded labels. """
    one_hot_labels = one_hot_encode(labels, num_classes=10)
    return data.cuda(), one_hot_labels.cuda()

def create_negative_data(data, labels):
    """ Create negative samples by randomly selecting different labels and return with one-hot encoded format. """
    batch_size = labels.size(0)
    num_classes = 10
    # Generate random labels different from the current labels
    incorrect_labels = (labels + torch.randint(1, num_classes, (batch_size,), device=labels.device)) % num_classes
    one_hot_labels = one_hot_encode(incorrect_labels, num_classes=10)
    return data.cuda(), one_hot_labels.cuda()

Files already downloaded and verified
Files already downloaded and verified


In [60]:
class CustomNetwork(nn.Module):
    def __init__(self):
        super(CustomNetwork, self).__init__()
        # First hidden layer: Receives input and applies an 11x11 convolution
        self.layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=11, padding=5, stride=1)
        self.bn1 = nn.BatchNorm2d(num_features=3)  # Batch normalization for 3 output channels
        
        # Second hidden layer: Takes the output of the first layer and applies another 11x11 convolution
        self.layer2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=11, padding=5, stride=1)
        self.bn2 = nn.BatchNorm2d(num_features=3)  # Batch normalization for 3 output channels
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.hebbian_factors = torch.ones(2, 10, 3072, requires_grad=True, device='cuda')
        self.hebbian_optimizer = optim.Adam([self.hebbian_factors], lr=learning_rate)
        self.threshold = 0
        
    def soft_plus_loss(self, positive_goodness, negative_goodness, is_second_phase=False):
        if is_second_phase:
            threshold = self.threshold * 2
        else:
            threshold = self.threshold
        return torch.log(1 + torch.exp(torch.cat([
            -positive_goodness + threshold,
            negative_goodness - threshold]))).mean()

    def forward(self, x, layer_num):
        
        if layer_num == 0:
            # Apply first hidden layer
            x = self.layer1(x)
            x = self.bn1(x)
            x = F.relu(x)
        elif layer_num == 1:
            # Apply second hidden layer
            x = self.layer2(x)
            x = self.bn2(x)
            x = F.relu(x)

        # Since no output layer is used, return the feature map from the last layer directly
        return x
    
    def train_network(self, training_data_loader):
        for _ in tqdm(range(number_of_epochs)):
            for images, labels in training_data_loader:
               positive_data, positive_labels = create_positive_data(images, labels)
               negative_data, negative_labels = create_negative_data(images, labels)
               
               for i in range(2):
                   hebbian_factors = self.hebbian_factors[i, :, :]
                   positive_data = self.forward(positive_data, i)
                   flattened_positive_data = positive_data.clone()
                   flattened_positive_data = flattened_positive_data.view(flattened_positive_data.size(0), -1).detach()
                   negative_data = self.forward(negative_data, i)
                   flattened_negative_data = negative_data.clone()
                   flattened_negative_data = flattened_negative_data.view(flattened_negative_data.size(0), -1).detach()
                   
                   
                   negative_goodness = (torch.mm(negative_labels, hebbian_factors) * flattened_positive_data).mean(1)
                   positive_goodness = (torch.mm(positive_labels, hebbian_factors) * flattened_negative_data).mean(1)
                   loss = self.soft_plus_loss(positive_goodness, negative_goodness)
                   self.optimizer.zero_grad()
                   self.hebbian_optimizer.zero_grad()
                   loss.backward()
                   self.hebbian_optimizer.step()
                   self.optimizer.step()
                   
    

    def predict(self, testing_data_loader):
        self.cuda()  # Ensure the model is on GPU
        correct = 0
        total = 0
        for images, actual_labels in testing_data_loader:
            images = images.cuda()  # Ensure images are on GPU
            actual_labels = actual_labels.cuda()  # Ensure labels are on GPU
            batch_size = images.size(0)
            goodness_per_label = []
    
            for label in range(10):  # Assuming 10 classes
                labels = torch.full((batch_size,), label, dtype=torch.long, device=images.device)
                marked_data, one_hot_labels = create_positive_data(images, labels)
                goodness = []
    
                for layer_num in range(2):  # Assuming 2 layers
                    marked_data = self.forward(marked_data, layer_num)
                    flattened_data = marked_data.view(marked_data.size(0), -1).detach()
                    goodness_value = (torch.mm(one_hot_labels, self.hebbian_factors[layer_num, :, :]) * flattened_data).mean(1)
                    goodness.append(goodness_value)
    
                goodness_per_label.append(torch.sum(torch.stack(goodness), dim=0).unsqueeze(1))
    
            goodness_per_label = torch.cat(goodness_per_label, 1)
            predicted_labels = goodness_per_label.argmax(dim=1)
            correct += (predicted_labels == actual_labels).sum().item()
            total += batch_size
    
        accuracy = correct / total
        return accuracy


In [61]:
if __name__ == "__main__":
    torch.cuda.empty_cache()
    torch.manual_seed(1234)
    network = CustomNetwork().cuda()
    for i in range(20):
        network.train_network(trainloader)
        training_acc = network.predict(trainloader)
        testing_acc = network.predict(testloader)
        print(f"Training Acc: {training_acc} in {(i+1) * number_of_epochs} epochs")
        print(f"Testing Acc: {testing_acc} in {(i+1) * number_of_epochs} epochs")
    
    

100%|██████████| 1/1 [00:12<00:00, 12.44s/it]


Training Acc: 0.19172 in 1 epochs
Testing Acc: 0.1957 in 1 epochs


100%|██████████| 1/1 [00:09<00:00,  9.16s/it]


Training Acc: 0.22896 in 2 epochs
Testing Acc: 0.231 in 2 epochs


100%|██████████| 1/1 [00:09<00:00,  9.03s/it]


Training Acc: 0.25428 in 3 epochs
Testing Acc: 0.2578 in 3 epochs


100%|██████████| 1/1 [00:08<00:00,  8.16s/it]


Training Acc: 0.27688 in 4 epochs
Testing Acc: 0.2805 in 4 epochs


100%|██████████| 1/1 [00:09<00:00,  9.87s/it]


Training Acc: 0.2858 in 5 epochs
Testing Acc: 0.2899 in 5 epochs


100%|██████████| 1/1 [00:09<00:00,  9.26s/it]


Training Acc: 0.294 in 6 epochs
Testing Acc: 0.2992 in 6 epochs


100%|██████████| 1/1 [00:08<00:00,  8.63s/it]


Training Acc: 0.30126 in 7 epochs
Testing Acc: 0.305 in 7 epochs


100%|██████████| 1/1 [00:07<00:00,  7.47s/it]


Training Acc: 0.30838 in 8 epochs
Testing Acc: 0.3108 in 8 epochs


100%|██████████| 1/1 [00:10<00:00, 10.07s/it]


Training Acc: 0.31094 in 9 epochs
Testing Acc: 0.3153 in 9 epochs


100%|██████████| 1/1 [00:08<00:00,  8.42s/it]


Training Acc: 0.31478 in 10 epochs
Testing Acc: 0.3174 in 10 epochs


100%|██████████| 1/1 [00:06<00:00,  6.59s/it]


Training Acc: 0.31616 in 11 epochs
Testing Acc: 0.317 in 11 epochs


100%|██████████| 1/1 [00:06<00:00,  6.67s/it]


Training Acc: 0.31934 in 12 epochs
Testing Acc: 0.3206 in 12 epochs


100%|██████████| 1/1 [00:06<00:00,  6.41s/it]


Training Acc: 0.3212 in 13 epochs
Testing Acc: 0.3227 in 13 epochs


100%|██████████| 1/1 [00:07<00:00,  7.04s/it]


Training Acc: 0.32484 in 14 epochs
Testing Acc: 0.3258 in 14 epochs


100%|██████████| 1/1 [00:06<00:00,  6.42s/it]


Training Acc: 0.32646 in 15 epochs
Testing Acc: 0.328 in 15 epochs


100%|██████████| 1/1 [00:06<00:00,  6.69s/it]


Training Acc: 0.3271 in 16 epochs
Testing Acc: 0.3286 in 16 epochs


100%|██████████| 1/1 [00:06<00:00,  6.34s/it]


Training Acc: 0.32912 in 17 epochs
Testing Acc: 0.3306 in 17 epochs


100%|██████████| 1/1 [00:06<00:00,  6.50s/it]


Training Acc: 0.33232 in 18 epochs
Testing Acc: 0.3308 in 18 epochs


100%|██████████| 1/1 [00:06<00:00,  6.66s/it]


Training Acc: 0.33292 in 19 epochs
Testing Acc: 0.3326 in 19 epochs


100%|██████████| 1/1 [00:06<00:00,  6.48s/it]


Training Acc: 0.33572 in 20 epochs
Testing Acc: 0.3382 in 20 epochs
