In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
%%capture
!pip install onnx onnxruntime onnxsim magic-wormhole
!pip install torchsampler
!pip install git+https://github.com/ildoonet/cutmix
!wget https://raw.githubusercontent.com/davda54/sam/main/sam.py

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms

from torch import nn
from torchvision.models import MobileNet_V3_Small_Weights, MobileNet_V3_Large_Weights

# Define the device to use for computation (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the MobileNetV3 model
model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
model.classifier = nn.Identity()
model = model.eval()
model = model.to(device)

In [None]:
import os
from glob import glob

from torch.utils.data import Dataset
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, labels_dir, transform):
        self.image_paths = []
        self.labels =[]
        for key, label_dir in labels_dir.items():
            for image_dir in glob(os.path.join(label_dir, '*.png')):
                self.image_paths.append(image_dir)
                self.labels.append(key)
            for image_dir in glob(os.path.join(label_dir, '*.jpg')):
                self.image_paths.append(image_dir)
                self.labels.append(key)
        self.transform = transform
        self.image_size = 160
        self.cutout_p = 0.1

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path, label = self.image_paths[idx], self.labels[idx]
        image = Image.open(image_path)
        if image.mode != "RGB":
            image = image.convert("RGB")
        image = self.transform(image, size=self.image_size, cutout_p=self.cutout_p)
        return image, label

    def next_epoch(self):
      self.image_size = min(224, self.image_size + 16)
      self.cutout_p = min(0.5, self.cutout_p + 0.15)

    def get_labels(self):
      return self.labels

In [None]:
class Classifier(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(out_channels, num_classes),
        )
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, 0, 0.01)
                torch.nn.init.zeros_(m.bias)

    def forward(self, x):
        x = x.flatten(1)
        x = self.classifier(x)
        return x

In [None]:
import torch.nn.functional as F

class SoftTargetCrossEntropy(nn.Module):
    def forward(self, x, target):
        loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
        return loss.mean()

## Training

# Data Structure

Here is data structure that are used:

```
train
└── class_name
    ├── yes
    │   └── abc.png
    └── no
        └── abc.png
```

You can send data easily using magic-wormhole or upload it manually.

In [None]:
!mkdir train
!cd train && wormhole receive 15-december-prowler --accept-file

In [None]:
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode

def preprocess_image(image, size, cutout_p):
  preprocess = T.Compose([
        T.RandomResizedCrop(size=size, interpolation=InterpolationMode.BILINEAR),
        T.ColorJitter(brightness=0.2, saturation=0.15, contrast=0.15, hue=0.1),
        T.TrivialAugmentWide(interpolation=InterpolationMode.BILINEAR),
        T.PILToTensor(),
        T.ConvertImageDtype(torch.float),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(p=cutout_p, value="random"),
  ])
  return preprocess(image)

In [None]:
from cutmix.cutmix import CutMix
from torch.utils.data import DataLoader
from torchsampler import ImbalancedDatasetSampler

from sam import SAM

BATCH_SIZE = 16
NUM_EPOCHS = 15
CLASS_NAME = "clothing"

labels_dir = {0: f"/content/train/{CLASS_NAME}/no", 1: f"/content/train/{CLASS_NAME}/yes"}
dataset = CustomImageDataset(labels_dir, preprocess_image)
dataset = CutMix(dataset, num_class=len(labels_dir), beta=1.0, prob=0.2, num_mix=1)
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              sampler=ImbalancedDatasetSampler(dataset.dataset),
                              pin_memory=True, num_workers=4)

In [None]:
classifier = Classifier(in_channels=576, out_channels=1024, num_classes=len(labels_dir)).to(device)
criterion = SoftTargetCrossEntropy()

optimizer = SAM(classifier.parameters(), torch.optim.AdamW, lr=3e-4, rho=1.0, adaptive=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer,
                                                       T_max=(len(dataset) // BATCH_SIZE) * 2)

In [None]:
from tqdm.notebook import trange, tqdm

for epoch in range(1, NUM_EPOCHS+1):
    # Set the model to training mode
    classifier.to(device).train()

    # Initialize the running loss and accuracy
    running_loss = 0.0
    running_accuracy = 0.0
    num_samples = 0

    # Iterate over the training set
    for inputs, targets in tqdm(train_dataloader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # Zero the optimizer gradients
        optimizer.zero_grad()

        # Generate embeddings
        with torch.no_grad():
          emb_inputs = model(inputs)

        # Forward pass
        outputs = classifier(emb_inputs)
        loss = criterion(outputs, targets)

        # First Forward-Backward pass
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # Second Forward-Backward
        outputs = classifier(emb_inputs)
        criterion(outputs, targets).backward()
        optimizer.second_step(zero_grad=True)
        scheduler.step()

        # Compute the batch accuracy and update the running accuracy and loss
        predicted = torch.argmax(outputs, dim=1)
        running_accuracy += torch.argmax(targets, dim=1).eq(predicted).sum().item()
        running_loss += loss.item() * len(targets)
        num_samples += len(targets)

    # Compute the average training loss and accuracy
    train_loss = running_loss / num_samples
    train_accuracy = running_accuracy / num_samples
    if epoch % 3 == 0:
      dataset.dataset.next_epoch()

    # Print the epoch loss and accuracy
    print("Epoch {} - Train Loss: {:.4f} - Train Accuracy: {:.4f}".format(
        epoch, train_loss, train_accuracy))

## ONNX Export

In [None]:
with torch.no_grad():
    torch.onnx.export(classifier.eval(),
                      torch.randn(1, 576).to(device),
                      f"{CLASS_NAME}.onnx",
                      opset_version=17,
                      do_constant_folding=True,
                      input_names = ['input'],
                      output_names = ['output'])

In [None]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = f"{CLASS_NAME}.onnx"
model_quant = f"{CLASS_NAME}.onnx"
quantized_model = quantize_dynamic(model_fp32, model_quant)

In [None]:
!onnxsim "{CLASS_NAME}.onnx" "{CLASS_NAME}.onnx"

In [None]:
!python3 -m onnxruntime.tools.convert_onnx_models_to_ort  "{CLASS_NAME}.onnx" --optimization_style Fixed
!rm "{CLASS_NAME}.required_operators.config"

## Post Training

In [None]:
!wormhole send "{CLASS_NAME}.ort"

## Testing

In [None]:
from torchvision import transforms as T

test_preprocess = T.Compose([
    T.Resize(size=224, interpolation=InterpolationMode.BILINEAR),
    T.ConvertImageDtype(torch.float),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
from IPython.display import display
from PIL import Image
from torchvision.io import read_image
import time

classifier.eval()
for image_file in glob("train/bicycles/no/*.*"):
    # Load the input image
    input_image = read_image(image_file)

    # Preprocess the input image
    input_tensor = test_preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device)
    with torch.no_grad():
      emb_inputs = model(input_batch)

    # Forward pass
    with torch.no_grad():
      outputs = classifier(emb_inputs)
    probs = torch.nn.functional.softmax(outputs, dim=1)
    _, predicted = torch.max(outputs, 1)
    if predicted == 1:
      display(Image.open(image_file))
      print(predicted, probs)
    time.sleep(0.01)

In [None]:
!rm train/vinyl/no/bb92e4cdd878c292.*