In [1]:
!pip install pytorch-lightning



In [2]:
!pip install -U transformers



In [3]:
!pip install -U accelerate



In [4]:
!pip install datasets



In [5]:
!pip install evaluate



In [6]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from datasets import load_dataset, DatasetDict, load_metric

# 1. Loading the data

In [7]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [8]:
dataset = load_dataset('imagefolder', data_dir='/content/gdrive/MyDrive/WikiArt', split ='train')
dataset

Resolving data files:   0%|          | 0/42500 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/42500 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['image', 'label'],
    num_rows: 42500
})

In [9]:
# Split the dataset into train, test and valid sets with the train_test_split method
# First, split the dataset into 80% train and 20% test + validation
train_testvalid = dataset.train_test_split(test_size=0.2, seed = 91, stratify_by_column = 'label')
# Then split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed = 91, stratify_by_column = 'label')
# gather every sets into a single DatasetDict
train_test_valid_dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

In [10]:
train_test_valid_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 34000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 4250
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 4250
    })
})

In [11]:
train_ds = train_test_valid_dataset['train']
val_ds = train_test_valid_dataset['valid']
test_ds = train_test_valid_dataset['test']

In [12]:
# each example has 2 features: 'img' (of type Image) and 'label' (of type ClassLabel):
train_test_valid_dataset['train'].features

{'image': Image(decode=True, id=None),
 'label': ClassLabel(names=['Academic_Art', 'Art_Nouveau', 'Baroque', 'Expressionism', 'Japanese_Art', 'Neoclassicism', 'Primitivism', 'Realism', 'Renaissance', 'Rococo', 'Romanticism', 'Symbolism', 'Western_Medieval'], id=None)}

In [13]:
# creating dictionaries which map between integer indices and actual class names:
id2label = {id:label for id, label in enumerate(train_test_valid_dataset['train'].features['label'].names)}
label2id = {label:id for id,label in id2label.items()}
id2label

{0: 'Academic_Art',
 1: 'Art_Nouveau',
 2: 'Baroque',
 3: 'Expressionism',
 4: 'Japanese_Art',
 5: 'Neoclassicism',
 6: 'Primitivism',
 7: 'Realism',
 8: 'Renaissance',
 9: 'Rococo',
 10: 'Romanticism',
 11: 'Symbolism',
 12: 'Western_Medieval'}

In [14]:
len(id2label)

13

# 2. Preprocess the data

Huggingface models require 2 things: pixel_values and labels.

Preprocessing images typically comes down to (1) resizing them to a particular size (2) normalizing the color channels (R,G,B) using a mean and standard deviation. These are referred to as image transformations.

In addition, one typically performs what is called data augmentation during training (like random cropping and flipping) to make the model more robust and achieve higher accuracy. Data augmentation is also a great technique to increase the size of the training data. The following augmentation techniques are considered:

- RandomResizedCrop: This technique randomly crops and resizes the input image to a specified size. It helps in introducing variations in the scale and aspect ratio of the input images. A random crop operation with range of size between 0.08 and 1.0 of the original size and aspect ratio (from 3/4 to 4/3) of the original aspect ratio were conducted. The cropped image was finally resized to 224x224 using random interpolation
- RandomHorizontalFlip: This technique randomly flips the input image horizontally. The given image was randomly flipped horizontally at a preset probability of 50%
- RandomGrayscale: Convert image to grayscale. Because there is a considerable proportion of images are grayscale, our models need to be invariant to color variations, converting images to grayscale can help achieve this. By removing color information, the model focuses solely on the intensity values, making it less sensitive to changes in color distribution or lighting conditions.

In [15]:
from transformers import ViTImageProcessor
from torchvision.transforms.v2 import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip, RandomResizedCrop, RandomGrayscale, CenterCrop

In [16]:
# The model is pre-trained on ImageNet-21k, a dataset of 14 million labeled images
model_name = "google/vit-base-patch16-224-in21k"
image_processor = ViTImageProcessor.from_pretrained(model_name)
image_processor

Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [17]:
image_mean, image_std = image_processor.image_mean, image_processor.image_std
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

In [18]:
normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose([
    RandomResizedCrop(crop_size),
    RandomHorizontalFlip(),
    RandomGrayscale(),
    ToTensor(),
    normalize])

_val_test_transforms = Compose([
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ])

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_test_transforms(examples):
    examples['pixel_values'] = [_val_test_transforms(image.convert("RGB")) for image in examples['image']]
    return examples



In [19]:
# Apply data transformation and augmentation:
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_test_transforms)
test_ds.set_transform(val_test_transforms)

In [20]:
# create corresponding PyTorch DataLoaders for Pytorch Lightning:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [21]:
train_dataloader = torch.utils.data.DataLoader(train_ds, shuffle=True, collate_fn=collate_fn, batch_size=8, num_workers=8, pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(val_ds, collate_fn=collate_fn, batch_size=8, num_workers=8, pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_ds, collate_fn=collate_fn, batch_size=8, num_workers=8, pin_memory=True)

In [22]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([8, 3, 224, 224])
labels torch.Size([8])


In [23]:
next(iter(val_dataloader))['pixel_values'].shape

torch.Size([8, 3, 224, 224])

# 3. Define the model

In [24]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import evaluate, accelerate
import torch.nn as nn

In [25]:
class ViTLightningModule(pl.LightningModule):
    def __init__(self):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                              num_labels=len(id2label),
                                                              id2label=id2label,
                                                              label2id=label2id)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

    def common_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", accuracy, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=2e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader

# 4. Train the model

In [26]:
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=2,
    strict=False,
    verbose=False,
    mode='min')

In [29]:
pl.seed_everything(91)
model = ViTLightningModule()
trainer = pl.Trainer(accelerator='gpu', precision='16-mixed', callbacks=[early_stop_callback], max_epochs=5)
trainer.fit(model)

INFO:lightning_fabric.utilities.seed:Seed set to 91
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type                      | Params
---------------------------------------------------
0 |

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [30]:
trainer.test()

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=4-step=21250.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=4-step=21250.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{}]

In [35]:
from google.colab import auth
from googleapiclient.http import MediaFileUpload
from googleapiclient.discovery import build

def save_file_to_drive(name, path):
    file_metadata = {
    'name': name,
    'mimeType': 'application/octet-stream'
    }

    media = MediaFileUpload(path,
                  mimetype='application/octet-stream',
                  resumable=True)

    created = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()

    print('File ID: {}'.format(created.get('id')))

    return created


extension_zip = ".zip"
filename = 'version1'
folders_or_files_to_save= '/content/lightning_logs/version_1'

zip_file = filename + extension_zip

# !rm -rf $zip_file
!zip -r $zip_file {folders_or_files_to_save} # FOLDERS TO SAVE INTO ZIP FILE

auth.authenticate_user()
drive_service = build('drive', 'v3')

destination_name = zip_file
path_to_file = zip_file
save_file_to_drive(destination_name, path_to_file)

  adding: content/lightning_logs/version_1/ (stored 0%)
  adding: content/lightning_logs/version_1/hparams.yaml (stored 0%)
  adding: content/lightning_logs/version_1/checkpoints/ (stored 0%)
  adding: content/lightning_logs/version_1/checkpoints/epoch=4-step=21250.ckpt (deflated 8%)
  adding: content/lightning_logs/version_1/events.out.tfevents.1699335051.b7cd282f765c.1869.1 (deflated 72%)
File ID: 1qYG17jTQ_lchb9J0xfkeNHH2QLnYslK2


{'id': '1qYG17jTQ_lchb9J0xfkeNHH2QLnYslK2'}