# Import Libraries

In [None]:
!pip install pytorch_lightning==1.7.0
!pip install quaterion
!pip install quaterion-models

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning==1.7.0
  Downloading pytorch_lightning-1.7.0-py3-none-any.whl (700 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m700.9/700.9 KB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.3-py3-none-any.whl (518 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.6/518.6 KB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: pyDeprecate, torchmetrics, pytorch_lightning
Successfully installed pyDeprecate-0.3.2 pytorch_lightning-1.7.0 torchmetrics-0.11.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting quaterion
  Downloading quaterion-0.1.34-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import shutil
import numpy as np
import tqdm
import torch
import torchvision
import pickle
from typing import Union, Dict, Optional, Callable
from torch import nn
from torch.utils.data import Dataset, Subset
from torchvision import datasets, transforms
from quaterion import Quaterion, TrainableModel
from quaterion.dataset import GroupSimilarityDataLoader, SimilarityGroupSample
from quaterion.loss import SimilarityLoss, TripletLoss
from quaterion.train.cache import CacheConfig, CacheType
from quaterion.eval.attached_metric import AttachedMetric
from quaterion.eval.group import RetrievalRPrecision
from quaterion_models.encoders import Encoder
from quaterion_models.heads import EncoderHead, SkipConnectionHead

# Define Configurations

In [None]:
DATASET_PATH = os.path.join('.', 'results') # ./results
TRAINSET_PATH = os.path.join(DATASET_PATH, 'train') # ./results/train
TESTSET_PATH = os.path.join(DATASET_PATH, 'test') # ./results/test

# Hyperparameters
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
EPOCHS = 300
INPUT_SIZE = 224

# Download Dataset

In [None]:
if not os.path.exists('kaggle.json'):
  from google.colab import files
  files.upload()

Saving kaggle.json to kaggle.json


In [None]:
!if [ ! -f ~/.kaggle/kaggle.json ]; then \
  rm -rf ~/.kaggle/ && \
  mkdir ~/.kaggle/ && \
  cp kaggle.json ~/.kaggle/ && \
  chmod 600 ~/.kaggle/kaggle.json; \
fi

In [None]:
!if [ ! -f even-more-fruitssssss.zip ]; then \
  kaggle datasets download -d yash161101/even-more-fruitssssss && \
  unzip even-more-fruitssssss.zip; \
fi

Downloading even-more-fruitssssss.zip to /content
100% 327M/328M [00:02<00:00, 126MB/s]
100% 328M/328M [00:02<00:00, 121MB/s]
Archive:  even-more-fruitssssss.zip
  inflating: results/test/Apple/Image_155.png  
  inflating: results/test/Apple/Image_226.jpg  
  inflating: results/test/Apple/Image_240.jpg  
  inflating: results/test/Apple/Image_246.jpeg  
  inflating: results/test/Apple/Image_294.jpg  
  inflating: results/test/Apple/Image_312.jpg  
  inflating: results/test/Apple/Image_313.jpg  
  inflating: results/test/Apple/Image_321.jpg  
  inflating: results/test/Apple/Image_329.jpg  
  inflating: results/test/Apple/Image_51.jpg  
  inflating: results/test/Apple/Image_54.jpeg  
  inflating: results/test/Apple/Image_68.jpg  
  inflating: results/test/Apple/Image_79.jpg  
  inflating: results/test/Banana/Image_139.png  
  inflating: results/test/Banana/Image_194.jpg  
  inflating: results/test/Banana/Image_200.jpg  
  inflating: results/test/Banana/Image_204.jpeg  
  inflating: result

# Load Dataset

In [None]:
class FruitsDataset(Dataset):
  def __init__(self, dataset: Dataset, transform: Callable):
    self._dataset = dataset
    self._transform = transform

  def __len__(self) -> int:
    return len(self._dataset)

  def __getitem__(self, index) -> SimilarityGroupSample:
    image, label = self._dataset[index]
    image = self._transform(image)
    return SimilarityGroupSample(obj=image, group=label)

def get_datasets(
  input_size: int,
  split_cache_path='split_cache.pkl',
):
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]

  train_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
  ])
  test_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
  ])

  # full_dataset = datasets.StanfordCars(
  #   root=DATASET_PATH, 
  #   split='train', 
  #   download=True
  # ) + datasets.StanfordCars(
  #   root=DATASET_PATH, 
  #   split='test', 
  #   download=True
  # )

  trainset = datasets.ImageFolder(TRAINSET_PATH)
  testset = datasets.ImageFolder(TESTSET_PATH)
  full_dataset = trainset + testset

  train_indices, test_indices = None, None

  if not split_cache_path or not os.path.exists(split_cache_path):
    train_categories = np.random.choice(a=196, size=196 // 2, replace=False)
    labels_list = np.array([label for _, label in tqdm.tqdm(full_dataset)])
    labels_mask = np.isin(labels_list, train_categories)
    train_indices = np.argwhere(labels_mask).squeeze()
    test_indices = np.argwhere(np.logical_not(labels_mask)).squeeze()

  if train_indices is None or test_indices is None:
    train_indices, test_indices = pickle.load(open(split_cache_path, 'rb'))
  else:
    pickle.dump((train_indices, test_indices), open(split_cache_path, 'wb'))

  train_dataset = FruitsDataset(
    Subset(full_dataset, train_indices), 
    transform=train_transform
  )
  test_dataset = FruitsDataset(
    Subset(full_dataset, test_indices), 
    transform=test_transform
  )

  return train_dataset, test_dataset

def get_dataloaders(
  batch_size: int,
  input_size: int,
  shuffle: bool=False,
  split_cache_Path='split_cache.pkl',
):
  train_dataset, test_dataset = get_datasets(input_size, split_cache_Path)
  train_dataloader = GroupSimilarityDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
  test_dataloader = GroupSimilarityDataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  return train_dataloader, test_dataloader

# Define Encoder

In [None]:
class FruitsEncoder(Encoder):
  def __init__(self, encoder_model: nn.Module):
    super().__init__()
    self._encoder = encoder_model
    self._embedding_size = 2048

  @property
  def trainable(self) -> bool:
    return False

  @property
  def embedding_size(self) -> int:
    return self._embedding_size

  def forward(self, images):
    embeddings = self._encoder.forward(images)
    return embeddings
  
  def save(self, output_path: str):
    os.makedirs(output_path, exist_ok=True)
    torch.save(self._encoder, os.path.join(output_path, 'encoder.pth'))
    
  def load(self, input_path: str):
    encoder_model = torch.load(os.path.join(input_path, 'encoder.pth'))
    return FruitsEncoder(encoder_model)

# Define Model

In [None]:
class Model(TrainableModel):
  def __init__(self, lr: float, mining: str):
    self._lr = lr
    self._mining = mining
    super().__init__()

  def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
    pretrained_encoder = torchvision.models.resnet152(weights='IMAGENET1K_V1')
    pretrained_encoder.fc = nn.Identity()
    return FruitsEncoder(pretrained_encoder)

  def configure_head(self, input_embedding_size) -> EncoderHead:
    return SkipConnectionHead(input_embedding_size, dropout=0.1)

  def configure_loss(self) -> SimilarityLoss:
    return TripletLoss(mining=self._mining, margin=0.5)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.model.parameters(), self._lr)
    return optimizer

  def configure_caches(self) -> Optional[CacheConfig]:
    return CacheConfig(cache_type=CacheType.AUTO, save_dir='./cache_dir', batch_size=32)

  def configure_metrics(self) -> AttachedMetric:
    return AttachedMetric(
      'rrp',
      metric=RetrievalRPrecision(),
      prog_bar=True,
      on_epoch=True,
      on_step=True,
    )

# Train Model

In [None]:
def train(
  lr: float, 
  mining: str, 
  batch_size: int, 
  epochs: int, 
  input_size: int, 
  shuffle: bool, 
  save_dir: str,
):
  model = Model(lr=lr, mining=mining)

  train_dataloader, val_dataloader = get_dataloaders(
    batch_size=batch_size, 
    input_size=input_size, 
    shuffle=shuffle, 
  )

  Quaterion.fit(
    trainable_model=model,
    trainer=None,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
  )

  shutil.rmtree(save_dir, ignore_errors=True)
  model.save_servable(save_dir)

In [None]:
train(
  lr=LEARNING_RATE, 
  mining='hard', 
  batch_size=BATCH_SIZE, 
  epochs=EPOCHS, 
  input_size=INPUT_SIZE,
  shuffle=True,
  save_dir='./fruit_recognition/'
)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

100%|██████████| 791/791 [00:32<00:00, 24.38it/s]
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
2023-03-08 06:37:28.515 | DEBUG    | quaterion.train.cache_mixin:_cache:168 - Using full cache


Output()

2023-03-08 06:45:58.031 | DEBUG    | quaterion.train.cache_mixin:_cache:223 - Caching has been successfully finished
2023-03-08 06:45:58.045 | DEBUG    | quaterion.train.cache_mixin:save_cache:385 - Cache saved to ./cache_dir


Output()

  rank_zero_deprecation(


# Save Model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp -r ./fruit_recognition/ /content/drive/My\ Drive/models/