# Building a Basic CNN

In this notebook, we create a simple CNN model to classify the images. This basic model will act as our starting point and will be used as a benchmark to compare with more advanced models later. However this problem was much harder than we anticipated, this model has a really poor performance.

**Note:** We highly recommend running this notebook on a GPU. 

## 0. Initialization

Import the packages.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
import os

os.chdir("..")
import requests
import zipfile
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from src.models.simple_cnn import OptimizedCNN
from src.utils import seed_everything
from src.loading import load_data
from src.train import train

In [None]:
# Set seeds
seed_everything()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



---

<a name='s1'></a>
## 1. Downloading the dataset

Fetching the dataset should take around 4-5 minutes. Unzipping takes 20s.

In [None]:
# if folder 'data/' is does not exist, download the data
if not os.path.exists("data/"):
    # Dropbox URL
    dropbox_url = "https://www.dropbox.com/scl/fi/sa14unf8s47e9ym125zgo/data.zip?rlkey=198bg0cmbmmrcjkfufy9064wm&dl=1"

    # File path where the .zip file will be saved
    file_path = "data.zip"

    response = requests.get(dropbox_url)

    if response.status_code == 200:
        with open(file_path, "wb") as file:
            file.write(response.content)
        message = "Download successful. The file has been saved as 'data.zip'."
    else:
        message = "Failed to download the file. Error code: " + str(
            response.status_code
        )

    print(message)

    # Path to the downloaded .zip file
    zip_file_path = "data.zip"

    # Directory to extract the contents of the zip file
    extraction_path = ""

    # Unzipping the file
    with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
        zip_ref.extractall(extraction_path)

    extraction_message = (
        f"The contents of the zip file have been extracted to: {extraction_path}"
    )

    print(extraction_message)

## 2. Data

In [None]:
# Load the .jpeg files in the data folder
PATH_IMAGES = "data/images"
PATH_LABELS = "data/labels/trainLabels.csv"
# Hyperparameters
batch_size = 8
img_size = (512, 512)
num_epochs = 20

In [None]:
train_loader, validation_loader = load_data(
    PATH_LABELS, PATH_IMAGES, img_size, batch_size
)

In [None]:
# Visualize an image
for images, labels in train_loader:
    print(images.shape)
    print(labels.shape)
    plt.figure(figsize=(6, 6))
    plt.axis("off")
    plt.imshow(np.transpose(images[0], (1, 2, 0)))
    plt.show()
    break

## 3. Model

In [None]:
model = OptimizedCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)
model.to(device)

### 3.1 Model training and evaluation

In [None]:
train(
    model,
    train_loader,
    validation_loader,
    criterion,
    optimizer,
    device,
    model_name="results/models/cnn_test.pt",
    num_epochs=num_epochs,
)