In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Basic block for ResNet
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)
        return out

# ResNet18 model definition
class ResNet18Modified(nn.Module):
    def __init__(self, num_classes=10, in_channels=3):
        super(ResNet18Modified, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride, downsample))
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Training and evaluation function
def train_and_evaluate(model, train_loader, test_loader, num_epochs=20, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

# Loading and transforming datasets
def load_data(dataset_name):
    transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
    if dataset_name == 'MNIST':
        train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
        in_channels = 1
    elif dataset_name == 'FMNIST':
        train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
        test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)
        in_channels = 1
    elif dataset_name == 'CIFAR-10':
        train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
        test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
        in_channels = 3
    else:
        raise ValueError("Dataset not supported")

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    return train_loader, test_loader, in_channels

# Running training for each dataset
for dataset_name in ['MNIST', 'FMNIST', 'CIFAR-10']:
    print(f'\nTraining on {dataset_name}')
    train_loader, test_loader, in_channels = load_data(dataset_name)
    model = ResNet18Modified(num_classes=10, in_channels=in_channels)
    train_and_evaluate(model, train_loader, test_loader)



Training on MNIST
Epoch [1/20], Loss: 0.1325
Epoch [2/20], Loss: 0.0598
Epoch [3/20], Loss: 0.0484
Epoch [4/20], Loss: 0.0397
Epoch [5/20], Loss: 0.0351
Epoch [6/20], Loss: 0.0301
Epoch [7/20], Loss: 0.0261
Epoch [8/20], Loss: 0.0252
Epoch [9/20], Loss: 0.0225
Epoch [10/20], Loss: 0.0190
Epoch [11/20], Loss: 0.0175
Epoch [12/20], Loss: 0.0153
Epoch [13/20], Loss: 0.0151
Epoch [14/20], Loss: 0.0136
Epoch [15/20], Loss: 0.0104
Epoch [16/20], Loss: 0.0110
Epoch [17/20], Loss: 0.0098
Epoch [18/20], Loss: 0.0098
Epoch [19/20], Loss: 0.0088
Epoch [20/20], Loss: 0.0073
Test Accuracy: 99.04%

Training on FMNIST
Epoch [1/20], Loss: 0.4241
Epoch [2/20], Loss: 0.3022
Epoch [3/20], Loss: 0.2610
Epoch [4/20], Loss: 0.2397
Epoch [5/20], Loss: 0.2201
Epoch [6/20], Loss: 0.2041
Epoch [7/20], Loss: 0.1874
Epoch [8/20], Loss: 0.1773
Epoch [9/20], Loss: 0.1559
Epoch [10/20], Loss: 0.1417
Epoch [11/20], Loss: 0.1275
Epoch [12/20], Loss: 0.1173
Epoch [13/20], Loss: 0.1048
Epoch [14/20], Loss: 0.0976
Epoch