# ConvNext
### Setting up the modelling environment

In [1]:
import sys
# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.5.4 

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision import datasets
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from timm.data.transforms_factory import create_transform
from timm import create_model
from timm.data.transforms_factory import create_transform

## Loading the model

We start by loading the model using `timm`

In [6]:
convnext_name = "convnext_base_in22k"
model = create_model(convnext_name, pretrained=True)


TypeError: 'generator' object is not subscriptable

## Loading the preprocessing

We again use `timm` for preprocessing when concerning the convnext model.

In [23]:
convnext_train_feature_extractor = create_transform(224, is_training=True)
convnext_test_feature_extractor = create_transform(224, is_training=False)
print(convnext_train_feature_extractor)
print(convnext_test_feature_extractor)

Compose(
    RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=None)
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
Compose(
    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


## Loading the data

We now use torchvision to download and load the CIFAR 10 data. Alternatively we could have loaded it directly, but the format is aweful and requires loading the data from a binary format (yuck). While we use the CIFAR 10 data, this is mainly used to show how to set up a training environment for an arbitrary datset. We utilize the transform inside the dataset, which is possible since we use a torchvision dataset. If we use a custom dataset we will have to code this in directly.

In [33]:
cifar10_train = datasets.CIFAR10(root="../data",
                                 train=True,
                                 transform=convnext_train_feature_extractor)
cifar10_test = datasets.CIFAR10(root="../data",
                                train=False,
                                transform=convnext_test_feature_extractor)

preprocessing the image would then look like the following. Here we show the preprocessing giving us a tensor of shape (3, 224, 224). When training we will use a dataloader that will give us batches such that the shapes are (batch_size, 3, 224, 224), which is what we want.

In [35]:
image_example, label_example = cifar10_train[0]
print(f"image input type: {type(image_example)}")
print(f"image size: {image_example.size()}")
print(f"label: {label_example}")
print(f"label type: {type(label_example)}")

image input type: <class 'torch.Tensor'>
image size: torch.Size([3, 224, 224])
label: 6


In [36]:
train_dataloader = DataLoader(cifar10_train, batch_size=8, shuffle=True, num_workers=4)
test_dataloader = DataLoader(cifar10_test, batch_size=64, shuffle=False, num_workers=4)

## Training pipeline

The following is a purely pytorch training pipeline used for training the model

In [39]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    return epoch_loss, epoch_acc

In [41]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# train(model, train_dataloader, loss, optimizer, device)

## Pytorch lightning module for training

We now construct a pytorch lightning model for ConvNext.

In [3]:
class ConvNext(pl.LightningModule):
    def __init__(self, name="convnext_base_in22k", num_classes=10, default_root_dir="checkpoints/"):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.model = create_model(convnext_name, pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs).detach()
        loss = self.loss_fn(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs).detach()
        loss = self.loss_fn(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        return loss

In [None]:
convnext_model = ConvNext()

### Model Callbacks

In [4]:
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/",
                                      filename="{epoch}-{val_loss:.2f}",
                                      monitor="val_loss",
                                      save_top_k=2,
                                      save_weights_only=True,
                                      mode="min")

### Trainer

We now construct a pytorch lightning trainer for ConvNext.

In [None]:
convnext_trainer = pl.Trainer(gpus=1, max_epochs=10, callbacks=[checkpoint_callback])

In [None]:
cifar10_train_subset = torch.utils.data.Subset(convnext_model, range(1000))
subset_dataloader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)
convnext_trainer.fit(convnext_model, subset_dataloader, test_dataloader)