### <font color='#FF93D5'> [0] <u>Initialization</u> </font>

In [None]:
# remove residues from notebook
!rm -rf sample_data

# get the zip containing the cat images
!wget https://github.com/Crash285github/nyaural_catwork/raw/main/data/cats.zip -P data/

# extract the cats into data/
from zipfile import ZipFile
with ZipFile('data/cats.zip', 'r') as cats:
  cats.extractall('data')

# remove the downloaded zip
!rm data/cats.zip

### <font color='#FF7580'>[1] <u>Data manipulation</u> </font>

#### <font color='#777'> [1.0] Imports, parameters</font>

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms as tforms
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

# how to divide the dataset later on
train_frac, val_frac, test_frac = 7, 2, 1

# show dataset images on [1.3.1]
visualize_datasets = True
# how many images to show
number_of_samples = 5

#### <font color='#C45A63'> [1.1] Dataset class</font>

In [None]:
class CatDataset(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.total_imgs = []

        for img in os.listdir(main_dir):
            if img.endswith(".zip"):
                continue
            img_loc = os.path.join(self.main_dir, img)
            image = Image.open(img_loc).convert('RGB')
            tensor_image = self.transform(image)
            self.total_imgs.append(tensor_image)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        return self.total_imgs[idx]

#### <font color='#C45A63'>[1.2] Dataset details</font>

In [None]:
# image transformations
transforms = tforms.Compose([
  tforms.ToTensor(),
  tforms.Grayscale()
])

# normalize divisions
n_train_frac = train_frac / (train_frac + val_frac + test_frac)
n_val_frac = val_frac / (train_frac + val_frac + test_frac)
n_test_frac = test_frac / (train_frac + val_frac + test_frac)

print(f'{(n_train_frac*100):.1f} % Training data')
print(f'{(n_val_frac*100):.1f} % Validation data')
print(f'{(n_test_frac*100):.1f} % Testing data')

#### <font color='#C45A63'>[1.3] Create Datasets</font>

In [None]:
# create the dataset
dataset = CatDataset('data', transforms)

# split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [n_train_frac, n_val_frac, n_test_frac])

print(f'{len(train_dataset)}\ttraining images')
print(f'{len(val_dataset)}\tvalidating images')
print(f'{len(test_dataset)}\ttesting images')

##### <font color='#777'>[1.3.1] Visualize Datasets </font>

In [None]:
if visualize_datasets:
   def make_plt(dataset, title):
      fig, _ = plt.subplots(1, number_of_samples)
      fig.set_figwidth(24)
      fig.suptitle(title, fontsize=32)

      selected = np.random.choice(len(dataset), number_of_samples, False)
      for i, ind in enumerate(selected):
         plt.subplot(1, number_of_samples, i+1)
         # we permute: (?, 64, 64) --> (64, 64, ?)
         plt.imshow(dataset[ind].permute(1,2,0), cmap='gray')
         plt.axis('off')

   # training
   make_plt(train_dataset, "Training data")

   # validation
   make_plt(val_dataset, "Validation data")

   # testing
   make_plt(test_dataset, "Testing data")

else:
   print(f"Settable parameter 'visualize_datasets' is False" )


#### <font color='#C45A63'>[1.4] Create DataLoaders </font>

In [None]:
# training loader
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# validation loader
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)

# testing loader
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True)

### <font color='#FFEC72'>[2] <u>Define Classes</u> </font>

#### <font color='#777'> [2.0] Imports</font>

In [None]:
import torch.nn as nn
from copy import deepcopy

#### <font color='#CEBD5C'> [2.1] Dense AutoEncoder</font>

In [None]:
class DenseAutoEncoder(nn.Module):
  def __init__(self):
    super(DenseAutoEncoder, self).__init__()
    self.encoder = nn.Sequential(
      nn.Flatten(), # (N, 64, 64) --> (N, 4096)
      nn.Linear(64*64, 24*24), # (N, 4096) --> (N, 256)
      nn.ReLU(),
      nn.Linear(24*24, 16*16),
      nn.ReLU(),
      nn.Linear(16*16, 4*4),
      nn.ReLU(),
      nn.Linear(4*4, 2*2)    # (N, 4)
    )

    self.decoder = nn.Sequential(
      nn.Linear(2*2, 4*4),
      nn.ReLU(),
      nn.Linear(4*4, 16*16),
      nn.ReLU(),
      nn.Linear(16*16, 24*24),
      nn.ReLU(),
      nn.Linear(24*24, 64*64),
      nn.Sigmoid()
    )

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    return decoded

  def regularization(self):
    l2_reg = sum(p.pow(2).sum() for p in self.parameters())
    return l2_reg

#### <font color='#CEBD5C'> [2.2] Sparse AutoEncoder</font>

In [None]:
class SparseAutoEncoder(nn.Module):
  def __init__(self, l1_lam=0.001):
    super(SparseAutoEncoder, self).__init__()
    self.l1_lam  = l1_lam
    self.encoder = nn.Sequential(
      nn.Flatten(), # (N, 64, 64) --> (N, 4096)
      nn.Linear(64*64, 24*24), # (N, 4096) --> (N, 256)
      nn.ReLU(),
      nn.Linear(24*24, 16*16),
      nn.ReLU(),
      nn.Linear(16*16, 4*4),
      nn.ReLU(),
      nn.Linear(4*4, 2*2)    # (N, 4)
    )

    self.decoder = nn.Sequential(
      nn.Linear(2*2, 4*4),
      nn.ReLU(),
      nn.Linear(4*4, 16*16),
      nn.ReLU(),
      nn.Linear(16*16, 24*24),
      nn.ReLU(),
      nn.Linear(24*24, 64*64),
      nn.Sigmoid()
    )

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    return decoded

  def regularization(self):
    l1_norm = sum(torch.linalg.norm(p, 1) for p in self.parameters())
    return self.l1_lam * l1_norm


#### <font color='#CEBD5C'> [2.3] EarlyStopper</font>

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0., model: nn.Module | None = None):
        self.patience: int = patience
        self.min_delta: float = min_delta
        self.counter: int = 0
        self.min_validation_loss: float = np.inf
        self.__model: nn.Module | None = model
        self.__state_dict = None

    def __call__(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
            if self.__model is not None:
                self.__state_dict = deepcopy(self.__model.state_dict())
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def load_checkpoint(self):
        if self.__model is not None and self.__state_dict is not None:
            with torch.no_grad():
                self.__model.load_state_dict(self.__state_dict)


### <font color='#77FF5B'>[3] <u>Training Loop</u> </font>

#### <font color='#777'> [3.0] Imports</font>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

#### <font color='#5EAA4C'> [3.1] Models, parameters</font>

In [None]:
num_epochs = 30
learning_rate = 0.001

# run the training on cuda gpu if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dense model
dense_model = DenseAutoEncoder().to(device)
dense_criterion = nn.MSELoss()
dense_optimizer = optim.Adam(dense_model.parameters(), lr=learning_rate)

# Sparse model
sparse_model = SparseAutoEncoder().to(device)
sparse_criterion = nn.MSELoss()
sparse_optimizer = optim.Adam(sparse_model.parameters(), lr=learning_rate)

#### <font color='#5EAA4C'> [3.1.1] Training, evaluation & testing functions</font>

In [None]:
def train_one_epoch(model, optimizer, data_loader, criterion):
  model.train()
  losses = []
  for i, image in enumerate(data_loader):
    print(f'\r Training: {(i*100 / len(data_loader)):.1f}%', end='')
    image = image.reshape(-1,64*64).to(device)

    # forward pass
    reconstructed_image = model(image)
    loss = criterion(reconstructed_image, image) + model.regularization()

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
  print('\r Training: 100.0%', end='')
  return np.mean(losses)

def test_model(model, data_loader, criterion, phase):
  model.eval()
  losses = []
  for i, image in enumerate(data_loader):
    print(f'\r {phase}: {(i*100 / len(data_loader)):.1f}%', end='')
    image = image.reshape(-1,64*64).to(device)

    # forward pass
    reconstructed_image = model(image)
    loss = criterion(reconstructed_image, image) + model.regularization()

    losses.append(loss.item())
  print(f'\r {phase}: 100.0%', end='')
  return np.mean(losses)


#### <font color='#5EAA4C'> [3.2] Train DenseAutoEncoder</font>

In [None]:
early_stopper = EarlyStopper(model=dense_model, patience=5)
for epoch in range(num_epochs):
  print(f'Epoch {epoch+1}:')

  train_loss = train_one_epoch(dense_model, dense_optimizer, train_dataloader, dense_criterion)
  print(f'\t| loss: {train_loss.item():.4f}')

  val_loss = test_model(dense_model, val_dataloader, dense_criterion, phase='Validating')
  print(f'\t| loss: {val_loss.item():.4f}')

  test_loss = test_model(dense_model, test_dataloader, dense_criterion, phase='Testing')
  print(f'\t| loss: {test_loss.item():.4f}')

  print('=' * 20)

  if early_stopper(val_loss):
    print("Early stopped")
    early_stopper.load_checkpoint()
    break

#### <font color='#5EAA4C'> [3.3] Train SparseAutoEncoder</font>

In [None]:
early_stopper = EarlyStopper(model=sparse_model, patience=5)
for epoch in range(num_epochs):
  print(f'Epoch {epoch+1}:')

  train_loss = train_one_epoch(sparse_model, sparse_optimizer, train_dataloader, sparse_criterion)
  print(f'\t| loss: {train_loss.item():.4f}')

  val_loss = test_model(sparse_model, val_dataloader, sparse_criterion, phase='Validating')
  print(f'\t| loss: {val_loss.item():.4f}')

  test_loss = test_model(sparse_model, test_dataloader, sparse_criterion, phase='Testing')
  print(f'\t| loss: {test_loss.item():.4f}')

  print('=' * 20)

  if early_stopper(val_loss):
    print("Early stopped")
    early_stopper.load_checkpoint()
    break

#### <font color='#5EAA4C'> [3.3] Visualize original and reconstructed images</font>