# DL final project
*   Alon Meirovich, ID: 330181470
*   Matan Goldfarb, ID: 314623174
*   Talya Yermiahu, ID: 207594193

## Imports


In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from torchvision.transforms import Grayscale, Resize
import time

from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/My Drive/Colab Notebooks/DL Final Project')

from project_utils import CombinedDataset, eval_model

Mounted at /content/drive


##Training Mode


In [2]:
is_traiting = False

## Data Section

In [3]:
# Parameters
batch_size = 64

# Transformers
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=(-10, 10), translate=(0.01, 0.15),
                            scale=(0.9, 1.1), fill=-1)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Grayscale(),
])

transform_ood = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset
mnist_train = datasets.MNIST(root='./data', train=True, transform=transform_train, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transform_test, download=True)

# Split the train set into train and validation sets
train_size = int(0.8 * len(mnist_train))
val_size = len(mnist_train) - train_size
mnist_train, mnist_val = random_split(mnist_train, [train_size, val_size])

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
# Load OOD datasets for testing
cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_ood)
fashion_mnist = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_ood)
# Concatenate CIFAR-10 and FashionMNIST datasets
ood_dataset = ConcatDataset([cifar10, fashion_mnist])

# Combine MNIST test set with CIFAR-10 and FashionMNIST as OOD data
combined_test_loader = DataLoader(CombinedDataset(mnist_test, fashion_mnist), batch_size=batch_size, shuffle=True)

Files already downloaded and verified


## Model Class


In [4]:
# Model Class
class OSRCNN(nn.Module):
    def __init__(self, th):
        super(OSRCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 11)  # 10 classes + 1 unknown
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.th = th
        self.valMode = False

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        if not self.training and not self.valMode:
            with torch.no_grad():
                x = self.softmax(x)
                probas, y_pred = torch.max(x, 1)
                y_pred[probas < self.th] = 10
                return y_pred
        return x

    def set_validation(self, val_mode):
        self.valMode = val_mode


## Useful Functions


In [5]:
# Training the OSR model
def train_osr(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        norm_loss = running_loss/len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {norm_loss}")
    return norm_loss

# Validation function for the OSR model
def validate_osr(model, val_loader, criterion):
    model.eval()
    model.set_validation(True)
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            norm_val_loss = val_loss/len(val_loader)
    print(f"Validation Loss: {norm_val_loss}")
    model.set_validation(False)
    return norm_val_loss

def plot_accuracy(train_accuracies, val_accuracies):
    plt.plot(train_accuracies, label='Training accuracy')
    plt.plot(val_accuracies, label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()


def plot_predictions(model, loader, class_names):
    model.eval()
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)

            plt.figure(figsize=(10, 4))
            for i in range(10):
                plt.subplot(2, 5, i + 1)
                plt.imshow(inputs[i].cpu().squeeze(), cmap='gray')
                plt.axis('off')
                if outputs[i] == len(class_names):
                    plt.title(f'Pred: Unknown, Actual: {class_names[labels[i]]}')
                else:
                    plt.title(f'Pred: {class_names[outputs[i]]}, Actual: {class_names[labels[i]]}')
            plt.tight_layout()
            plt.show()
            break


def plot_loss(loss, num):
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, len(loss) + 1), loss)
    if num == 1:
        plt.title('Training Loss Over Epochs')
    elif num == 2:
        plt.title('Validation Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.show()


def show_images(images, titles=None):
    fig, axs = plt.subplots(1, len(images), figsize=(15, 15))
    for i, img in enumerate(images):
        axs[i].imshow(img.view(28, 28).cpu().detach().numpy(), cmap='gray')
        if titles:
            axs[i].set_title(titles[i])
        axs[i].axis('off')
    plt.show()

## Training Process


In [6]:
# Train the OSR model

if is_traiting:

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

  epochs = 30
  lr = 0.001
  # Initialization
  torch.manual_seed(42)
  model = OSRCNN(0).to(device)
  osr_criterion = nn.CrossEntropyLoss()
  osr_optimizer = optim.Adam(model.parameters(), lr=lr)

  # Training
  train_losses = []
  val_losses = []
  time_0 = time.time()

  for epoch in range(epochs):
      train_losses.append(train_osr(model, train_loader, osr_criterion, osr_optimizer, epoch))
      val_losses.append(validate_osr(model, val_loader, osr_criterion))
      # Save the OSR model
      model_filename = f'osr_soft00001_model_epoch_{epoch + 1}.pth'
      torch.save(model.state_dict(), model_filename)


  print("Total run-time: %s seconds" % (time.time() - time_0))

  class_names = [str(i) for i in range(10)] + ["Unknown"]
  plot_predictions(model, train_loader, class_names)
  plot_loss(train_losses, 1), plot_loss(val_losses, 2)

## Evaluate OSR model


In [7]:
# Evaluate the OSR model
# Loading the saved OSR model
th = 0.99
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
osr_model_loaded = OSRCNN(th).to(device)
osr_model_loaded.load_state_dict(torch.load('osr_model_epoch_30.pth', map_location=device))
osr_model_loaded.eval()
acc_mnist, acc_ood, acc_total = eval_model(osr_model_loaded, combined_test_loader, device)
print(f'MNIST Accuracy: {acc_mnist*100:.2f}%')
print(f'OOD Accuracy: {acc_ood*100:.2f}%')
print(f'Total Accuracy: {acc_total*100:.2f}%')

# So after ±50 runs this model approaches a 96% of accuracy, with another model we may get more!!
# There won't be no time limit for training, just for test
# Ask Ron how many pics will be on test

MNIST Accuracy: 94.96%
OOD Accuracy: 96.14%
Total Accuracy: 95.55%


## Evaluate BaseLine model

In [8]:
# Evaluate the OSR model
# Loading the saved OSR model
th = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseLine_model_loaded = OSRCNN(th).to(device)
baseLine_model_loaded.load_state_dict(torch.load('osr_model_epoch_30.pth', map_location=device))
baseLine_model_loaded.eval()
acc_mnist, acc_ood, acc_total = eval_model(baseLine_model_loaded, test_loader, device)
print(f'MNIST Accuracy: {acc_mnist*100:.2f}%')
print(f'Total Accuracy: {acc_total*100:.2f}%')

# So after ±50 runs this model approaches a 96% of accuracy, with another model we may get more!!
# There won't be no time limit for training, just for test
# Ask Ron how many pics will be on test

MNIST Accuracy: 99.29%
Total Accuracy: 99.29%
