# 1.Import

In [15]:
try:
    import lightning as L
except:
    import lightning as L

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from torchmetrics.classification import Accuracy

from PIL import Image

from typing import Any, Callable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data.dataloader import default_collate


from torchvision.ops import StochasticDepth, MLP, Permute
from torchvision.transforms import (
    Compose,
    RandAugment,
    ToTensor,
    Resize,
    Lambda,
    Normalize,
    RandomRotation,
    RandomHorizontalFlip,
    CenterCrop,
    RandomAdjustSharpness
)
from torchvision.transforms.v2 import RandomChoice
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    download_url,
    download_and_extract_archive
)

from functools import partial

import numpy as np
import matplotlib.pyplot as plt

import os
import cv2
import math
import copy
import time
import random
import warnings

import loralib


warnings.filterwarnings("ignore")

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

# 2. Configuration

In [4]:
PATCH_SIZE            = [4, 4]
EMBED_DIM             = 96
DEPTHS                = [2, 2, 6, 2]
NUM_HEADS             = [3, 6, 12, 24]
WINDOW_SIZE           = [7, 7]
STOCHASTIC_DEPTH_PROB = 0.2

In [5]:
IMAGE_SIZE  = 256
NUM_CLASSES = 8

In [6]:
EPOCH      = 3 ** 2
BATCH_SIZE = 6 ** 2

In [7]:
TRUNCATE_PER_CATEGORY = int(1e4)

In [8]:
MOMENTUM = math.sqrt(6) / math.e
GOLDEN_RATIO = (1. + math.sqrt(5)) / 2.
EARLY_STOPPING_PATIENCE = 1 / 9

In [9]:
WEIGHT_DECAY  = GOLDEN_RATIO  * 10 ** -math.pi
LEARNING_RATE = GOLDEN_RATIO  * 10 ** -math.e

In [10]:
METRIC_TO_MONITOR = "val_acc"
METRIC_MODE       = "max"

In [11]:
ACC_HISTORY     = dict()
LOSS_HISTORY    = dict()
MODEL_NAME      = dict()
MODEL           = dict()
BEST_MODEL_PATH = dict()

In [12]:
os.makedirs("experiment", exist_ok=True)
os.makedirs("experiment/training", exist_ok=True)
os.makedirs("experiment/dataset", exist_ok=True)
os.makedirs("experiment/model", exist_ok=True)
EXPERIMENT_DIR = "experiment/"

In [13]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

Random seed: 1478610483


# 3. Dataset

## Utils

In [16]:
AUG_TRANSFORM = Compose(
    [
        RandAugment(),
        
        RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        RandomHorizontalFlip(p=0.5),
        RandomRotation(35),
        
        Resize((IMAGE_SIZE, IMAGE_SIZE)),
        ToTensor(),
        Lambda(lambda x: (x * 2) - 1),
        Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
           
    ]
)

TRANSFORM = Compose(
    [
        Resize((IMAGE_SIZE, IMAGE_SIZE)),
        ToTensor(),
        Lambda(lambda x: (x * 2) - 1),
        
        Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
        
    ]
)

In [None]:
class Caltech256(VisionDataset):
    """`Caltech 256 `_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``ISIC2019`` exists or will be saved to if download is set to
            True.
        split (string): dataset split
        transform (callable, optional): A function/transform that takes in a
            PIL image and returns a transformed version.
            E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
        download (bool, optional): If true, downloads the dataset from the
            internet and puts it in root directory. If dataset is already
            downloaded, it is not downloaded again.
    """

    def __init__(
        self,
        root: str,
        split: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super().__init__(
            os.path.join(root, "caltech256"),
            transform=transform,
            target_transform=target_transform,
        )
        os.makedirs(self.root, exist_ok=True)

        assert split in ["train", "val", "test", "inference"], (
            "Please choose one of these: 'train', 'val', 'test', or 'inference'"
        )

        if split == "inference":
            assert self.transform is None and self.target_transform is None

        self.split = split

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True"
                "to download it"
            )

        self.categories = sorted(
            os.listdir(os.path.join(self.root, "256_ObjectCategories"))
        )

        self.y: List[int] = list()
        self.x = list()
        for i, c in enumerate(self.categories):
            if i == NUM_CLASSES: break

            image_path = [
                os.path.join(
                    self.root,
                    "256_ObjectCategories",
                    c,
                    item,
                )
                for item in os.listdir(
                    os.path.join(self.root, "256_ObjectCategories", c)
                )
                if item.endswith(".jpg")
            ]
            image_path = image_path[:TRUNCATE_PER_CATEGORY]

            start = 0
            end = 0

            if self.split == "train":
                end = int(0.81 * len(image_path))
            elif self.split == "val":
                start = int(0.81 * len(image_path))
                end = int(0.9 * len(image_path))
            else:
                start = int(0.9 * len(image_path))
                end = len(image_path)

            image_path = image_path[start:end]

            self.x.extend(image_path)
            self.y.extend(len(image_path) * [i])

        self.categories = [
            cat.split(".")[-1].replace("-101", "")
            for idx, cat in enumerate(self.categories)
            if idx < NUM_CLASSES
        ]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img = Image.open(self.x[index]).convert('RGB')

        target = self.y[index]

        if self.split != "inference":
            if self.transform is not None:
                img = self.transform(img)

            if self.target_transform is not None:
                target = self.target_transform(target)

        return img, target

    def _check_integrity(self) -> bool:
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

    def __len__(self) -> int:
        # return len(self.index)
        return len(self.y)
