In [None]:
import sys
sys.path.insert(0, "../")

from src import CLIP_Lite, CLIP_Pro
from src import TextTransformer
from src import SimpleTokenizer
from src import CLIPDataset

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# lite or pro
clip_mode = 'lite'

train_mode = 'classic'

### Download the model
___

List all available models

In [None]:
MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",    
}

In [None]:
!wget {MODELS["RN50"]} -O model.pt

In [None]:
model = torch.jit.load("model.pt").to(device)
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

### Get dataset from kaggle
---

In [None]:
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!ls ~/.kaggle
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download moltean/fruits
!unzip -q fruits.zip

Install regexp packages for tokenizer

In [None]:
!pip install -q ftfy regex
!wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

### Construction of the dataset and loaders
---

In [None]:
import torchvision.transforms as T
from PIL import Image

In [None]:
if clip_mode == 'pro':
    tokenizer = SimpleTokenizer()
    text_transformer = TextTransformer(tokenizer, context_length)

In [None]:
MEAN = (0.48145466, 0.4578275, 0.40821073)
STD = (0.26862954, 0.26130258, 0.27577711)

image_transform = T.Compose([
    T.Resize(input_resolution, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(input_resolution),
    T.ToTensor(),
    T.Normalize(MEAN, STD)                              
])

In [None]:
if clip_mode == 'lite' and train_mode == 'classic':
    train_dataset = CLIPDataset(
        f'{dataset_root}/Training', 
        image_transform=image_transform, 
        return_indices=True
    )

    val_dataset = CLIPDataset(
        f'{dataset_root}/Test', 
        image_transform=image_transform, 
        return_indices=True
    )
else:
    train_dataset = CLIPDataset(
        f'{dataset_root}/Training', 
        image_transform=image_transform, 
        prompt_transform=text_transformer,
        return_indices=False
    )

    val_dataset = CLIPDataset(
        f'{dataset_root}/Test', 
        image_transform=image_transform, 
        prompt_transform=text_transformer,
        return_indices=False
    )

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

### The training procedure
---

Init Pytorch Lightning modules

In [None]:
!pip install -q pytorch-lightning

In [None]:
if clip_mode == 'lite':
    clip_lite = CLIP_Lite(model, num_classes=131, training_mode=train_mode)
if clip_mode == 'pro':
    clip_lite = CLIP_Pro(model, num_classes=131, training_mode=train_mode)

In [None]:
trainer = pl.Trainer(
    gpus=1,
    gradient_clip_val=1e-3,
    amp_backend='native',
    auto_lr_find=True
)

In [None]:
trainer.fit(clip_lite, train_loader, val_loader)