In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Step 1: Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

# Step 2: Visualize the distribution of labels
def plot_label_distribution(dataset):
    label_counts = [0] * 10
    for data in dataset:
        _, label = data
        label_counts[label] += 1

    labels = [str(i) for i in range(10)]
    plt.bar(labels, label_counts)
    plt.xlabel('Labels')
    plt.ylabel('Count')
    plt.title('Distribution of Labels in MNIST Dataset')
    plt.show()

plot_label_distribution(trainset)

# Step 3: Visualize samples from each class
def visualize_samples(dataset, num_samples=5):
    class_samples = {i: [] for i in range(10)}

    for data in dataset:
        image, label = data
        if len(class_samples[label]) < num_samples:
            class_samples[label].append(image)

    fig, axes = plt.subplots(10, num_samples, figsize=(12, 12))

    for i in range(10):
        for j in range(num_samples):
            axes[i, j].imshow(class_samples[i][j].squeeze(), cmap='gray')
            axes[i, j].axis('off')

    plt.suptitle('Samples from Each Class')
    plt.show()

visualize_samples(trainset)

# Step 4: Check for class imbalance
label_counts = [0] * 10
for data in trainset:
    _, label = data
    label_counts[label] += 1

print("Class Imbalance:")
for i, count in enumerate(label_counts):
    print(f"Class {i}: {count} samples")

# Step 5: Partition the dataset into train, validation, and test sets
train_size = int(0.8 * len(trainset))
val_size = int(0.1 * len(trainset))
test_size = len(trainset) - train_size - val_size
train_data, val_data, test_data = torch.utils.data.random_split(trainset, [train_size, val_size, test_size])

# Step 6: Define a function to visualize feature maps (for a given layer)
def visualize_feature_maps(model, layer_idx, image):
    model.eval()
    feature_maps = model.features(image.unsqueeze(0))
    selected_feature_map = feature_maps[0, layer_idx]

    plt.figure(figsize=(8, 8))
    plt.imshow(selected_feature_map.detach().cpu(), cmap='viridis')
    plt.title(f'Feature Map {layer_idx}')
    plt.axis('off')
    plt.show()

# You'll need to have a trained CNN model for this step.
# Replace 'model' with your trained model and specify the layer index.

# Example:
# visualize_feature_maps(model, 0, image)



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Define the CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Initialize the model, loss function, and optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    train_losses, val_losses = [], []
    train_acc, val_acc = [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_accuracy = 100 * correct_train / total_train
        train_losses.append(running_loss / len(train_loader))
        train_acc.append(train_accuracy)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

            val_accuracy = 100 * correct_val / total_val
            val_losses.append(val_loss / len(val_loader))
            val_acc.append(val_accuracy)

        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracy:.2f}%, '
              f'Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accuracy:.2f}%')

    return train_losses, val_losses, train_acc, val_acc

# Train the model
train_losses, val_losses, train_acc, val_acc = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

# Plot training and validation loss and accuracy
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_acc, label='Train Acc')
plt.plot(val_acc, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

# Function to display feature maps
def visualize_feature_maps(model, layer, image):
    model.eval()
    activations = model(layer(image.unsqueeze(0))
    num_features = activations.size(1)

    plt.figure(figsize=(8, 8))
    for i in range(num_features):
        plt.subplot(8, 8, i + 1)
        plt.imshow(activations[0, i].detach().cpu(), cmap='viridis')
        plt.axis('off')
    plt.show()

# Choose an image from the validation dataset and a specific layer (e.g., conv1)
sample_image, sample_label = val_loader.dataset[0]
visualize_feature_maps(model, model.conv1, sample_image)
