### Preparations
---

Clone the repository

In [None]:
!git clone https://github.com/Godofnothing/CLIP_experimental

Installation of the dependencies

In [None]:
!pip install -q pytorch-lightning
!pip install -q ftfy regex
!wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

In [None]:
import sys
sys.path.insert(0, "CLIP_experimental")

import torch
import numpy as np

from src import CLIP_Lite, CLIP_Pro
from src import TextTransformer
from src import SimpleTokenizer
from src import CLIPDataset

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# lite or pro
clip_mode = 'lite'
# classic or cosine_similarity for lite | visual or visual&text for pro
train_mode = 'classic'

### Download the model
___

Models available
- RN50
- RN101
- RN50x4
- ViT-B/32

In [None]:
!git clone https://github.com/openai/CLIP

In [None]:
from CLIP import clip

# set jit=False to case to nn.Module from torchscript 
model, image_transform = clip.load("ViT-B/32", jit=False)
input_resolution = 224
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

### Get dataset from kaggle
---

In [None]:
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!ls ~/.kaggle
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download moltean/fruits
!unzip -q fruits.zip

In [None]:
dataset_root = 'fruits-360'

### Construction of the dataset and loaders
---

In [None]:
import torchvision.transforms as T
from torch.utils.data import DataLoader
from PIL import Image

In [None]:
if clip_mode == 'pro':
    tokenizer = SimpleTokenizer()
    text_transformer = TextTransformer(tokenizer, context_length)

Redefine transformations if needed

In [None]:
MEAN = (0.48145466, 0.4578275, 0.40821073)
STD = (0.26862954, 0.26130258, 0.27577711)

image_transform = T.Compose([
    T.Resize(input_resolution, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(input_resolution),
    T.ToTensor(),
    T.Normalize(MEAN, STD)                              
])

In [None]:
if clip_mode == 'lite' and train_mode == 'classic':
    train_dataset = CLIPDataset(
        f'{dataset_root}/Training', 
        image_transform=image_transform, 
        return_indices=True
    )

    val_dataset = CLIPDataset(
        f'{dataset_root}/Test', 
        image_transform=image_transform, 
        return_indices=True
    )
else:
    train_dataset = CLIPDataset(
        f'{dataset_root}/Training', 
        image_transform=image_transform, 
        prompt_transform=text_transformer,
        return_indices=False
    )

    val_dataset = CLIPDataset(
        f'{dataset_root}/Test', 
        image_transform=image_transform, 
        prompt_transform=text_transformer,
        return_indices=False
    )

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

### The training procedure
---

Init Pytorch Lightning modules

In [None]:
import pytorch_lightning as pl

In [None]:
if clip_mode == 'lite':
    # 1024 for RN, 512 for ViT
    clip_wrapper = CLIP_Lite(model, num_classes=131, clip_out_features=1024, training_mode=train_mode)
if clip_mode == 'pro':
    clip_wrapper = CLIP_Pro(model, training_mode=train_mode)

In [None]:
trainer = pl.Trainer(
    gpus=1,
    gradient_clip_val=1e-3,
    amp_backend='native',
    auto_lr_find=True
)

In [None]:
trainer.fit(clip_wrapper, train_loader, val_loader)