#### ***TASK 1 - DATA PREPROCESSING***

Define image augmentations in the cell below using two variables:  

- **`transform_train`**: Stores transforms for training images. You can include any augmentations you prefer.  
- **`transform_test`**: Stores transforms for your test images. As a best practice, limit these transformations to only the essential ones from `transform_test`.

Lastly, be sure to convert all images to [tensors](https://www.perplexity.ai/search/i-m-a-student-at-naiss-mlb-and-_EL_nBO9TS694cbTEl5M.A) via the `transforms.ToTensor()` transform. Don't know transforms? [Click here](https://pytorch.org/vision/stable/transforms.html).

In [3]:
import torchvision.transforms as transforms

transform_train = transforms.Compose([
  transforms.RandomRotation(20),
  transforms.RandomHorizontalFlip(0.5),
  transforms.Resize((28, 28)),
  transforms.Grayscale(),
  transforms.ToTensor(),
])

transform_test = transforms.Compose([
  transforms.Resize((28, 28)),
  transforms.Grayscale(),
  transforms.ToTensor(),
])

In [4]:
from torchvision.datasets import FashionMNIST

train_dataset = FashionMNIST(
  root="./data", train=True, download=True, transform=transform_train
)

test_dataset = FashionMNIST(
  root="./data", train=False, download=True, transform=transform_test
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 13.7MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 229kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 4.33MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 5.29MB/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






Now, we'll pass our images and transforms into a [DataLoader](https://www.perplexity.ai/search/i-m-a-student-at-naiss-mlb-and-_EL_nBO9TS694cbTEl5M.A), which allows us to train our model in batches.

In [5]:
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

#### ***TASK 2 - CNN Architecture***

This is where you have all the creative freedom in the world. Here are some good questions to ask yourself:

- How many [channels](https://www.perplexity.ai/search/i-m-a-student-at-naiss-mlb-wha-49AG77e4Qp2e7EkARdFsTA) should go into the input layer?
- What measures can I take to avoid [overfitting](https://www.perplexity.ai/search/i-m-a-student-at-naiss-mlb-wha-YdAbhqQzRZaEq39BEQzA6w)?
- What matters to me? (Training Speed / Performance tradeoffs)

Not comfortable with PyTorch? [Here](https://youtu.be/mozBidd58VQ?si=TE2_81TEQko1eDXT). Go and make me the best [CNN](https://www.datacamp.com/tutorial/introduction-to-convolutional-neural-networks-cnns) I've ever seen :)

In [None]:
import torch
import torch.nn as nn

IMAGE_WIDTH = 28 // (2 * 2 * 2)
IMAGE_HEIGHT = 28 // (2 * 2 * 2)

class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(
        in_channels=1,
        out_channels=16,
        kernel_size=3,
        padding=1,
        stride=1
      ),
      nn.ReLU(),
      nn.MaxPool2d(
        kernel_size=2,
        stride=2
      ),

      nn.Conv2d(
        in_channels=16,
        out_channels=32,
        kernel_size=3,
        padding=1,
        stride=1
      ),
      nn.ReLU(),
      nn.MaxPool2d(
        kernel_size=2,
        stride=2
      ),

      nn.Conv2d(32, 64, 3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(2, 2)
    )

    self.fc = nn.Sequential(
        nn.Flatten(),
        nn.Linear(
          in_features=(64 * IMAGE_WIDTH * IMAGE_HEIGHT),
          out_features=10,
        ),
        nn.Softmax() # May not need this if using correct loss function
    )

  def forward(self, x):
    x = self.conv(x)
    x = self.fc(x)

    return x