In [1]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.10.0 pytorch-lightning-2.1.2 torchmetrics-1.2.0


In [2]:
!pip install transformers



In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.5-py3-none-any.whl (7.8 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.5


In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [6]:
import warnings
warnings.filterwarnings("ignore")

# 1. Loading the data

In [7]:
import torchvision
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

In [8]:
dataset = torchvision.datasets.ImageFolder('/content/gdrive/MyDrive/WikiArt')
dataset

Dataset ImageFolder
    Number of datapoints: 42500
    Root location: /content/gdrive/MyDrive/WikiArt

In [9]:
# Split the dataset into train set and validation + test set
# Stratified Sampling for train set and val_test set
train_index, val_test_index= train_test_split(
np.arange(len(dataset.targets)),
test_size=0.2, random_state=91,
shuffle=True,
stratify=dataset.targets)

In [10]:
# Split the val set and test set from val_test set:
test_index, val_index= train_test_split(
np.arange(len(val_test_index)),
test_size=0.5, random_state=91,
shuffle=True,
stratify=[dataset.targets[i] for i in val_test_index])

# 2. Preprocess the data

Huggingface models require 2 things: pixel_values and labels.

Preprocessing images typically comes down to (1) resizing them to a particular size (2) normalizing the color channels (R,G,B) using a mean and standard deviation. These are referred to as image transformations.

In addition, one typically performs what is called data augmentation during training (like random cropping and flipping) to make the model more robust and achieve higher accuracy. Data augmentation is also a great technique to increase the size of the training data. The following augmentation techniques are considered:

- RandomResizedCrop: This technique randomly crops and resizes the input image to a specified size. It helps in introducing variations in the scale and aspect ratio of the input images.
- RandomHorizontalFlip: This technique randomly flips the input image horizontally. The given image was randomly flipped horizontally at a preset probability of 50%
- RandomGrayscale: Convert image to grayscale. Because there is a considerable proportion of images are grayscale, our models need to be invariant to color variations, converting images to grayscale can help achieve this. By removing color information, the model focuses solely on the intensity values, making it less sensitive to changes in color distribution or lighting conditions.

In [11]:
from transformers import ViTImageProcessor
from torchvision.transforms.v2 import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip, RandomRotation, RandomResizedCrop, CenterCrop

In [12]:
# The model is trained on ImageNet-1k
model_name = "MBZUAI/swiftformer-l3"
image_processor = ViTImageProcessor.from_pretrained(model_name)
image_processor

preprocessor_config.json:   0%|          | 0.00/175 [00:00<?, ?B/s]

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [13]:
image_mean, image_std = image_processor.image_mean, image_processor.image_std
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

In [14]:
normalize = Normalize(mean=image_mean, std=image_std)
train_transforms = Compose([
    RandomResizedCrop(crop_size),
    RandomHorizontalFlip(),
    RandomRotation(10),
    ToTensor(),
    normalize])

val_test_transforms = Compose([
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize])

In [15]:
class TransformDataset(torch.utils.data.Dataset):
  def __init__(self, base_dataset, transformations):
    super(TransformDataset, self).__init__()
    self.base = base_dataset
    self.transformations = transformations

  def __len__(self):
    return len(self.base)

  def __getitem__(self, idx):
    x, y = self.base[idx]
    return self.transformations(x), y

In [16]:
# Subset dataset for train, test and val
train_subset = torch.utils.data.Subset(dataset, train_index)
test_subset = torch.utils.data.Subset(dataset, test_index)
val_subset = torch.utils.data.Subset(dataset, val_index)

In [17]:
# Apply data augmentation and transformation to the subsets:
train_dataset = TransformDataset(train_subset, train_transforms)
test_dataset = TransformDataset(test_subset, val_test_transforms)
val_dataset = TransformDataset(val_subset, val_test_transforms)

In [18]:
print('number of train images: {}'.format(len(train_dataset)))
print('number of test images: {}'.format(len(test_dataset)))
print('number of val images: {}'.format(len(val_dataset)))

number of train images: 34000
number of test images: 4250
number of val images: 4250


In [19]:
label2id = {}
id2label = {}

for i, class_name in enumerate(dataset.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [20]:
# Create a custom collator:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, batch):
        pixel_values =  torch.stack([x[0] for x in batch])
        labels = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return {"pixel_values": pixel_values, "labels": labels}
collator = ImageClassificationCollator(image_processor)

In [21]:
# # create corresponding PyTorch DataLoaders for Pytorch Lightning:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=192, num_workers=12, shuffle=True, collate_fn=collator, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=192, num_workers=12, shuffle=False, collate_fn=collator, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=192, num_workers=12, shuffle=False, collate_fn=collator, pin_memory=True)

# 3. Define and train the model

In [22]:
import pytorch_lightning as pl
from transformers import SwiftFormerForImageClassification , AdamW
import torch.nn as nn

In [23]:
class ViTLightningModule(pl.LightningModule):
    def __init__(self):
        super(ViTLightningModule, self).__init__()
        self.vit = SwiftFormerForImageClassification.from_pretrained(model_name,
                                                              num_labels=len(id2label),
                                                              id2label=id2label,
                                                              label2id=label2id,
                                                              ignore_mismatched_sizes=True)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]
        return loss, accuracy

    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_acc", accuracy, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", accuracy, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", accuracy, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=2e-4)

    def train_dataloader(self):
        return train_loader

    def val_dataloader(self):
        return val_loader

    def test_dataloader(self):
        return test_loader

In [24]:
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=True,
    verbose=True,
    mode='min')

In [25]:
pl.seed_everything(91)
model = ViTLightningModule()
trainer = pl.Trainer(accelerator='gpu', precision='16-mixed', callbacks=[early_stop_callback])
trainer.fit(model)

INFO:lightning_fabric.utilities.seed:Seed set to 91


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/114M [00:00<?, ?B/s]

Some weights of SwiftFormerForImageClassification were not initialized from the model checkpoint at MBZUAI/swiftformer-l3 and are newly initialized because the shapes did not match:
- head.weight: found shape torch.Size([1000, 512]) in the checkpoint and torch.Size([13, 512]) in the model instantiated
- head.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([13]) in the model instantiated
- dist_head.weight: found shape torch.Size([1000, 512]) in the checkpoint and torch.Size([13, 512]) in the model instantiated
- dist_head.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([13]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, usi

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 1.128


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.407 >= min_delta = 0.0. New best score: 0.721


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 0.691


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.047 >= min_delta = 0.0. New best score: 0.644


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.120 >= min_delta = 0.0. New best score: 0.524


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.166 >= min_delta = 0.0. New best score: 0.358


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.116 >= min_delta = 0.0. New best score: 0.242


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric val_loss did not improve in the last 3 records. Best score: 0.242. Signaling Trainer to stop.


In [26]:
trainer.test(ckpt_path='best')

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_0/checkpoints/epoch=11-step=2136.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_0/checkpoints/epoch=11-step=2136.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.24243025481700897, 'test_acc': 0.9284706115722656}]

# 4. Save checkpoints to Google Drive

In [27]:
from google.colab import auth
from googleapiclient.http import MediaFileUpload
from googleapiclient.discovery import build

In [28]:
folders_or_files_to_save= '/content/lightning_logs'
filename = 'Swiftformer'

In [29]:
def save_file_to_drive(name, path):
    file_metadata = {
    'name': name,
    'mimeType': 'application/octet-stream'
    }
    media = MediaFileUpload(path,
                  mimetype='application/octet-stream',
                  resumable=True)
    created = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
    print('File ID: {}'.format(created.get('id')))
    return created

In [30]:
extension_zip = ".zip"
zip_file = filename + extension_zip

# !rm -rf $zip_file
!zip -r $zip_file {folders_or_files_to_save} # FOLDERS TO SAVE INTO ZIP FILE

auth.authenticate_user()
drive_service = build('drive', 'v3')

destination_name = zip_file
path_to_file = zip_file
save_file_to_drive(destination_name, path_to_file)

  adding: content/lightning_logs/ (stored 0%)
  adding: content/lightning_logs/version_0/ (stored 0%)
  adding: content/lightning_logs/version_0/hparams.yaml (stored 0%)
  adding: content/lightning_logs/version_0/events.out.tfevents.1700273391.db6a16d0e569.1581.0 (deflated 67%)
  adding: content/lightning_logs/version_0/checkpoints/ (stored 0%)
  adding: content/lightning_logs/version_0/checkpoints/epoch=11-step=2136.ckpt (deflated 8%)
  adding: content/lightning_logs/version_0/events.out.tfevents.1700278186.db6a16d0e569.1581.1 (deflated 19%)
File ID: 1Sub0ON3225AqwcCrB2EQ0lPbgjFI3zHB


{'id': '1Sub0ON3225AqwcCrB2EQ0lPbgjFI3zHB'}