# SWIN
### Setting up the modelling environment

In [1]:
# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 transformers

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision import datasets
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from timm.data.transforms_factory import create_transform
from transformers import AutoModelForImageClassification
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


## Loading the preprocessing

We again use `timm` for preprocessing when concerning the convnext model.

In [1]:
swin_feature_extractor = T.Compose([
    T.PILToTensor(),
    T.Resize((224, 224), T.InterpolationMode.BILINEAR, antialias=False),
    T.ConvertImageDtype(torch.float32),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
swin_feature_extractor

NameError: name 'T' is not defined

## Loading the data

We now use torchvision to download and load the CIFAR 10 data. Alternatively we could have loaded it directly, but the format is aweful and requires loading the data from a binary format (yuck). While we use the CIFAR 10 data, this is mainly used to show how to set up a training environment for an arbitrary datset. We utilize the transform inside the dataset, which is possible since we use a torchvision dataset. If we use a custom dataset we will have to code this in directly.

In [4]:
cifar10_train = datasets.CIFAR10(root="../data",
                                 train=True,
                                 transform=swin_feature_extractor)
cifar10_test = datasets.CIFAR10(root="../data",
                                train=False,
                                transform=swin_feature_extractor)

preprocessing the image would then look like the following. Here we show the preprocessing giving us a tensor of shape (3, 224, 224). When training we will use a dataloader that will give us batches such that the shapes are (batch_size, 3, 224, 224), which is what we want.

In [5]:
image_example, label_example = cifar10_train[0]
print(f"image input type: {type(image_example)}")
print(f"image size: {image_example.size()}")
print(f"label: {label_example}")

image input type: <class 'torch.Tensor'>
image size: torch.Size([3, 224, 224])
label: 6


In [6]:
train_dataloader = DataLoader(cifar10_train, batch_size=8, shuffle=True, num_workers=4)
test_dataloader = DataLoader(cifar10_test, batch_size=64, shuffle=False, num_workers=4)

## Training with Pytorch lightning

We now construct a pytorch lightning model for ConvNext.

In [7]:
class SWINModel(pl.LightningModule):
    def __init__(self,
                 name="microsoft/swin-base-patch4-window7-224-in22k",
                 num_classes=10,
                 default_root_dir="checkpoints/"):
        super().__init__()
        self.name = name
        self.num_classes = num_classes
        self.default_root_dir = default_root_dir
        self.loss_fn = nn.CrossEntropyLoss()
        self.model = AutoModelForImageClassification.from_pretrained(name,
                                                                     num_labels=num_classes,
                                                                     ignore_mismatched_sizes=True)

    def forward(self, x):
        outs = self.model(x)
        return outs.logits

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)
        self.log("val_loss", loss)
        return loss

In [8]:
swin_model = SWINModel()
summary(swin_model, input_size=(1, 3, 224, 224))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-base-patch4-window7-224-in22k and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([21841, 1024]) in the checkpoint and torch.Size([10, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([21841]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Layer (type:depth-idx)                                                 Output Shape              Param #
SWINModel                                                              [1, 10]                   --
├─SwinForImageClassification: 1-1                                      [1, 10]                   --
│    └─SwinModel: 2-1                                                  [1, 1024]                 --
│    │    └─SwinEmbeddings: 3-1                                        [1, 3136, 128]            6,528
│    │    └─SwinEncoder: 3-2                                           [1, 49, 1024]             86,734,648
│    │    └─LayerNorm: 3-3                                             [1, 49, 1024]             2,048
│    │    └─AdaptiveAvgPool1d: 3-4                                     [1, 1024, 1]              --
│    └─Linear: 2-2                                                     [1, 10]                   10,250
Total params: 86,753,474
Trainable params: 86,753,474
Non-trainable params: 0

### Model Callbacks

In [9]:
# callback for creating checkpoints
checkpoint_callback = callbacks.ModelCheckpoint(dirpath="checkpoints/",
                                      filename="{epoch}-{val_loss:.2f}",
                                      monitor="val_loss",
                                      save_top_k=2,
                                      save_weights_only=True,
                                      mode="min")

# Rich model summary
rich_summary_callback = callbacks.RichModelSummary()

### Trainer

We now construct a pytorch lightning trainer for ConvNext.

In [14]:
gpus = "gpu" if torch.cuda.is_available() else "cpu"
swin_trainer = pl.Trainer(accelerator=gpus,
                              devices=1,
                              max_epochs=10,
                              callbacks=[checkpoint_callback, rich_summary_callback])

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
swin_trainer.fit(swin_model, train_dataloader, test_dataloader)

Missing logger folder: /Users/rasmus/Library/CloudStorage/OneDrive-UniversityofCopenhagen/Skole/kandidat/1. Semester/Advanced Topics in Image Analysis (ATIA)/testing/models/lightning_logs


Epoch 0:   0%|          | 19/6407 [01:27<8:12:40,  4.63s/it, loss=2.62, v_num=0]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [12]:
swin_model

SWINModel(
  (loss_fn): CrossEntropyLoss()
  (model): SwinForImageClassification(
    (swin): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=128, out_features=128, bias=True)
                    (key): Linear(in_features=128, out_features=128, bias=True)
                    (value): Linear(in_features=128, out_features=128, bias=True)
                    (dropo