In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.optim as optim
%matplotlib inline

def create_imbalanced_mnist(dataset, min_per_class=30, max_per_class=100):
    """
    Creates a subset of the MNIST dataset with a specified imbalance in the number of samples per class.

    The resulting subset will contain a minimum and maximum number of samples per class, where classes with
    fewer samples than the minimum are excluded, and classes with more samples are randomly downsampled to the maximum.

    Args:
        dataset (torchvision.datasets.MNIST): The original MNIST dataset.
        min_per_class (int, optional): Minimum number of samples per class. Classes with fewer samples are excluded. Defaults to 30.
        max_per_class (int, optional): Maximum number of samples per class. Classes with more samples are randomly downsampled. Defaults to 100.

    Returns:
        torch.utils.data.Subset: A subset of the MNIST dataset with the specified class imbalance.
    """

    indices = []
    for label in range(10):
        class_indices = np.where(np.array(dataset.targets) == label)[0]
        if len(class_indices) > max_per_class:
            class_indices = np.random.choice(class_indices, max_per_class, replace=False)
        elif len(class_indices) < min_per_class:
            continue  
        indices.extend(class_indices)
    return Subset(dataset, indices)

transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_train_dataset = datasets.MNIST('data/', train=True, transform=transformation, download=True)
full_test_dataset = datasets.MNIST('data/', train=False, transform=transformation, download=True)

imbalanced_train_dataset = create_imbalanced_mnist(full_train_dataset)
imbalanced_test_dataset = create_imbalanced_mnist(full_test_dataset)

train_loader = torch.utils.data.DataLoader(imbalanced_train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(imbalanced_test_dataset, batch_size=32, shuffle=True)

class ResidualBlock(nn.Module):
    """
    A residual block for a ResNet-like architecture.

    This block consists of two convolutional layers with batch normalization and ReLU activations, 
    and includes a skip connection (residual connection) to improve gradient flow and facilitate learning.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        stride (int, optional): The stride of the convolutional layers. Defaults to 1.

    Attributes:
        conv1 (torch.nn.Conv2d): The first convolutional layer.
        bn1 (torch.nn.BatchNorm2d): The first batch normalization layer.
        relu (torch.nn.ReLU): The ReLU activation function.
        conv2 (torch.nn.Conv2d): The second convolutional layer.
        bn2 (torch.nn.BatchNorm2d): The second batch normalization layer.
        downsample (torch.nn.Sequential): The downsampling layer if needed to match dimensions.

    Methods:
        forward(x):
            Defines the forward pass of the residual block.
    """

    def __init__(self, in_channels, out_channels, stride=1):
        """
        Initializes the residual block layers.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            stride (int, optional): The stride of the convolutional layers. Defaults to 1.
        """
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        """
        Defines the forward pass of the residual block.

        Args:
            x (torch.Tensor): The input tensor with shape (N, in_channels, H, W), where N is the batch size, 
                              and H and W are the height and width of the input.

        Returns:
            torch.Tensor: The output tensor after passing through the residual block.
        """
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.downsample(identity)
        out = self.relu(out)
        return out

class Net(nn.Module):
    """
    A custom neural network combining a modified ResNet18 backbone with additional residual blocks.

    This network starts with a modified ResNet18 model where the first convolutional layer is adapted for a single-channel input.
    It includes two additional residual blocks, followed by fully connected layers for classification.

    Attributes:
        initial_layers (torch.nn.Sequential): The initial convolutional layers, including the modified ResNet layers.
        residual_block1 (ResidualBlock): The first residual block with 64 output channels.
        residual_block2 (ResidualBlock): The second residual block with 128 output channels and stride 2.
        fc1 (torch.nn.Linear): The first fully connected layer.
        fc2 (torch.nn.Linear): The second fully connected layer.
    
    Methods:
        forward(x):
            Defines the forward pass of the network.
        
        _get_conv_output(shape):
            Helper function to determine the size of the input to the fully connected layer.
    """

    def __init__(self):
        """
        Initializes the network layers, including a modified ResNet18 backbone, additional residual blocks, 
        and fully connected layers.
        """
        super(Net, self).__init__()
        
        resnet = models.resnet18(pretrained=True)
        
        self.initial_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),  # Change input channels to 1
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )

        self.initial_layers[0].weight.data = resnet.conv1.weight.data.mean(dim=1, keepdim=True)

        self.residual_block1 = ResidualBlock(64, 64)
        self.residual_block2 = ResidualBlock(64, 128, stride=2)
        
        self.fc1 = nn.Linear(128 * 7 * 7, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):
        """
        Defines the forward pass of the network.

        Args:
            x (torch.Tensor): The input tensor with shape (N, 1, H, W), where N is the batch size, 
                              and H and W are the height and width of the input images.

        Returns:
            torch.Tensor: The output tensor after passing through the network, with log softmax applied.
        """
        x = self.initial_layers(x)
        x = self.residual_block1(x)
        x = self.residual_block2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def _get_conv_output(self, shape):
        """
        Helper function to dynamically determine the input size for the fully connected layer.

        Args:
            shape (tuple): The shape of the input tensor to compute the output size for the fully connected layer.

        Returns:
            int: The number of features after the convolutional layers and residual blocks.
        """
        x = torch.rand(shape)
        x = self.initial_layers(x)
        x = self.residual_block1(x)
        x = self.residual_block2(x)
        x = x.view(x.size(0), -1)
        return x.size(1)

def fit(epoch, model, data_loader, phase='training'):
    """
    Trains or validates the model for one epoch.

    This function performs either training or validation of the model depending on the specified phase. 
    It calculates the loss and accuracy for each batch in the data loader and prints the results at the end of the epoch.

    Args:
        epoch (int): The current epoch number.
        model (torch.nn.Module): The model to be trained or validated.
        data_loader (torch.utils.data.DataLoader): The data loader providing the data and target batches.
        phase (str, optional): The phase of the training process, either 'training' or 'validation'. Defaults to 'training'.

    Returns:
        tuple: A tuple containing:
            - loss (float): The average loss over the dataset.
            - accuracy (float): The accuracy of the model on the dataset as a percentage.
    """
    if phase == 'training':
        model.train()
    elif phase == 'validation':
        model.eval()

    running_loss = 0.0
    running_correct = 0
    
    with torch.no_grad() if phase == 'validation' else torch.enable_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = Variable(data), Variable(target)
            
            if phase == 'training':
                optimizer.zero_grad()
                
            output = model(data)
            loss = F.nll_loss(output, target)
            running_loss += loss.item()
            preds = output.data.max(dim=1, keepdim=True)[1]
            running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()
            
            if phase == 'training':
                loss.backward()
                optimizer.step()

    loss = running_loss / len(data_loader.dataset)
    accuracy = 100.0 * running_correct.item() / len(data_loader.dataset)
    
    print(f'{phase} loss is {loss:{5}.{2}} and {phase} accuracy is {running_correct}/{len(data_loader.dataset)} {accuracy:{10}.{4}}')
    return loss, accuracy

model = Net()
n_features = model._get_conv_output((1, 1, 28, 28))
model.fc1 = nn.Linear(n_features, 50)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
train_losses, train_accuracy = [], []
val_losses, val_accuracy = [], []

for epoch in range(1, 20):
    train_epoch_loss, train_epoch_accuracy = fit(epoch, model, train_loader, phase='training')
    val_epoch_loss, val_epoch_accuracy = fit(epoch, model, test_loader, phase='validation')
    train_losses.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

plt.plot(range(1, len(train_losses) + 1), train_losses, 'bo', label='training')
plt.plot(range(1, len(val_losses) + 1), val_losses, 'r', label='validation')
plt.title('Loss')
plt.legend()
plt.show()

plt.plot(range(1, len(train_accuracy) + 1), train_accuracy, 'bo', label='training')
plt.plot(range(1, len(val_accuracy) + 1), val_accuracy, 'r', label='validation')
plt.title('Accuracy')
plt.legend()
plt.show()
