In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models, torchvision.datasets
from torch.utils.data import Dataset
from torchvision.io import read_image
import os
import shutil
import random
import matplotlib.pyplot as plt
from torchvision import transforms

In [2]:
!pip install --upgrade pip

Defaulting to user installation because normal site-packages is not writeable
^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [3]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable


In [4]:
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable


In [5]:
from datasets import load_dataset

In [6]:
# Only using 1 parquet file. It contains about 5.8m examples
url = 'https://huggingface.co/datasets/kakaobrain/coyo-700m/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet'

data_files = {"train": url}

pre_train_data = load_dataset("parquet", data_files=data_files, split="train")

In [7]:
pre_train_data

Dataset({
    features: ['id', 'url', 'text', 'width', 'height', 'image_phash', 'text_length', 'word_count', 'num_tokens_bert', 'num_tokens_gpt', 'num_faces', 'clip_similarity_vitb32', 'clip_similarity_vitl14', 'nsfw_score_opennsfw2', 'nsfw_score_gantman', 'watermark_score', 'aesthetic_score_laion_v2'],
    num_rows: 5836073
})

In [6]:
pre_train_data[0]['url']

'https://cdn.shopify.com/s/files/1/0286/3900/2698/products/TVN_Huile-olive-infuse-et-s-227x300_e9a90ffd-b6d2-4118-95a1-29a5c7a05a49_800x.jpg?v=1616684087'

In [7]:
pre_train_data[0]['text']

'Olive oil infused with Tuscany herbs'

In [23]:
# using only subset of data for now
pre_train_data = pre_train_data.with_format("torch")
test_data = pre_train_data.select(range(11000, 12000))
val_data = pre_train_data.select(range(10000, 11000))
pre_train_data = pre_train_data.select(range(10000))


In [24]:
pre_train_data

Dataset({
    features: ['id', 'url', 'text', 'width', 'height', 'image_phash', 'text_length', 'word_count', 'num_tokens_bert', 'num_tokens_gpt', 'num_faces', 'clip_similarity_vitb32', 'clip_similarity_vitl14', 'nsfw_score_opennsfw2', 'nsfw_score_gantman', 'watermark_score', 'aesthetic_score_laion_v2'],
    num_rows: 10000
})

In [25]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("t5-base")

In [26]:
# TODO: Need to encode the text descriptions, and clean up images
# TODO: Need to create a final dataset with text, and images
# We can create a custom datalaoder that will load images from urls at runtime.


In [27]:
import requests
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PreTrainDataset(Dataset):
    def __init__(self, dataset, tokenizer, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        url = self.dataset[idx]['url']  
        text = self.dataset[idx]['text'] # text has already been encoded and padded 
#         text = self.encode_text(text)
        try:
            response = requests.get(url, timeout=1)
            image = Image.open(BytesIO(response.content)).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, text
        except Exception:
            return None
#     def encode_text(self, example):
#         text = self.tokenizer(example, padding='max_length', max_length=max_seq_len, add_special_tokens=True) # hard-coded max_length for now
#         bos_id = tokenizer.convert_tokens_to_ids("<s>")
#          # add a bos token as well
#         text = {
#             "input_ids": [bos_id] + text["input_ids"],
#             "attention_mask": [1] + text["attention_mask"]
#         }

        return text

In [28]:
import numpy as np

In [29]:
def remove_none_fn(batch):
    batch_without_nones = [item for item in batch if item is not None]
    if not batch_without_nones:
        return []
    if len(batch_without_nones) < len(batch):
        batch_without_nones.extend([batch_without_nones[-1]] * (len(batch)-len(batch_without_nones)))
    images, texts = zip(*batch_without_nones)
    images = torch.stack(images)
    
    tokenized = tokenizer(
        texts,
        padding="longest",
        return_tensors="pt",
        add_special_tokens=True)
    return images, tokenized

In [30]:
custom_transforms = transforms.Compose([
    transforms.Resize((272, 272)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
])
pre_train_dataset_cleaned = PreTrainDataset(pre_train_data, tokenizer= tokenizer, transform=custom_transforms)
val_dataset_cleaned = PreTrainDataset(val_data, tokenizer= tokenizer, transform=custom_transforms)
train_loader = DataLoader(pre_train_dataset_cleaned, batch_size=64, shuffle=True, collate_fn=remove_none_fn)
val_loader = DataLoader(val_dataset_cleaned, batch_size=10, shuffle=True, collate_fn=remove_none_fn)

In [33]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [32]:
def log_grad_norms(model, norm_type=2):
    for name, param in model.named_parameters():
        if param.grad is not None and param.requires_grad:
            grad_norm = param.grad.norm(norm_type).item()
            print(f"{name}: grad norm = {grad_norm:.6f}")


In [21]:
def train(model, data, val_data, opt=None, lr=0.0001, weight_decay=0.00000, num_epochs=20, checkpoint_path='../checkpoints/'):

    model.train()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

#     n = 0

    model = model.to(device)
    train_losses = []
    train_contrastive_losses = []
    train_generative_losses = []
    
    val_losses = []
    val_contrastive_losses = []
    val_generative_losses = []
    epochs = []

    batch_size = 64
    n = 0
    accumulation_steps = 4
    if opt is not None:
        optimizer = opt
    else:
        optimizer = optim.AdamW(model.parameters(),
                lr=lr,
                weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    while epoch < num_epochs:

        # Using AdamW for now, can try with other optimizers too


        t_loss = 0
        t_contrastive_loss = 0
        t_generative_loss = 0

        for step, batch in enumerate(data):
            
#             print(batch[0], len(batch[0]))
            # input images, and texts
            if not batch:
                continue
            imgs = batch[0].type(torch.float32).to(device)
            text = batch[1]['input_ids'].type(torch.long).to(device)
#             print(text)

#             if len(imgs) < batch_size:
#                 # Last batch will have less images, text pairs since it will be the
#                 # remainder of Total images / batch_size.

#                 # Adjust the learning rate of the last batch by 
#                 # (size(last_batch) / batch_size) to account 
#                 # for the smaller size.
#                 adj_lr = lr * (len(imgs) / batch_size)
#                 optimizer = optim.AdamW(model.parameters(),
#                     lr=adj_lr,
#                     weight_decay=weight_decay)
            # Since task is to predict next token, the labels will start form position 1
            text_labels = text[:, 1:] 
            total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)
            total_loss = total_loss / accumulation_steps
            
            n += 1
            print("-----------------------------------------------------------")
            print(f"Iter: {n}   Total Loss: {total_loss.item() * accumulation_steps}   Gen Loss: {generative_loss.item()}   Contr Loss: {contrastive_loss.item()}")
#             total_loss.backward()
            contrastive_loss.backward(retain_graph=True)
            print("contrastive_norms")
            log_grad_norms(model)
            generative_loss.backward(retain_graph=True)
            print("generative_norms")
            log_grad_norms(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            i = 0
#             for name, param in model.named_parameters():
#                 if param.grad is not None:
#                     print(f"{name}: grad norm = {param.grad.norm().item():.4f}")
#                 i += 1
#                 if i > 10:
#                     break
            if n % accumulation_steps == 0: # accumulate gradients to artifically increase batch size for learning
                optimizer.step()
                scheduler.step(total_loss)
                optimizer.zero_grad()

            
            # accumulate epoch loss
            t_loss += total_loss.detach()
            t_contrastive_loss += contrastive_loss.detach()
            t_generative_loss += generative_loss.detach()
            del imgs
            del text
            if n % 100 == 0:
#                 torch.save(model.state_dict(), f"{checkpoint_path}_iter_{n}")
                torch.save({
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        }, f"{checkpoint_path}_model_checkpoint.pt")

        # end of epoch


        epoch += 1

        train_losses.append(t_loss / len(data))
        train_contrastive_losses.append(t_contrastive_loss / len(data))
        train_generative_losses.append(t_generative_loss / len(data))

        epochs.append(epoch)

        val_loss, val_contrastive_loss, val_generative_loss = validation(model, val_data)
        val_losses.append(val_loss)
        val_contrastive_losses.append(val_contrastive_loss)
        val_generative_losses.append(val_generative_loss)
        
#         if epoch % 5 == 0: # save model every 5th epoch
#         torch.save(model.state_dict(), f"{checkpoint_path}_epoch_{epoch}")
        torch.save({
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                }, f"_model_checkpoint.pt")
            
        print("Epoch {}:  Train loss: {}   Train Contrastive Loss: {}   Train Generative Loss: {}]".format(epoch, t_loss / len(data), t_contrastive_loss / len(data), t_generative_loss / len(data)))
        print("Epoch {}:  Val loss: {}   Val Contrastive Loss: {}   Val Generative Loss: {}]".format(epoch, val_loss / len(val_data), val_contrastive_loss / len(val_data), val_generative_loss / len(val_data)))

    return train_losses, train_contrastive_losses, train_generative_losses, val_losses, val_contrastive_losses, val_generative_losses
    

In [16]:
import os
os.chdir("models")

In [17]:
from model import MaMMUT
model = MaMMUT(vocab_size=tokenizer.vocab_size, contrastive_loss_weight=2)

In [18]:
c = torch.load('../checkpoints/_model_checkpoint.pt', weights_only=True)


In [None]:
model.load_state_dict(c['model_state_dict'])

In [43]:
opt = optim.AdamW(model.parameters(),
                lr=0.0001,
                weight_decay=0)
opt.load_state_dict(c['optimizer_state_dict'])

In [34]:
train(model=model, opt=None, data=train_loader, val_data=val_loader)



-----------------------------------------------------------
Iter: 1   Total Loss: 18.854232788085938   Gen Loss: 10.431276321411133   Contr Loss: 4.211477756500244
contrastive_norms
pos_embedding: grad norm = 0.035248
text_cls_token: grad norm = 0.244320
vit.class_token: grad norm = 0.044123
vit.conv_proj.weight: grad norm = 0.380109
vit.conv_proj.bias: grad norm = 0.075046
vit.encoder.pos_embedding: grad norm = 0.044734
vit.encoder.layers.encoder_layer_0.ln_1.weight: grad norm = 0.020298
vit.encoder.layers.encoder_layer_0.ln_1.bias: grad norm = 0.034181
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_weight: grad norm = 0.711618
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_bias: grad norm = 0.039051
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.weight: grad norm = 1.448126
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.bias: grad norm = 0.056747
vit.encoder.layers.encoder_layer_0.ln_2.weight: grad norm = 0.036612
vit.encoder.layers.encod

pos_embedding: grad norm = 0.037735
text_cls_token: grad norm = 0.244320
vit.class_token: grad norm = 0.046358
vit.conv_proj.weight: grad norm = 0.407854
vit.conv_proj.bias: grad norm = 0.075534
vit.encoder.pos_embedding: grad norm = 0.046944
vit.encoder.layers.encoder_layer_0.ln_1.weight: grad norm = 0.021165
vit.encoder.layers.encoder_layer_0.ln_1.bias: grad norm = 0.034864
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_weight: grad norm = 0.752884
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_bias: grad norm = 0.040038
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.weight: grad norm = 1.524135
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.bias: grad norm = 0.058481
vit.encoder.layers.encoder_layer_0.ln_2.weight: grad norm = 0.038067
vit.encoder.layers.encoder_layer_0.ln_2.bias: grad norm = 0.029651
vit.encoder.layers.encoder_layer_0.mlp.0.weight: grad norm = 1.365962
vit.encoder.layers.encoder_layer_0.mlp.0.bias: grad norm = 0.039167
v

-----------------------------------------------------------
Iter: 2   Total Loss: 18.57837677001953   Gen Loss: 10.146315574645996   Contr Loss: 4.216030120849609
contrastive_norms
pos_embedding: grad norm = 0.022373
text_cls_token: grad norm = 0.081758
vit.class_token: grad norm = 0.017312
vit.conv_proj.weight: grad norm = 0.157158
vit.conv_proj.bias: grad norm = 0.026102
vit.encoder.pos_embedding: grad norm = 0.017479
vit.encoder.layers.encoder_layer_0.ln_1.weight: grad norm = 0.009651
vit.encoder.layers.encoder_layer_0.ln_1.bias: grad norm = 0.013896
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_weight: grad norm = 0.329171
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_bias: grad norm = 0.014815
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.weight: grad norm = 0.601913
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.bias: grad norm = 0.018694
vit.encoder.layers.encoder_layer_0.ln_2.weight: grad norm = 0.011209
vit.encoder.layers.encode

pos_embedding: grad norm = 0.022564
text_cls_token: grad norm = 0.081758
vit.class_token: grad norm = 0.025342
vit.conv_proj.weight: grad norm = 0.210282
vit.conv_proj.bias: grad norm = 0.027699
vit.encoder.pos_embedding: grad norm = 0.025466
vit.encoder.layers.encoder_layer_0.ln_1.weight: grad norm = 0.011705
vit.encoder.layers.encoder_layer_0.ln_1.bias: grad norm = 0.015874
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_weight: grad norm = 0.419620
vit.encoder.layers.encoder_layer_0.self_attention.in_proj_bias: grad norm = 0.018709
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.weight: grad norm = 0.796260
vit.encoder.layers.encoder_layer_0.self_attention.out_proj.bias: grad norm = 0.026827
vit.encoder.layers.encoder_layer_0.ln_2.weight: grad norm = 0.015409
vit.encoder.layers.encoder_layer_0.ln_2.bias: grad norm = 0.013698
vit.encoder.layers.encoder_layer_0.mlp.0.weight: grad norm = 0.555653
vit.encoder.layers.encoder_layer_0.mlp.0.bias: grad norm = 0.018320
v

KeyboardInterrupt: 

##### def validation(model, data):
    
    model.eval()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

    model.to(device)

    val_loss = 0
    val_contrastive_loss = 0
    val_generative_loss = 0
    
    for step, batch in enumerate(data):

        # input images, and texts
        imgs = batch[0].type(torch.float32).to(device)
        text = batch[1]['input_ids'].type(torch.long).to(device)
        # Since task is to predict next token, the labels will start form position 1
        text_labels = text[:, 1:] 
        total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)

        val_loss += total_loss.detach()
        val_contrastive_loss += contrastive_loss.detach()
        val_generative_loss += generative_loss.detach()

    return val_loss, val_contrastive_loss, val_generative_loss
    