In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models, torchvision.datasets
from torch.utils.data import Dataset
from torchvision.io import read_image
import os
import shutil
import random
import matplotlib.pyplot as plt
from torchvision import transforms

In [2]:
!pip install --upgrade pip

Defaulting to user installation because normal site-packages is not writeable


In [3]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable


In [4]:
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable


In [5]:
from datasets import load_dataset

In [6]:
# Only using 1 parquet file. It contains about 5.8m examples
url = 'https://huggingface.co/datasets/kakaobrain/coyo-700m/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet'

data_files = {"train": url}

pre_train_data = load_dataset("parquet", data_files=data_files, split="train")

In [7]:
pre_train_data

Dataset({
    features: ['id', 'url', 'text', 'width', 'height', 'image_phash', 'text_length', 'word_count', 'num_tokens_bert', 'num_tokens_gpt', 'num_faces', 'clip_similarity_vitb32', 'clip_similarity_vitl14', 'nsfw_score_opennsfw2', 'nsfw_score_gantman', 'watermark_score', 'aesthetic_score_laion_v2'],
    num_rows: 5836073
})

In [8]:
pre_train_data[0]['url']

'https://cdn.shopify.com/s/files/1/0286/3900/2698/products/TVN_Huile-olive-infuse-et-s-227x300_e9a90ffd-b6d2-4118-95a1-29a5c7a05a49_800x.jpg?v=1616684087'

In [9]:
pre_train_data[0]['text']

'Olive oil infused with Tuscany herbs'

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("t5-base")

In [11]:
# TODO: Need to encode the text descriptions, and clean up images
# TODO: Need to create a final dataset with text, and images
# We can create a custom datalaoder that will load images from urls at runtime.


In [44]:
import requests
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PreTrainDataset(Dataset):
    def __init__(self, dataset, tokenizer, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        url = self.dataset[idx]['url']  
        text = self.dataset[idx]['text'] # text has already been encoded and padded 
        text = self.encode_text(text)
        try:
            response = requests.get(url, timeout=5)
            image = Image.open(BytesIO(response.content)).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, text
        except Exception:
            return None # don't return anything if image cannot be loaded
    def encode_text(self, example):
        text = self.tokenizer(example, padding='max_length', max_length=75, add_special_tokens=True) # hard-coded max_length for now
        bos_id = tokenizer.convert_tokens_to_ids("<s>")
         # add a bos token as well
        text = {
            "input_ids": [bos_id] + text["input_ids"],
            "attention_mask": [1] + text["attention_mask"]
        }

        return text

In [45]:
def remove_none_fn(batch):
    return tuple(zip(*[item for item in batch if item is not None]))

In [46]:
custom_transforms = transforms.Compose([
    transforms.Resize((272, 272)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
])
pre_train_dataset_cleaned = PreTrainDataset(pre_train_data, tokenizer= tokenizer, transform=custom_transforms)
train_loader = DataLoader(pre_train_dataset_cleaned, batch_size=16, shuffle=True, collate_fn=remove_none_fn)

In [47]:
# Test run
for i, j in train_loader:
    print(i)
    print(j)
    print(tokenizer.decode(j[0]['input_ids']))

    break

(tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.8549, 0.8549, 0.8549,  ..., 0.8549, 0.8549, 0.8549],
         [0.8549, 0.8549, 0.8549,  ..., 0.8549, 0.8549, 0.8549],
         [0.8549, 0.8549, 0.8549,  ..., 0.8549, 0.8549, 0.8549]],

        [[0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
         [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
         [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
         ...,
         [0.7490, 0.7490, 0.7490,  ..., 0.7490, 0.7490, 0.7490],
         [0.7490, 0.7490, 0.7490,  ..., 0.7490, 0.7490, 0.7490],
         [0.7490, 0.7490, 0.7490,  ..., 0.7490, 0.7490, 0.7490]],

        [[0.8941, 0.8941, 0.8941,  ..., 0.8941, 0.8941, 0.8941],
         [0.8941, 0.8941, 0.8941,  ..., 0.8941, 0.8941, 0.8941],
         [0.8941, 0.8941, 0.8941,  ..., 0.8941, 0.8941, 0

In [29]:
def train(model, data, lr, weight_decay, num_epochs, checkpoint_path):
    
    model.train()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

    n = 0

    model.to(device)
    train_losses = []
    train_contrastive_losses = []
    train_generative_losses = []
    
    val_losses = []
    val_contrastive_losses = []
    val_generative_losses = []

    while epoch < num_epochs:

        # Using AdamW for now, can try with other optimizers too
       
        optimizer = optim.AdamW(model.parameters(),
                lr=lr,
                weight_decay=weight_decay)
        t_loss = 0
        t_contrastive_loss = 0
        t_generative_loss = 0
        for step, batch in enumerate(loader):
            
            # input images, and texts
            imgs = batch[0].type(torch.long).to(device)
            text = batch[1]['input_ids'].type(torch.long).to(device)

            if len(inp) < batch_size:
                # Last batch will have less images, text pairs since it will be the
                # remainder of Total images / batch_size.

                # Adjust the learning rate of the last batch by 
                # (size(last_batch) / batch_size) to account 
                # for the smaller size.
                adj_lr = learning_rate * (len(inp) / batch_size)
                optimizer = optim.AdamW(model.parameters(),
                    lr=adj_lr,
                    weight_decay=weight_decay)

            text_labels = text[:, 1:] # labels are the same text just with the <s> token removed
            total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)
            total_loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            
            n += 1
            iters.append(n)
            
            # accumulate epoch loss
            t_loss += total_loss
            t_contrastive_loss += contrastive_loss
            t_generative_loss += generative_loss

        # end of epoch


        epoch += 1

        train_losses.append(t_loss / len(loader))
        train_contrastive_losses.append(t_contrastive_loss / len(loader))
        train_generative_losses.append(t_generative_loss / len(loader))

        epochs.append(epoch)

        val_loss, val_contrastive_loss, val_generative_loss = validation(model, val_data)
        val_losses.append(val_loss)
        val_contrastive_losses.append(val_contrastive_loss)
        val_generative_losses.append(val_generative_loss)
        
        if epoch % 5 == 0: # save model every 5th epoch
            torch.save(model.state_dict(), checkpoint_path.format(epoch))
            
        print("Epoch {}:  Train loss: {}   Train Contrastive Loss: {}   Train Generative Loss: {}]".format(epoch, t_loss / len(loader), t_contrastive_loss / len(loader), t_generative_loss / len(loader)))
        print("Epoch {}:  Val loss: {}   Val Contrastive Loss: {}   Val Generative Loss: {}]".format(epoch, val_loss / len(loader), val_contrastive_loss / len(loader), val_generative_loss / len(loader)))

    return train_losses, train_contrastive_losses, train_generative_losses, val_losses, val_contrastive_losses, val_generative_losses
    

In [None]:
def validation(model, data):
    
    model.eval()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

    model.to(device)

    val_loss = 0
    val_contrastive_loss = 0
    val_generative_loss = 0
    
    for step, batch in enumerate(loader):

        # input images, and texts
        imgs = batch[0].type(torch.long).to(device)
        text = batch[1]['input_ids'].type(torch.long).to(device)

        text_labels = text[:, 1:] # labels are the same text just with the <s> token removed
        total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)

        val_loss += total_loss
        val_contrastive_loss += contrastive_loss
        val_generative_loss += generative_loss

    return val_loss, val_contrastive_loss, val_generative_loss
    