In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models, torchvision.datasets
from torch.utils.data import Dataset
from torchvision.io import read_image
import os
import shutil
import random
import matplotlib.pyplot as plt
from torchvision import transforms

In [2]:
!pip install --upgrade pip

Defaulting to user installation because normal site-packages is not writeable


In [3]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable


In [4]:
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable


In [2]:
from datasets import load_dataset

In [3]:
# Only using 1 parquet file. It contains about 5.8m examples
url = 'https://huggingface.co/datasets/kakaobrain/coyo-700m/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet'

data_files = {"train": url}

pre_train_data = load_dataset("parquet", data_files=data_files, split="train")

In [4]:
pre_train_data

Dataset({
    features: ['id', 'url', 'text', 'width', 'height', 'image_phash', 'text_length', 'word_count', 'num_tokens_bert', 'num_tokens_gpt', 'num_faces', 'clip_similarity_vitb32', 'clip_similarity_vitl14', 'nsfw_score_opennsfw2', 'nsfw_score_gantman', 'watermark_score', 'aesthetic_score_laion_v2'],
    num_rows: 5836073
})

In [5]:
pre_train_data[0]['url']

'https://cdn.shopify.com/s/files/1/0286/3900/2698/products/TVN_Huile-olive-infuse-et-s-227x300_e9a90ffd-b6d2-4118-95a1-29a5c7a05a49_800x.jpg?v=1616684087'

In [6]:
pre_train_data[0]['text']

'Olive oil infused with Tuscany herbs'

In [7]:
# using only subset of data for now
pre_train_data = pre_train_data.with_format("torch")
test_data = pre_train_data.select(range(6000, 7000))
val_data = pre_train_data.select(range(5000, 6000))
pre_train_data = pre_train_data.select(range(5000))


In [8]:
pre_train_data

Dataset({
    features: ['id', 'url', 'text', 'width', 'height', 'image_phash', 'text_length', 'word_count', 'num_tokens_bert', 'num_tokens_gpt', 'num_faces', 'clip_similarity_vitb32', 'clip_similarity_vitl14', 'nsfw_score_opennsfw2', 'nsfw_score_gantman', 'watermark_score', 'aesthetic_score_laion_v2'],
    num_rows: 5000
})

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("t5-base")

In [10]:
# TODO: Need to encode the text descriptions, and clean up images
# TODO: Need to create a final dataset with text, and images
# We can create a custom datalaoder that will load images from urls at runtime.


In [11]:
import requests
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class PreTrainDataset(Dataset):
    def __init__(self, dataset, tokenizer, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        url = self.dataset[idx]['url']  
        text = self.dataset[idx]['text'] # text has already been encoded and padded 
#         text = self.encode_text(text)
        return {"url": url, "text": text}
        # try:
        #     response = requests.get(url, timeout=1)
        #     image = Image.open(BytesIO(response.content)).convert("RGB")
        #     if self.transform:
        #         image = self.transform(image)
        #     return image, text
        # except Exception:
        #     return None
#     def encode_text(self, example):
#         text = self.tokenizer(example, padding='max_length', max_length=max_seq_len, add_special_tokens=True) # hard-coded max_length for now
#         bos_id = tokenizer.convert_tokens_to_ids("<s>")
#          # add a bos token as well
#         text = {
#             "input_ids": [bos_id] + text["input_ids"],
#             "attention_mask": [1] + text["attention_mask"]
#         }

        return text

In [12]:
import numpy as np

In [13]:
def remove_none_fn(batch):
    batch_without_nones = [item for item in batch if item is not None]
    if not batch_without_nones:
        return []
    if len(batch_without_nones) < len(batch):
        batch_without_nones.extend([batch_without_nones[-1]] * (len(batch)-len(batch_without_nones)))
    images, texts = zip(*batch_without_nones)
    images = torch.stack(images)
    
    tokenized = tokenizer(
        texts,
        padding="longest",
        return_tensors="pt",
        add_special_tokens=True)
    return images, tokenized

In [105]:
import aiohttp
import asyncio
from PIL import Image
from io import BytesIO
from torchvision import transforms

async def fetch(session, url, text, idx):
    if text is None:
        return None
    try:
        async with session.get(url, timeout=2) as resp:
            if resp.status == 200:
                content = await resp.read()
                image = Image.open(BytesIO(content)).convert("RGB")
                return image, text
    except Exception:
        return None

async def fetch_valid_pairs(batch):
    async with aiohttp.ClientSession() as session:
        tasks = [
            fetch(session, item['url'], item['text'], idx)
            for idx, item in enumerate(batch)
        ]
        results = await asyncio.gather(*tasks)
        # results = results[1]

    # Filter out failed fetches
    # print(results)
    valid_samples = []
    for res in results:
        if res:
            valid_samples.append(res)
    return valid_samples



In [106]:
import nest_asyncio
import asyncio

nest_asyncio.apply()

In [111]:
def async_remove_none_fn(batch_data):
    # print(batch_data)
    # urls = [item['url'] for item in batch_data]
    # texts = [item['text'] for item in batch_data]
    loop = asyncio.get_event_loop()
    valid_samples = loop.run_until_complete(fetch_valid_pairs(batch_data))
    # images = asyncio.run(fetch_all(urls))  # async batch fetch
    images, texts = zip(*valid_samples)
    texts = list(texts)
    # print(texts)
    
    # Apply transforms if needed
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    # valid_img = None
    # for img in images:
    #     if img:
    #         valid_img = img
    #         break
    # if not valid_img:
    #     return []
            
    images = [transform(img) if img else transform(valid_img) for img in images]
    # images = torch.Tensor(images)
    tokenized = tokenizer(
        texts,
        padding="longest",
        return_tensors="pt",
        add_special_tokens=True)
    images = torch.stack(images)
    # print(images)
    # print(tokenized)
    # return list(zip(images, texts))
    return images, tokenized


In [108]:
custom_transforms = transforms.Compose([
    transforms.Resize((272, 272)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
])
pre_train_dataset_cleaned = PreTrainDataset(pre_train_data, tokenizer= tokenizer, transform=custom_transforms)
val_dataset_cleaned = PreTrainDataset(val_data, tokenizer= tokenizer, transform=custom_transforms)
train_loader = DataLoader(pre_train_dataset_cleaned, batch_size=32, shuffle=True, collate_fn=async_remove_none_fn)
val_loader = DataLoader(val_dataset_cleaned, batch_size=10, shuffle=True, collate_fn=async_remove_none_fn)

In [89]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [76]:
def log_grad_norms(model, norm_type=2):
    for name, param in model.named_parameters():
        if param.grad is not None and param.requires_grad:
            grad_norm = param.grad.norm(norm_type).item()
            print(f"{name}: grad norm = {grad_norm:.6f}")


In [77]:
import pickle

In [112]:
def train(model, data, val_data, opt=None, lr=0.0001, weight_decay=0.00000, num_epochs=20, checkpoint_path='../checkpoints/', warmup_steps=100):

    model.train()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

#     n = 0
    
    model = model.to(device)
    train_losses = []
    train_contrastive_losses = []
    train_generative_losses = []
    
    val_losses = []
    val_contrastive_losses = []
    val_generative_losses = []
    epochs = []

    batch_size = 32
    n = 0
    accumulation_steps = 8
    if opt is not None:
        optimizer = opt
    else:
        optimizer = optim.Adam(model.parameters(),
                lr=lr,
                weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, total_iters=warmup_steps)
#     main_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    # schedule = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_steps])
    # # for param in model.vit.parameters():
    #     param.requires_grad = False

    while epoch < num_epochs:

        # Using AdamW for now, can try with other optimizers too


        t_loss = 0
        t_contrastive_loss = 0
        t_generative_loss = 0

        for step, batch in enumerate(data):
            
#             print(batch[0], len(batch[0]))
            # input images, and texts
            if not batch:
                continue
            # print(batch[1])
            imgs = batch[0].type(torch.float32).to(device)
            text = batch[1]['input_ids'].type(torch.long).to(device)
#             print(text)
            unique_imgs = torch.unique(imgs.view(imgs.shape[0], -1), dim=0).size(0)
            print(f"Unique images in batch: {unique_imgs}/{imgs.shape[0]}")

#             if len(imgs) < batch_size:
#                 # Last batch will have less images, text pairs since it will be the
#                 # remainder of Total images / batch_size.

#                 # Adjust the learning rate of the last batch by 
#                 # (size(last_batch) / batch_size) to account 
#                 # for the smaller size.
#                 adj_lr = lr * (len(imgs) / batch_size)
#                 optimizer = optim.AdamW(model.parameters(),
#                     lr=adj_lr,
#                     weight_decay=weight_decay)
            # Since task is to predict next token, the labels will start form position 1
            text_labels = text[:, 1:] 
            total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)
            total_loss = total_loss / accumulation_steps
            
            n += 1
            print("-----------------------------------------------------------")
            print(f"Iter: {n}   Total Loss: {total_loss.item() * accumulation_steps}   Gen Loss: {generative_loss.item()}   Contr Loss: {contrastive_loss.item()}")
            total_loss.backward(retain_graph=True)
            # contrastive_loss.backward(retain_graph=True)
            # print("contrastive_norms")
            # log_grad_norms(model)
            # generative_loss.backward(retain_graph=True)
            # print("generative_norms")
            # log_grad_norms(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2)
            # i = 0
#             for name, param in model.named_parameters():
#                 if param.grad is not None:
#                     print(f"{name}: grad norm = {param.grad.norm().item():.4f}")
#                 i += 1
#                 if i > 10:
#                     break
            if n % accumulation_steps == 0: # accumulate gradients to artifically increase batch size for learning
               
                optimizer.step()
                scheduler.step(total_loss)
                optimizer.zero_grad()

            
            # accumulate epoch loss
            t_loss += total_loss.detach() * accumulation_steps
            t_contrastive_loss += contrastive_loss.detach()
            t_generative_loss += generative_loss.detach()
            del imgs
            del text
            if n % 100 == 0:
#                 torch.save(model.state_dict(), f"{checkpoint_path}_iter_{n}")
                torch.save({
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        }, f"{checkpoint_path}_model_checkpoint.pt")


        # end of epoch


        epoch += 1

        train_losses.append(t_loss / len(data))
        train_contrastive_losses.append(t_contrastive_loss / len(data))
        train_generative_losses.append(t_generative_loss / len(data))

        epochs.append(epoch)

        val_loss, val_contrastive_loss, val_generative_loss = validation(model, val_data)
        val_losses.append(val_loss)
        val_contrastive_losses.append(val_contrastive_loss)
        val_generative_losses.append(val_generative_loss)
        
        with open(f"{checkpoint_path}_train_loss.pkl", 'wb') as f:
            pickle.dump(train_losses, f)
        with open(f"{checkpoint_path}_train_cont_loss.pkl", 'wb') as f:
            pickle.dump(train_contrastive_losses, f)
        with open(f"{checkpoint_path}_train_gen_loss.pkl", 'wb') as f:
            pickle.dump(train_generative_losses, f)

        with open(f"{checkpoint_path}_val_loss.pkl", 'wb') as f:
            pickle.dump(val_losses, f)
        with open(f"{checkpoint_path}_val_cont_loss.pkl", 'wb') as f:
            pickle.dump(val_contrastive_losses, f)
        with open(f"{checkpoint_path}_val_gen_loss.pkl", 'wb') as f:
            pickle.dump(val_generative_losses, f)
            
#         if epoch % 5 == 0: # save model every 5th epoch
#         torch.save(model.state_dict(), f"{checkpoint_path}_epoch_{epoch}")
        torch.save({
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                }, f"_model_checkpoint.pt")
            
        print("Epoch {}:  Train loss: {}   Train Contrastive Loss: {}   Train Generative Loss: {}]".format(epoch, t_loss / len(data), t_contrastive_loss / len(data), t_generative_loss / len(data)))
        print("Epoch {}:  Val loss: {}   Val Contrastive Loss: {}   Val Generative Loss: {}]".format(epoch, val_loss / len(val_data), val_contrastive_loss / len(val_data), val_generative_loss / len(val_data)))

    return train_losses, train_contrastive_losses, train_generative_losses, val_losses, val_contrastive_losses, val_generative_losses
    

In [79]:
def validation(model, data):
    
    model.eval()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    epoch = 0

    model.to(device)

    val_loss = 0
    val_contrastive_loss = 0
    val_generative_loss = 0
    
    for step, batch in enumerate(data):

        # input images, and texts
        imgs = batch[0].type(torch.float32).to(device)
        text = batch[1]['input_ids'].type(torch.long).to(device)
        # Since task is to predict next token, the labels will start form position 1
        text_labels = text[:, 1:] 
        total_loss, contrastive_loss, generative_loss = model(imgs, text, text_labels)

        val_loss += total_loss.detach()
        val_contrastive_loss += contrastive_loss.detach()
        val_generative_loss += generative_loss.detach()

    return val_loss, val_contrastive_loss, val_generative_loss
    

In [80]:
import os
os.chdir("models")

FileNotFoundError: [Errno 2] No such file or directory: 'models'

In [113]:
from model import MaMMUT
model = MaMMUT(vocab_size=tokenizer.vocab_size,
                image_size= 224,
                patch_size = 16,
                vit_num_layers= 6,
                vit_num_heads= 8,
                vit_hidden_dim = 768,
                vit_mlp_dim = 2048,
                vit_dropout = 0.0, # Potential ablation / extension to add to the replication
                vit_attention_dropout = 0.0, # Potential ablation / extension to add to the replication
                contrastive_loss_weight = 0.5,
                generative_loss_weight = 1.0,
                text_decoder_depth = 4,
                text_decoder_embed_dim = 512,
                text_decoder_sub_layer_heads = 8,
                text_decoder_feedforward_dim = 2048,
                text_decoder_dk = 128,
                latent_dim = 512,
                contrastive_loss_gamma = 1.0
              )

In [24]:
c = torch.load('../checkpoints/_model_checkpoint.pt', weights_only=True)


In [25]:
model.load_state_dict(c['model_state_dict'])

<All keys matched successfully>

In [None]:
opt = optim.AdamW(model.parameters(),
                lr=0.0001,
                weight_decay=0)
opt.load_state_dict(c['optimizer_state_dict'])

In [None]:
train(model=model, opt=None, data=train_loader, val_data=val_loader)

Unique images in batch: 16/16
-----------------------------------------------------------
Iter: 1   Total Loss: 45.89055633544922   Gen Loss: 10.911336898803711   Contr Loss: 69.95843505859375
Unique images in batch: 19/19
-----------------------------------------------------------
Iter: 2   Total Loss: 35.15397644042969   Gen Loss: 10.906286239624023   Contr Loss: 48.49537658691406
Unique images in batch: 22/22
-----------------------------------------------------------
Iter: 3   Total Loss: 39.187435150146484   Gen Loss: 10.796002388000488   Contr Loss: 56.782867431640625
Unique images in batch: 22/22
-----------------------------------------------------------
Iter: 4   Total Loss: 28.674766540527344   Gen Loss: 10.839180946350098   Contr Loss: 35.671173095703125
Unique images in batch: 19/19
-----------------------------------------------------------
Iter: 5   Total Loss: 36.634891510009766   Gen Loss: 10.921138763427734   Contr Loss: 51.42750549316406
Unique images in batch: 26/26


In [None]:
accumulation_steps