# CS484 Final Project: Semi-Supervised Learning on CIFAR-10

Student Name: Stone Hu

## Team Members and Contributions:

| Name                | Email              | Student Number | Contributions: |
|---------------------|--------------------|----------------|----------------|
| Stone Hu            | ds2hu@uwaterloo.ca | 20890769       | Research/implementation for loss functions, model development, explanation of loss functions and gathered conclusions |
| Justin Metivier     | jmetivie@uwaterloo.ca | 20874949    | Development of generic model for loss function development and MNIST testing, custom data loader for semi-labeled data, model training, helped retrofit models to have specified loss functions |
| Andrew Batek        | abatek@uwaterloo.ca | 20892302      | Preliminary paper review, LeNet implementation, hyperparameter tuning and model training, training result aggregation |
| Matthew Erxleben    | merxlebe@uwaterloo.ca | 20889980    | Preliminary research on model architecture, ResNet implementation, jupyter notebook tables/graphs/notes |

## Code Libraries:

- PyTorch: Model construction and training
- PyTorch Lightning: Data loading tools
- Torchvision: Datasets sourcing
- TensorBoard: Tracking training and visualizing performance

## Project Topic 5: Semi-Supervised Image Classification

## Abstract:

The overarching theme of this project is the implementation of semi-supervised learning across different CNN architectures with different semi-supervised losses. There were two main goals with this: 1 - to understand implementation and effectiveness of different semi-supervised techniques and loss functions, and 2 - to compare the robustness in prediction accuracy of different models across different ratios of labeled to unlabeled data. 

We focused on 2 different model architectures - an implementation of LeNet and an implementation of ResNet. They were modified to use a traditional supervised loss function of cross entropy in conjunction with different unsupervised loss techniques to handle the presence of labeled and unlabeled data simultaneously. These loss techniques are mutual information loss, entropy minimization loss, consistency regularization loss, and K-Means clustering loss.

This project compares the accuracy of LeNet and ResNet using all four of these loss functions on the CIFAR-10 Dataset.

## Project Overview:

We implemented 3 different models. The first is a generic CNN architecture which we implemented with cross entropy loss as the supervised loss and a variety of different unsupervised loss techniques. We tested Mutual Information loss, sourced from Bridle & MacKay "Unsupervised Classifiers, Mutual Information and Phantom Targets", NIPS 1991, Entropy Minimization loss, and Consistency Regularization loss, and chose the best performing of the three for our final model architecture. This model was used as more of a testing ground to develop the different loss functions and do preliminary research on effectiveness of different losses. We then chose two preexisting CNN architectures to modify using our findings from the first model. The first model is an implementation of LeNet (Y. Lecun, L. Bottou, Y. Bengio and P. Haffner, "Gradient-based learning applied to document recognition," in Proceedings of the IEEE, Nov. 1998). Our second model is an implementation of ResNet (He, Kaiming, X. Zhang, Shaoqing Ren and Jian Sun. “Deep Residual Learning for Image Recognition.” 2016 IEEE Conference on Computer Vision and Pattern Recognition). In addition to the three aforementioned loss functions, we also implemented a K-means loss function using the deep features of the labeled data for these two models. Our project compares the effectiveness of the different loss functions on these two models, as well as compares robustness across different ratios of labeled to unlabeled data.

We initially chose MNIST as our dataset as it was easier to develop around, however, after finding good results across all models, we then chose to modify each to be trained on CIFAR-10, as we hoped this would accentuate any differences between models, loss functions, and label ratio.

## Models:

#### Baseline CNN:

This model aimed to implement a simple, no-frills CNN architecture to see if there were any advantages to a simpler model while handling a mix of labelled and unlabeled data. It is 2 convolutional layers followed by a linear layer. We also used it as a sandbox to implement the different loss functions and ensure they were performing as expected. It performed quite well on MNIST, however, we decided not to pursue comparing it with CIFAR-10 as the other two models were overall better. We also knew the limitations of the depth of the architecture when dealing with RGB images would lead to subpar results.


#### LeNet:

We employed LeNet, one of the first convolutional neural network architectures introduced by Yann LeCun in 1998. With only 5 layers, LeNet is relatively shallow and more suited for grayscale images. Its simplicity makes it less suited for larger, RGB input, but it is a good demonstration of early convolutional neural networks.


#### ResNet:
We also implemented ResNet, a popular neural network architecture introduced in 2015 that has become the backbone for many modern vision models. ResNet’s key innovation is the use of residual/skip connections, which allow gradients to flow through many more layers without vanishing. In our project, we utilized on ResNet18, the smallest variant featuring 18 layers, chosen to balance model capacity with our computational constraints. Due to the skip connections, ResNet18 handles larger RGB images well and allows for exploring deeper versions of ResNet in future work.


## Loss Functions:

For all our models, we used the standard cross-entropy loss as our loss function for labeled (supervised) data, and augmented this loss with a variety of unsupervised loss functions, comparing the effectiveness of each by using each one in isolation with our supervised loss. Our 4 loss functions were KMeans clustering loss, mutual information loss, entropy minimization loss, consistency regularization loss. Each of these loss functions is motivated by an assumption that we make about our dataset.

In KMeans clustering loss, we add an additional set of parameters cluster_centers to each model. The idea behind this is that when we extract the deep features of an image using a CNN, hopefully the points representing the deep features for data samples of the same class (i.e. two images of airplanes) end up close to each other. With this idea, we assume that the deep features for inputs of the same class will naturally end up in clusters, which we will model using KMeans clusters. We set the number of clusters to be the number of output classes that the network is trying to predict, with the dimension of a point being the dimension of a deep feature vector for the neural net. Then, during training, we calculate the KMeans loss between each deep feature and its closest cluster, and backpropagate this loss through both the convolutional network and the KMeans cluster center parameters to update the weights of the network. 

In entropy minimization loss, we make no modifications to our network, but directly calculate an additional loss for unlabeled data in the training loop. The idea behind entropy minimization is to directly use the formula for entropy as a loss function, and minimize it to reduce uncertainty in the network’s output probability vectors (i.e. make them closer to one-hot). The motivation for this is that elements in our dataset have a clear correct class, and that we want to encourage our model to make more confident (low-entropy) classifications. When combined with labeled training data, the network is more likely to make correct guesses, and entropy minimization loss pushes the network to assign higher probabilities to one single output class (which is hopefully the correct one). To implement this, we take the entropy of the softmax probability vector from the network and add it to our cross entropy loss, and backpropagate this through the entire network.

In mutual information loss, we make no direct modifications to our network, but extract the deep features of the CNN and measure the correlation between those and the input features. The motivation for mutual information loss is that there should be a strong correlation (i.e. high shared information) between the predicted class and the input image, so we formulate a loss function for the output softmax probability vector based on this assumption. For our implementation, we expanded on this loss by instead extracting the deep features of the CNN instead, since we believed there would be a stronger correlation between these and the input images. Then, we backpropagate the losses through the network.

For consistency regularization loss, we implement two forward passes through the network, one with the original image, and one with some noisy transformations applied to the image. We calculate a loss based on the mean squared error between the two softmax probability outputs, and backpropagate this through the network. The motivation for this loss is that the predicted class for input images shouldn’t change based on certain types of added noise, such as transformations, horizontal/vertical flips, etc., so we encourage the network to predict the same class for the original and noisy versions of the same image.

## Code:

### Imports from external libraries

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_float32_matmul_precision('medium')

import torchvision as tv
from torch.utils.data import DataLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
import os


In [None]:
from mylibs.util import split_training_data
from mylibs.model import SemiSupervisedClassifier

### Generic CNN Using MNIST our group used to develop Loss Functions

In [None]:
TRAIN_NEW_MODELS = False

In [None]:
transforms = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.1307], std=[0.3081])
])

TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 128


m_dataset = tv.datasets.MNIST(
    root="images",
    train=True,
    download=True,
    transform=transforms
)

train_set_size = int(len(m_dataset) * TRAIN_TEST_SPLIT)
test_set_size = len(m_dataset) - train_set_size
train_ds, test_ds = torch.utils.data.random_split(m_dataset, [train_set_size, test_set_size])
test_loader = DataLoader(dataset=test_ds, batch_size = BATCH_SIZE)

labeled_dl, unlabeled_dl = split_training_data(train_ds, 1-0.9, BATCH_SIZE)
mix_dl = CombinedLoader({
    "supervised": labeled_dl,
    "unsupervised": unlabeled_dl
}, mode="max_size_cycle")


if TRAIN_NEW_MODELS:
    classifier = SemiSupervisedClassifier(loss_fcn="mi")
    classifier.train_model(mix_dl)
    classifier.test_model(test_loader=test_loader, plot_loss=True)

### Split Datasets for training/testing and labeled/unlabeled data

In [None]:
TRAIN_TEST_SPLIT = 0.9

transforms = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.1307], std=[0.3081])
])

dataset = tv.datasets.CIFAR10(
    root="images",
    train=True,
    download=True,
    transform=transforms
)

BATCH_SIZE = 128

# split dataset into training and test set
train_set_size = int(len(dataset) * TRAIN_TEST_SPLIT)
test_set_size = len(dataset) - train_set_size
train_ds, test_ds = torch.utils.data.random_split(dataset, [train_set_size, test_set_size])
test_loader = DataLoader(dataset=test_ds, batch_size = BATCH_SIZE)

### Create DataLoaders for each split above

In [None]:

# 100% labeled DataLoader
labeled_loader, unlabeled_loader = split_training_data(train_ds, unlabeled_split=0)
all_labeled_loader = CombinedLoader({
    "supervised": labeled_loader,
    "unsupervised": unlabeled_loader
}, mode="max_size_cycle")

# 50% labeled vs unlabeled DataLoader
labeled_loader, unlabeled_loader = split_training_data(train_ds, unlabeled_split=0.5)
half_labeled_loader = CombinedLoader({
    "supervised": labeled_loader,
    "unsupervised": unlabeled_loader
}, mode="max_size_cycle")

# 5% labeled vs unlabeled DataLoader
labeled_loader, unlabeled_loader = split_training_data(train_ds, unlabeled_split=0.95)
barely_labeled_loader = CombinedLoader({
    "supervised": labeled_loader,
    "unsupervised": unlabeled_loader
}, mode="max_size_cycle")

### 

### Define Loss Functions and dataloaders for iteration ###

In [None]:
# K means, consistency regularization, entropy minimization, mutual information
loss_fcns = ["km", "cr", "em", "mi"]

data_loaders = [{"name": "all_labeled", "loader": all_labeled_loader}, 
                {"name": "half_labeled", "loader": half_labeled_loader},
                {"name": "barely_labeled", "loader": barely_labeled_loader}]


### LeNet Models ###

In [None]:
from mylibs.lenet import LeNet, train_model

if not os.path.exists("LeNetModels"):
    os.mkdir("LeNetModels")

learning_rate = 1e-3
epochs = 50
lam = 0.1
if TRAIN_NEW_MODELS:
    for loss_fcn in loss_fcns:
        for dl in data_loaders:
            name = dl["name"]
            loader = dl["loader"]
            model_name = f"LeNet_{loss_fcn}_{name}"
            file_name = f"LeNetModels/{model_name}.pth"
            train_model(learning_rate=learning_rate, epochs=epochs, lam=lam, train_loader=loader, test_loader=test_loader, model_file=file_name, unsup_loss_fn=loss_fcn)


### Resnet Models

In [None]:
from mylibs.resnet_test import train_model

if not os.path.exists("ResnetModels"):
    os.mkdir("ResnetModels")

learning_rate = 1e-3
epochs = 20
lam = 0.1
if TRAIN_NEW_MODELS:
    for loss_fcn in loss_fcns:
        for dl in data_loaders:
            name = dl["name"]
            loader = dl["loader"]
            model_name = f"ResNet_{loss_fcn}_{name}"
            file_name = f"ResnetModels/{model_name}.pth"
            train_model(learning_rate=learning_rate, epochs=20, lam=lam, train_loader=loader, test_loader=test_loader, model_file=file_name, unsup_loss_fn=loss_fcn)

These models can be run individually by running their files and modifying the constants for desired label split and loss function

## Results:

After training and testing each model, we used TensorBoard to visualize training accuracy, test accuracy, and training loss over all epochs. We chose to use TensorBoard over plotting in Matplotlib for more efficient visualization.



### LeNet:

| Loss Function:                   | Training Accuracy:                                                 | Test Accuracy:                                                    | Training Loss:                                                   | Legend:                                 |
|:--------------------------------:|:-------------------------------------------------------------------:|:-----------------------------------------------------------------:|:----------------------------------------------------------------:|:---------------------------------------:|
| K Means                         | <img src="images/LN_KM_TA.png"     width="300" height="200"/>       | <img src="images/LN_KM_TestA.png" width="300" height="200"/>      | <img src="images/LN_KM_TL.png"     width="300" height="200"/>    | <img src="images/legend5.png" width="300"/> |
| Mutual Information              | <img src="images/LN_MI_TA.png"     width="300" height="200"/>       | <img src="images/LN_MI_TestA.png" width="300" height="200"/>      | <img src="images/LN_MI_TL.png"     width="300" height="200"/>    | <img src="images/legend6.png" width="300"/> |
| Entropy Minimization            | <img src="images/LN_EM_TA.png"     width="300" height="200"/>       | <img src="images/LN_EM_TestA.png" width="300" height="200"/>      | <img src="images/LN_EM_TL.png"     width="300" height="200"/>    | <img src="images/legend7.png" width="300"/> |
| Consistency Regularization       | <img src="images/LN_CR_TA.png"     width="300" height="200"/>       | <img src="images/LN_CR_TestA.png" width="300" height="200"/>      | <img src="images/LN_CR_TL.png"     width="300" height="200"/>    | <img src="images/legend8.png" width="300"/> |


NOTE: As we can see, the loss for 100% labelled data in each graph is the same across all visualizations per unsuprivised loss function. This is because it is actually based on supervised loss cross entropy, and is simply used to display the difference between the values in the supervised vs semi-supervised models. 

The legend labels are structured as follows: the model architecture used (LeNet or Resnet), the dataset used (all CIFAR10), the abbreviated unsupervised loss function (ie km is Kmeans) and the percentage of data that is labelled (5%, 50% or 100%). Wherever the data has 100% labels, the curve is always the same between different unsupervised loss functions.

We now look at the test accuracy of LeNet for each of the Unsuprivised Loss Functions. Then for each loss function, we break the performance down into the 50% labelled data and 5% labelled data.

##### LeNet Baseline Test Accuracy:

| Supervised Loss Function   | 100% labelled data    |
|:---------------------------|:---------------------:|
| Cross Entropy              | 0.6176                |

##### LeNet Unsupervised Loss Test Accuracy:

| Unsupervised Loss Function | 50% labelled data     | 5% labelled data    |
|:---------------------------|:---------------------:|:---------------------:|
| K-means                    | 0.6060                | 0.6582              |
| Mutual Information         | 0.0900                | 0.0900              |
| Entropy Minimization       | 0.5971                | 0.6284              |
| Consistency Regularization | 0.5626                | 0.6006              |

### ResNet:

| Loss Function:                                                         | Training Accuracy:                                                         | Training Loss:                                                            | Test Accuracy:                                                         | Legend:                                                         |
|:-------------------------------------------------------------------------:|:-------------------------------------------------------------------------:|:------------------------------------------------------------------------:|:----------------------------------------------------------------------:|:----------------------------------------------------------------------:|
| K Means | <img src="images/RN_KM_TA.png"     width="300" height="200"/>              | <img src="images/RN_KM_TestA.png" width="300" height="200"/>              | <img src="images/RN_KM_TL.png"     width="300" height="200"/>          | <img src="images/legend1.png" width="300"/> |
| Mutual Information | <img src="images/RN_MI_TA.png"     width="300" height="200"/>              | <img src="images/RN_MI_TestA.png" width="300" height="200"/>              | <img src="images/RN_MI_TL.png"     width="300" height="200"/>          | <img src="images/legend2.png" width="300"/> |
| Entropy Minimization | <img src="images/RN_EM_TA.png"     width="300" height="200"/>              | <img src="images/RN_EM_TestA.png" width="300" height="200"/>              | <img src="images/RN_EM_TL.png"     width="300" height="200"/>          | <img src="images/legend3.png" width="300"/> |
| Consistency Regularization | <img src="images/RN_CR_TA.png"     width="300" height="200"/>              | <img src="images/RN_CR_TestA.png" width="300" height="200"/>              | <img src="images/RN_CR_TL.png"     width="300" height="200"/>          | <img src="images/legend4.png" width="300"/> |


##### ResNet Baseline Test Accuracy:

| Supervised Loss Function   | 100% labelled data    |
|:---------------------------|:---------------------:|
| Cross Entropy              | 0.7744                |

##### ResNet Unsupervised Loss Test Accuracy:

| Loss Function              | 50% labelled data | 5% labelled data |
|:---------------------------|:-----------------:|:----------------:|
| K-means                    | 0.7046            | 0.3798           |
| Mutual Information         | 0.7800            | 0.5152           |
| Entropy Minimization       | 0.7713            | 0.5336           |
| Consistency Regularization | 0.7722            | 0.5736           |


## Conclusions:

Based on our results, it seems like our LeNet model is too weak to model a dataset like CIFAR10, which contains far more complexities than the dataset it was originally designed around (MNIST). We noticed that regardless of any loss function or split of labeled/unlabeled data, the test accuracy always hovered at around 60%, which indicates that the model might be the limiting factor in this case. We also found that mutual information loss with LeNet actually caused training to diverge and cause the model to start random guessing, with a test accuracy of less than 10%.

On the other hand, by implementing a deeper and more complex network in ResNet, we were able to see some more interesting data and see noticeable differences between the different combinations of loss functions with supervised/unsupervised splits. Almost all of our semi-supervised models with 50% labelled and 50% unlabelled data did just as well as the fully labelled dataset. With 5% labelled data and 95% unlabelled, we found there was not enough training data for the model to accurately represent the distribution of images, and our ResNet model actually overfitted the few labeled data samples and achieved 100% training accuracy on the small amount of labeled data while having very low test accuracy, regardless of the unsupervised loss function.

Between each unsupervised loss function, the most standout result was that KMeans loss performed worse compared to the other loss functions. Although it is difficult to know for sure what causes models to succeed or fail, our theory for why KMeans did poorly for ResNet is that the deep feature space for ResNet is extremely large. KMeans clustering loss is trained based on the squared Euclidean distance from a cluster center, but this calculation becomes much more complex in high dimension space. We believe this leads to slightly worse loss optimization and poor test accuracy as a result. 

Aside from this result, we noticed that consistency regularization performed significantly better than the other unsupervised losses with very low amounts of labeled data. We believe this is because consistency regularization discourages overfitting, which was our greatest challenge in that environment. This loss function discourages the model from predicting different classes from small perturbations in the input, which is similar to data augmentation strategies employed in fully supervised networks to reduce overfitting.

Overall, we found that with a sufficient baseline of labeled training data and a sufficiently complex model for the dataset, adding unlabeled data and training using semi-supervised methods was a very effective method for improving test accuracy. Consistency regularization loss performed the best out of all unsupervised losses in environments with minimally labeled data.


## References:

Y. Lecun, L. Bottou, Y. Bengio and P. Haffner, "Gradient-based learning applied to document recognition," in Proceedings of the IEEE, vol. 86, no. 11, pp. 2278-2324, Nov. 1998, doi: 10.1109/5.726791.

Bridle & MacKay "Unsupervised Classifiers, Mutual Information and Phantom Targets", NIPS 1991

He, Kaiming, X. Zhang, Shaoqing Ren and Jian Sun. “Deep Residual Learning for Image Recognition.” 2016 IEEE Conference on Computer Vision and Pattern Recognition

Yves, G., & Yoshua, B. (2006). Entropy regularization. Semi-Supervised Learning, 151–168. https://doi.org/10.7551/mitpress/9780262033589.003.0009 

Fan, Y., Kukleva, A., Dai, D., & Schiele, B. (2022). Revisiting consistency regularization for semi-supervised learning. International Journal of Computer Vision, 131(3), 626–643. https://doi.org/10.1007/s11263-022-01723-4 

Cho, M., Vahid, K. A., Adya, S., & Rastegari, M. (2022). DKM: Differentiable K-Means Clustering Layer for Neural Network Compression. arXiv. https://arxiv.org/abs/2108.12659