In [None]:
# If you are using Google Collab, you can import the following:
# %pip install -U datasets transformers accelerate ftfy pyarrow wandb pandas numpy

In [None]:
from argparse import Namespace

from datasets import load_dataset
from datasets import Dataset

import torch

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from torchvision import transforms
import wandb

import pandas as pd
import numpy as np

from sklearn.model_selection import StratifiedKFold

In [None]:
from accelerate import Accelerator
from accelerate.utils import GradientAccumulationPlugin
from accelerate.utils import set_seed

In [None]:
from accelerate.utils import write_basic_config

write_basic_config()

## Set up config

In [None]:
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available() \
        else 'mps' if torch.backends.mps.is_available() else 'cpu')
# DEVICE = 'cpu'

CONFIG = Namespace(
    run_name='animal-classifier',
    model_name='animal-classifier-model-v1',
    image_size=256,
    hidden_dims=256,
    horizontal_flip_prob=0.5,
    gaussian_blur_kernel_size=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=15,
    learning_rate=4e-4,
    seed=1,
    beta_schedule='squaredcos_cap_v2',
    lr_exp_schedule_gamma=0.85,
    lr_warmup_steps=500,
    train_limit=-1,
    save_model=False,
    mixed_precision=None,
    grad_accumulation_steps=4
    )
CONFIG.device = DEVICE

## Create Dataset

For now, I am using the following data augmentations:
- RandomHorizontalFlip - Randomly flips the image horizontally
- GaussianBlur - Smooth/blur image using a Gaussian filter

In [None]:
def prepare_dataloader(config: Namespace):
    """
    Prepare dataloader
    """

    preprocess = transforms.Compose(
        [
            transforms.Resize((config.image_size, config.image_size)),  # Resize
            transforms.RandomHorizontalFlip(p=config.horizontal_flip_prob),
            transforms.GaussianBlur(kernel_size=config.gaussian_blur_kernel_size),
            transforms.ToTensor(),  # Convert to tensor (0, 1)
            transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
        ])
    
    # For pre-processing original image for visualization in W&Bs
    preprocess_original = transforms.Compose(
        [
            transforms.Resize((512, 512)),  # Resize
            transforms.ToTensor(),  # Convert to tensor (0, 1)
        ])

    # Load dataset
    dataset = load_dataset('cats_vs_dogs')
    # Remove images that are 100x100 or below.
    dataset = \
        dataset.filter(
            lambda example: example['image'].size[0] > 100 and example['image'].size[1] > 100)

    def transform(examples):
        images = [preprocess(image.convert('RGB')) for image in examples['image']]
        original_images = [
            preprocess_original(image.convert('RGB')) \
                for image in examples['image']]

        return {'image': images,
                'label': examples['labels'],
                'original-image': original_images
                }

    # Split dataset into train + val. Balance train + val
    num_points = len(dataset['train'])
    labels = dataset['train']['labels']

    split_df = pd.DataFrame()
    split_df['labels'] = labels
    split_df['id'] = list(range(num_points))
    split_df['fold'] = -1

    cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=config.seed)
    for i, (_, test_ids) in enumerate(cv.split(np.zeros(num_points), labels)):
        split_df.loc[test_ids, ['fold']] = i

    split_df['split'] = 'train'
    split_df.loc[split_df.fold == 0, ['split']] = 'val'

    # print(split_df[split_df['split'].str.fullmatch('train')].labels.value_counts())
    # print(split_df[split_df['split'].str.fullmatch('val')].labels.value_counts())

    train_indices = split_df[split_df['split'].str.fullmatch('train')]['id']
    val_indices = split_df[split_df['split'].str.fullmatch('val')]['id']

    def train_generator():
        for idx in train_indices:
            yield dataset['train'][idx]

    def val_generator():
        for idx in val_indices:
            yield dataset['train'][idx]

    train_dataset = Dataset.from_generator(train_generator)
    val_dataset = Dataset.from_generator(val_generator)

    train_dataset.set_transform(transform)
    val_dataset.set_transform(transform)

    train_gen = torch.Generator().manual_seed(config.seed)
    val_gen = torch.Generator().manual_seed(config.seed)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.per_device_train_batch_size,
        shuffle=True, generator=train_gen)
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=config.per_device_eval_batch_size,
        shuffle=False, generator=val_gen)

    return train_dataloader, val_dataloader

In [None]:
class AnimalClassifier(torch.nn.Module):

    def __init__(self, in_channels: int, dims: int,
                 num_labels: int):
        super().__init__()

        self.conv_1 = torch.nn.Conv2d(
            in_channels, dims, kernel_size=12)
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=3)

        self.conv_2 = torch.nn.Conv2d(
            dims, 2*dims, kernel_size=5)
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=3)

        self.conv_3 = torch.nn.Conv2d(
            2*dims, 2*dims, kernel_size=3)
        self.max_pool_3 = torch.nn.MaxPool2d(kernel_size=2)

        self.conv_4 = torch.nn.Conv2d(
            2*dims, 2*dims, kernel_size=3)
        self.max_pool_4 = torch.nn.MaxPool2d(kernel_size=2)

        self.flatten = torch.nn.Flatten()
        self.projection = torch.nn.LazyLinear(4*dims)

        self.linear = torch.nn.Linear(4*dims, num_labels)

    def forward(self, x: torch.Tensor):
        """
        Forward pass
        """

        # print(f"Input shape: {x.shape}")

        x_ = self.conv_1(x)
        x_ = self.max_pool_1(x_)
        # print(f"Output of conv & max pool 1: {x_.shape}")

        x_ = self.conv_2(x_)
        x_ = self.max_pool_2(x_)
        # print(f"Output of conv & max pool 2: {x_.shape}")

        x_ = self.conv_3(x_)
        x_ = self.max_pool_3(x_)
        # print(f"Output of conv & max pool 3: {x_.shape}")

        x_ = self.conv_4(x_)
        x_ = self.max_pool_4(x_)
        # print(f"Output of conv & max pool 4: {x_.shape}")

        x_ = self.flatten(x_)
        # print(f"Output of flatten: {x_.shape}")

        x_ = self.projection(x_)
        # print(f"Output of projection: {x_.shape}")

        output = self.linear(x_)
        # print(f"Final output: {output.shape}")
        return output

def create_model(in_dimensions: int, dims: int, num_labels: int):
    """
    Create model
    """

    model = AnimalClassifier(in_dimensions, dims, num_labels)
    return model

## Train Model

In [None]:
def compute_loss(preds: torch.Tensor, labels: torch.Tensor):
    """
    Compute loss
    """

    # Sum over each subset & average over each batch
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    # Cross entropy loss require (batch_size, num_classes, ...)
    loss = loss_fn(preds, labels)
    return loss

@torch.no_grad()
def eval_loop(epoch: int, model, dataloader,
              wandb_run, accelerator: Accelerator):
    """
    Evaluation loop
    """

    tensor_to_pil = transforms.ToPILImage()

    columns = ['pred']

    dataframe = []
    original_images = []
    images = []
    gt = []

    avg_loss = 0
    for i, batch in enumerate(dataloader):

        logits = model(batch['image'])
        labels = batch['label']

        preds = torch.argmax(logits, dim=-1)
        loss = compute_loss(logits, labels)
        avg_loss += loss.item()

        # acc = (preds == labels).double()
        # print(f"Accuracy: {acc.mean().item()} - Val loss: {loss.item()}")
        # wandb_run.log({'accuracy': acc.mean()}, commit=False)
        # wandb_run.log({'val-loss': loss.item()}, commit=False)
        
        _images = []
        _original_images = []
        for j in range(batch['image'].shape[0]):
            _images.append(tensor_to_pil(batch['image'][j,:]))
            _original_images.append(tensor_to_pil(batch['original-image'][j,:]))

        images += _images
        original_images += _original_images

        dataframe += preds.tolist()
        gt += batch['label'].tolist()

    dataframe = pd.DataFrame(dataframe,
                             columns=columns)
    dataframe['epoch'] = epoch
    dataframe['image'] = \
        [wandb.Image(image) for image in images]
    dataframe['original_images'] = \
        [wandb.Image(image) for image in original_images]
    dataframe['gt'] = gt

    # Get average accuracy and loss
    acc = (dataframe['gt'] == dataframe['pred']).mean()
    avg_loss = avg_loss/len(dataloader)

    accelerator.print(
        f"Val accuracy and loss: {acc} - {avg_loss}")

    table = wandb.Table(data=dataframe)
    wandb_run.log({'accuracy': acc}, commit=False)
    wandb_run.log({'val-loss': loss}, commit=False)
    wandb_run.log({'eval-table': table})

def training_loop(config: Namespace):
    """
    Training loop
    """

    wandb_run = wandb.init(project='Animal-Classifier', entity=None,
                           job_type='training',
                           name=config.run_name,
                           config=config)

    set_seed(config.seed)

    grad_accumulation_plugin = GradientAccumulationPlugin(
        num_steps=config.grad_accumulation_steps,
        adjust_scheduler=True,
        sync_with_dataloader=True)

    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_plugin=grad_accumulation_plugin,
        cpu=(config.device == 'cpu'))

    train_dataloader, val_dataloader = prepare_dataloader(config)    
    model = create_model(3, config.hidden_dims, 2)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=config.lr_warmup_steps)
        # last_epoch=config.num_train_epochs*len(train_dataloader))

    model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader, scheduler)

    num_steps = 0
    for epoch in range(config.num_train_epochs):
        model.train()

        accelerator.print(f"Epoch {epoch}")

        epoch_loss = 0
        num_iters = 0

        optimizer.zero_grad()
        for _, batch in enumerate(train_dataloader):

            with accelerator.accumulate(model):
                logits = model(batch['image'])
                labels = batch['label']

                loss = compute_loss(logits, labels)

                # accelerator.print(f"Loss: {loss.item()}")

                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)

                epoch_loss += loss.item()

                wandb_run.log({'loss': loss.item()}, commit=False, step=num_steps)
                wandb_run.log({'lr': scheduler.get_lr()[0]}, commit=False, step=num_steps)

                num_steps += 1
                num_iters += 1

                # Update the model parameters with the optimizer
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

        # Validate model
        eval_loop(epoch, model, val_dataloader, wandb_run, accelerator)

        wandb_run.log({'epoch-loss': epoch_loss/num_iters})

    if config.save_model:
        # Save model to W&Bs
        model_art = wandb.Artifact(config.model_name, type='model')
        torch.save(model.state_dict(), 'model.pt')

        model_art.add_file('model.pt')
        wandb_run.log_artifact(model_art)
    wandb_run.finish()

In [None]:
# # For debugging
# MODEL = create_model(3, CONFIG.hidden_dims, 2)
# train_dataloader, val_dataloader = prepare_dataloader(CONFIG)
# eval_loop(0, MODEL, val_dataloader, None, None)

In [None]:
from accelerate import notebook_launcher

notebook_launcher(training_loop, (CONFIG, ), num_processes=1)