### Preparations
---


Define seed

In [None]:
seed = 42

Clone the repository

In [None]:
!git clone https://github.com/Godofnothing/CLIP_experimental

Installation of the dependencies

In [None]:
!pip install -q pytorch-lightning
!pip install -q ftfy regex
!wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

In [None]:
import sys
sys.path.insert(0, "CLIP_experimental")

import torch
import numpy as np

from src import CLIP_Lite, CLIP_Pro
from src import TextTransformer
from src import SimpleTokenizer
from src import CLIPDataset, train_val_split
from src import ClassificationVisualizer

from srgan import GANUpsample

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
assert device == 'cuda'

In [None]:
# lite or pro
clip_mode = 'lite'

### Download the model
___

Models available
- RN50
- RN101
- RN50x4
- ViT-B/32

In [None]:
!git clone https://github.com/openai/CLIP

In [None]:
from CLIP import clip

clip_backbone = "ViT-B/32"

# set jit=False to case to nn.Module from torchscript 
model, image_transform = clip.load(clip_backbone, jit=False)
input_resolution = 224
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

### Get dataset
---

In [None]:
# bicubic or GAN upsample
upsample_mode = 'bicubic'

In [None]:
if upsample_mode == 'bicubic':
    # download the origingal dataset
    !pip install -q kaggle
    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/
    !ls ~/.kaggle
    !chmod 600 /root/.kaggle/kaggle.json
    !kaggle datasets download moltean/fruits
    !unzip -q fruits.zip
    dataset_root = 'fruits-360'
elif upsample_mode == 'gan':
    # download the GAN upsampled dataset
    !gdown --id 10Omg4-7u4yfAlQTRfkJm1DIB-bsZEPuh
    !unzip -q fruits-360-gan
    dataset_root = 'fruits-360-gan'
else:
    raise NotImplementedError("Unknown upsampling mode")

### Construction of the dataset and loaders
---

In [None]:
import torchvision.transforms as T
from torch.utils.data import DataLoader

In [None]:
if clip_mode == 'pro':
    caption_templates = [
        "a photo of the small {}, type of fruit", 
        "a close-up photo of a {}, type of fruit", 
        "a cropped photo of the {}, type of fruit"
    ]

    tokenizer = SimpleTokenizer()
    text_transformer = TextTransformer(
        tokenizer, 
        templates=caption_templates,
        context_length=context_length
    )

Redefine transformations if needed

In [None]:
MEAN = (0.48145466, 0.4578275, 0.40821073)
STD = (0.26862954, 0.26130258, 0.27577711)

if upsample_mode == 'bicubic':
    train_transform = T.Compose([
        T.Resize(input_resolution, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(input_resolution),
        # augmentation
        T.RandomHorizontalFlip(p=0.5),
        T.RandomPerspective(),
        T.RandomRotation(degrees=20, interpolation=T.InterpolationMode.BICUBIC),
        T.GaussianBlur(3, sigma=(0.1, 2.0)),
        #
        T.ToTensor(),
        T.Normalize(MEAN, STD)                              
    ])

    val_transform = T.Compose([
        T.Resize(input_resolution, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(input_resolution),
        T.ToTensor(),
        T.Normalize(MEAN, STD)                              
    ])

else:
    train_transform = T.Compose([
        # augmentation
        T.RandomHorizontalFlip(p=0.5),
        T.RandomPerspective(),
        T.RandomRotation(degrees=20, interpolation=T.InterpolationMode.BICUBIC),
        T.GaussianBlur(3, sigma=(0.1, 2.0)),
        #
        T.ToTensor(),
        T.Normalize(MEAN, STD)                              
    ])

    val_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(MEAN, STD)                              
    ])

In [None]:
train_suffix = 'Training' if upsample_mode == 'bicubic' else 'Train'

train_val_dataset = CLIPDataset(
    f'{dataset_root}/{train_suffix}', 
    image_transform=train_transform, 
    return_indices=True
)

train_dataset, val_dataset = train_val_split(train_val_dataset, val_size=0.1, random_state=42)
train_dataset.image_transform = train_transform
val_dataset.image_transform = val_transform

test_dataset = CLIPDataset(
    f'{dataset_root}/Test', 
    image_transform=val_transform, 
    return_indices=True
)

test_dataset.idx_to_class = train_dataset.idx_to_class
test_dataset.class_to_idx = train_dataset.class_to_idx

Init tensor of captions

In [None]:
if clip_mode == 'pro':
    num_classes = len(train_dataset.class_to_idx)
    num_captions = len(caption_templates)

    tokenized_captions = torch.zeros((num_classes, num_captions, context_length), dtype=torch.int)

    for idx, class_name in train_dataset.idx_to_class.items():
        class_captions = text_transformer(class_name)
        tokenized_captions[idx] = class_captions

    tokenized_captions = tokenized_captions.to(device)

In [None]:
import os

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=os.cpu_count())
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=os.cpu_count())

### The training procedure
---

Init Pytorch Lightning modules

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(seed, workers=True)

In [None]:
if clip_mode == 'lite':
    # 1024 for RN50, 512 for ViT/RN101
    clip_wrapper = CLIP_Lite(model, num_classes=131, clip_out_features=512)
if clip_mode == 'pro':
    clip_wrapper = CLIP_Pro(model, tokenized_captions, clip_out_features=512)

In [None]:
log_dir = f'logs/CLIP_{clip_mode}_{clip_backbone}'
logger = TensorBoardLogger(log_dir)
checkpoint = ModelCheckpoint(log_dir, monitor='val/accuracy', mode='max')

trainer = pl.Trainer(
    gpus=1,
    gradient_clip_val=1,
    amp_backend='native',
    deterministic=True,
    auto_lr_find=True,
    logger=logger,
    callbacks=[checkpoint]
)

trainer.tune(clip_wrapper, train_dataloader=train_loader)

In [None]:
trainer.fit(clip_wrapper, train_loader, val_loader)
clip_wrapper.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])
trainer.test(clip_wrapper, test_loader)

Tensorboard logger
---

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

Visualize predictions of the model
---

In [None]:
for param in clip_wrapper.parameters():
  param = param.cuda()

visualizer = ClassificationVisualizer(
    clip_wrapper, 
    train_dataset, 
    images_in_row=4
)

In [None]:
visualizer.visualize_predictions(num_images=16)