In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import models
from torchvision.models.resnet import resnet18, ResNet18_Weights

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Enhanced Data Augmentation and Normalization for CIFAR-10
transform = {
    'train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
}
# Load CIFAR-10 Dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform['train'])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform['test'])

100%|██████████| 170M/170M [00:05<00:00, 31.9MB/s]


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class IntermediateBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, num_convs):
#         super(IntermediateBlock, self).__init__()
#         self.convs = nn.ModuleList([
#             nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, padding=1)
#             for i in range(num_convs)
#         ])
#         # The Linear layer should match the number of output channels from the conv layers
#         self.fc = nn.Linear(out_channels, num_convs)  # Fully connected layer for generating 'a'

#     def forward(self, x):
#         for conv in self.convs:
#             x = F.relu(conv(x))

#         m = torch.mean(x, dim=[2, 3])  # Mean across spatial dimensions
#         a = F.softmax(self.fc(m), dim=1)  # Generate weights 'a'

#         x_prime = torch.zeros_like(x)
#         for i, conv in enumerate(self.convs):
#             conv_output = conv(x)
#             weight = a[:, i].unsqueeze(1).unsqueeze(2).unsqueeze(3)
#             x_prime += weight * conv_output

#         return x_prime


# class OutputBlock(nn.Module):
#     def __init__(self, in_features, num_classes, hidden_layers=[]):
#         super(OutputBlock, self).__init__()
#         self.layers = nn.Sequential()

#         # Set the correct input size
#         input_size = in_features  # This should be 8192 given the ResNet structure for CIFAR-10

#         # Create layers
#         for hidden_size in hidden_layers:
#             self.layers.add_module('fc', nn.Linear(input_size, hidden_size))
#             self.layers.add_module('relu', nn.ReLU(inplace=True))
#             input_size = hidden_size

#         # Final classification layer
#         self.layers.add_module('fc_final', nn.Linear(input_size, num_classes))

#     def forward(self, x):
#         x = torch.flatten(x, 1)  # Flatten the tensor
#         x = self.layers(x)       # Pass through the layers
#         return x

# # Adjustments in the ResNet modification function
# def modify_resnet18_for_cifar10_with_custom_blocks():
#     resnet18 = models.resnet18(pretrained=False)
#     resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
#     resnet18.maxpool = nn.Identity()

#     # Insert IntermediateBlock after layer1 of ResNet (for example)
#     in_channels = resnet18.layer1[-1].conv2.out_channels  # Get the number of output channels from the last conv layer in layer1
#     resnet18.layer1.add_module("intermediate_block", IntermediateBlock(in_channels, in_channels, num_convs=2))

#     # Assuming the output size before the fc layer is 8192 after flattening
#     in_features = 8192
#     resnet18.fc = OutputBlock(in_features, num_classes=10, hidden_layers=[512, 256])

#     return resnet18


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models

import torch
import torch.nn as nn
import torch.nn.functional as F


import torch
import torch.nn as nn
import torch.nn.functional as F

class IntermediateBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_convs):
        super(IntermediateBlock, self).__init__()
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
            for _ in range(num_convs)
        ])
        self.fc = nn.Linear(in_channels, num_convs)  # Fully connected layer for generating 'a'

    def forward(self, x):
        m = torch.mean(x, dim=[2, 3])  # Mean across spatial dimensions
        a = F.softmax(self.fc(m), dim=1)  # Generate weights 'a'

        x_prime = None  # Initialize x_prime after the first conv operation
        for i, conv in enumerate(self.convs):
            conv_output = conv(x)
            if x_prime is None:
                x_prime = torch.zeros_like(conv_output)
            weight = a[:, i].unsqueeze(1).unsqueeze(2).unsqueeze(3)
            x_prime += weight * conv_output

        return x_prime


class OutputBlock(nn.Module):
    def __init__(self, in_features, num_classes, hidden_layers=[]):
        super(OutputBlock, self).__init__()
        self.layers = nn.Sequential()

        # Set the correct input size
        input_size = in_features  # This should be 8192 given the ResNet structure for CIFAR-10

        # Create layers
        for hidden_size in hidden_layers:
            self.layers.add_module('fc', nn.Linear(input_size, hidden_size))
            self.layers.add_module('relu', nn.ReLU(inplace=True))
            input_size = hidden_size

        # Final classification layer
        self.layers.add_module('fc_final', nn.Linear(input_size, num_classes))

    def forward(self, x):
        x = torch.flatten(x, 1)  # Flatten the tensor
        x = self.layers(x)       # Pass through the layers
        return x


In [None]:
# Hyperparameters
epochs = 200
learning_rate = 0.001
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class CustomCIFAR10Model(nn.Module):
    def __init__(self, num_classes=10):
        super(CustomCIFAR10Model, self).__init__()
        # IntermediateBlock processes the input images first
        self.intermediate_block = IntermediateBlock(3, 64, num_convs=2)
        self.dropout = nn.Dropout(0.5)  # Dropout layer

        # OutputBlock generates a class-dimension vector
        self.output_block = OutputBlock(65536, num_classes=num_classes)  # Assuming the IntermediateBlock output size is 65536

        # ResNet18 model for feature extraction
        self.resnet18 = models.resnet18(pretrained=False)
        self.resnet18.fc = nn.Identity()  # Remove the final FC layer to use features

        # Linear layer to combine ResNet18 features with class-dimension vector
        self.combine_fc = nn.Linear(512 + num_classes, num_classes)  # 512 from ResNet18 features, num_classes from class-dimension vector

    def forward(self, x):
        # Pass the input through the IntermediateBlock
        intermediate = self.intermediate_block(x)
        intermediate = self.dropout(intermediate)  # Apply dropout

        # Flatten and pass through the OutputBlock to get class-dimension vector
        class_vector = self.output_block(torch.flatten(intermediate, 1))

        # Extract features from ResNet18
        features = self.resnet18(x)

        # Concatenate ResNet18 features with class-dimension vector
        combined = torch.cat((features, class_vector), dim=1)

        # Pass through the final linear layer for classification
        out = self.combine_fc(combined)

        return out



# Assuming train_loader and test_loader are defined and loaded with the updated transform
model = CustomCIFAR10Model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)  # Added weight decay
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)




In [None]:
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    train_loss /= len(train_loader)
    accuracy = 100. * correct / total
    print(f'\nTrain set: Average loss: {train_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)')

def evaluate(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            test_loss += criterion(outputs, targets).item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    test_loss /= len(test_loader)
    accuracy = 100. * correct / total
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)')



In [None]:
for epoch in range(1, epochs + 1):
    print(f'Epoch: {epoch}/{epochs}')
    train(model, device, train_loader, optimizer, criterion)  # Your train function
    evaluate(model, device, test_loader, criterion)  # Your evaluate function
    scheduler.step()  # Update the learning rate
    print('--------------------------------')

Epoch: 1/200

Train set: Average loss: 1.6323, Accuracy: 20368/50000 (40.74%)
Test set: Average loss: 1.3504, Accuracy: 5136/10000 (51.36%)
--------------------------------
Epoch: 2/200

Train set: Average loss: 1.3105, Accuracy: 26500/50000 (53.00%)
Test set: Average loss: 1.1307, Accuracy: 6050/10000 (60.50%)
--------------------------------
Epoch: 3/200

Train set: Average loss: 1.1538, Accuracy: 29640/50000 (59.28%)
Test set: Average loss: 0.9693, Accuracy: 6576/10000 (65.76%)
--------------------------------
Epoch: 4/200

Train set: Average loss: 1.0473, Accuracy: 31445/50000 (62.89%)
Test set: Average loss: 0.9857, Accuracy: 6579/10000 (65.79%)
--------------------------------
Epoch: 5/200

Train set: Average loss: 0.9807, Accuracy: 32808/50000 (65.62%)
Test set: Average loss: 0.8996, Accuracy: 6889/10000 (68.89%)
--------------------------------
Epoch: 6/200

Train set: Average loss: 0.9286, Accuracy: 33736/50000 (67.47%)
Test set: Average loss: 0.8008, Accuracy: 7201/10000 (72.