In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import pandas as pd

from config import Config
from datetime import datetime
from img_embedding import ImageEmbedding
from img_transformer import ImgTransformer
from img_util import show_img_tensor_CHW
from fliker_comment_tokenizer import FlikerCommentTokenizer
from fliker_img_comment_dataset import ImgCommentDataset
from model_util import count_parameters
from pathlib import Path
from text_token_embedding import TextTokenEmbedding
from text_casual_mask_transformer import TextMaskedTransformer
from torch.utils.tensorboard import SummaryWriter


import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = Config()

train_dataset = ImgCommentDataset(config, split="train")
eval_dataset = ImgCommentDataset(config, split="eval")
test_dataset= ImgCommentDataset(config, split="test")
print(f"train_dataset:  {len(train_dataset)}")
print(f"eval_dataset:  {len(eval_dataset)}")
print(f"test_dataset:  {len(test_dataset)}")


# Data Loader
BATCH_SIZE = 10
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"train_dataloader:  {len(test_dataset)}")
print(f"eval_data_loader:  {len(test_dataset)}")
print(f"test_data_loader:  {len(test_dataset)}")


Enriched img id: /tmp/enriched_results.csv
tokens: 128000
tokenizer.is_fast: True
Enriched img id: /tmp/enriched_results.csv
tokens: 128000
tokenizer.is_fast: True
Enriched img id: /tmp/enriched_results.csv
tokens: 128000
tokenizer.is_fast: True
train_dataset:  114418
eval_dataset:  28605
test_dataset:  15892
train_dataloader:  15892
eval_data_loader:  15892
test_data_loader:  15892


In [3]:
class ImgLanguageModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.img_embedding = ImageEmbedding(config=config)
        self.img_transfomer = ImgTransformer(config=config)
        self.img_flatten = nn.Flatten(start_dim=1)
        self.img_proj = nn.Linear(in_features=config.img_patches* config.img_patch_embedding, out_features=config.img_text_proj_features)
        self.img_softmax = nn.LogSoftmax(dim=-1)

        # self.text_embedding = TextTokenEmbedding(config=config)
        self.text_transformer = TextMaskedTransformer(config=config)
        self.text_flatten = nn.Flatten(start_dim=1)
        self.text_proj = nn.Linear(in_features=config.max_text_len* config.text_token_embedding, out_features=config.img_text_proj_features)
        self.text_softmax = nn.LogSoftmax(dim=-1)
        
        self.diag_mask = torch.diag(torch.ones(config.img_text_proj_features))
        self.loss_fn = nn.NLLLoss()

    def forward(self, batch_img_tensor: torch.tensor, batch_text_tensor: torch.tensor, batch_img_id_tensor: torch.tensor=None):
        """
        batch_img_tensor: B x IMG_PATCHES x IMG_EMB
        batch_text_tensor: B x TEXT_TOKEN
        """
        img_embedding = self.img_embedding(batch_img_tensor) # B x IMG_PATCHES x IMG_EMB
        # print(f"img_encoding: {img_embedding.size()}")

        img_feature = self.img_transfomer(img_embedding) # B x IMG_PATCHES x IMG_EMB
        # print(f"img_feature: {img_feature.size()}")

        img_feature = self.img_flatten(img_feature)
        # print(f"img_feature: {img_feature.size()}")

        img_feature = self.img_proj(img_feature)
        # print(f"img_feature: {img_feature.size()}")  # B x img_text_proj_features

        # text_embedding = self.text_embedding(batch_text_tensor)
        # print(f"text_embedding: {text_embedding.size()}")

        text_feature = self.text_transformer(batch_text_tensor)
        # print(f"text_feature: {text_feature.size()}")

        text_feature = self.text_flatten(text_feature)
        # print(f"text_feature: {text_feature.size()}")

        text_feature = self.text_proj(text_feature)
        # print(f"text_feature: {text_feature.size()}")  # B x img_text_proj_features

        # Contrastive learning
        contrastive_scores = img_feature @ text_feature.T
        # print(f"contractive_scores: {contrastive_scores}")  # B x img_text_proj_features

        img_contrastive_prob = self.img_softmax(contrastive_scores)
        # print(f"img_contrastive_prob: {img_contrastive_prob}")  # B x img_text_proj_features
        
        target = torch.arange(img_contrastive_prob.size()[0], device=img_contrastive_prob.device)
        img_loss = self.loss_fn(img_contrastive_prob, target)
        # img_loss = self.loss_fn(img_contrastive_prob, self.target.expand(img_contrastive_prob.size()[0], -1))
        # print(f"img_loss: {img_loss}")

        text_contrastive_prob = self.text_softmax(contrastive_scores.T)
        # print(f"text_contrastive_prob: {text_contrastive_prob}")  # B x img_text_proj_features
        text_loss = self.loss_fn(text_contrastive_prob, target)
        # print(f"text_loss: {text_loss}")
        
        return img_loss, text_loss

In [4]:
# train_dataset = ImgCommentDataset(config, split="train")

# BATCH_SIZE = 10
# train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

batch_img_tensor, batch_img_id_tensor, batch_comment_encoding = next(iter(train_dataloader))
print(f"batch_img_tensor: {batch_img_tensor.size()}")
print(f"batch_img_id_tensor: {batch_img_id_tensor.size()}")
print(f"batch_comment_encoding: {batch_comment_encoding.size()}")

batch_img_tensor: torch.Size([10, 3, 512, 512])
batch_img_id_tensor: torch.Size([10])
batch_comment_encoding: torch.Size([10, 50])


In [5]:
model = ImgLanguageModel(config=config)
img_loss, text_loss = model(batch_img_tensor=batch_img_tensor, batch_text_tensor=batch_comment_encoding, batch_img_id_tensor=batch_img_id_tensor)

pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"pytorch_total_params: {pytorch_total_params/10**6} m")
print(f"pytorch_total_trainable_params: {pytorch_total_trainable_params/10**6} m")
count_parameters(model)

In [None]:
EPOCHES = 1
EVAL_INTERVAL = 100
EVAL_STEPS = 10
lr = 0.001

device = torch.device("mps")

model = model.to(device)
optimizer =  torch.optim.AdamW(params=model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_dataloader), epochs=EPOCHES)

def eval(model: ImgLanguageModel, global_step : int, writer: SummaryWriter):
    model.eval()

    avg_eval_loss = None
    eval_loss_std = None
    with torch.no_grad():
        eval_losses = []
        for i, data in enumerate(eval_dataloader):
            if i > EVAL_STEPS:
                # It takes significant time to do one full eval.
                break

            batch_img_tensor, batch_target_tensor = data
            batch_img_tensor = batch_img_tensor.to(device)
            batch_target_tensor = batch_target_tensor.to(device)
            img_loss, text_loss = model(batch_img_tensor, batch_target_tensor)
            writer.add_scalar("eval/Img Loss", img_loss, global_step)
            writer.add_scalar("eval/Text Loss", text_loss, global_step)
            eval_losses.append(img_loss + text_loss)
        eval_losses = torch.tensor(eval_losses)
        avg_eval_loss = eval_losses.mean()
        eval_loss_std = eval_losses.std()
        writer.add_scalar("eval/Loss", avg_eval_loss, global_step)
        writer.add_scalar("Loss/eval-std", eval_loss_std, global_step)
    model.train()
    writer.flush()
    return avg_eval_loss, eval_loss_std
    


def train(model: ImgLanguageModel, writer: SummaryWriter):
    best_vloss = torch.tensor(1_000_000)
    with torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./runs'),
            activities=[torch.profiler.ProfilerActivity.CPU], 
            record_shapes=True,
            profile_memory=True,
            with_stack=True
    ) as prof:
    # with torch.mps.profiler.profile(mode="interval", wait_until_completed=False):
        for epoch in range(EPOCHES): 
            for train_step, data in enumerate(train_dataloader):
                global_step = epoch * len(train_dataloader) + train_step

                # Profile
                if global_step < 1 + 1 + 3:
                    prof.step()

                batch_img_tensor, batch_img_id_tensor, batch_target_tensor = data
                batch_img_tensor = batch_img_tensor.to(device)
                batch_target_tensor = batch_target_tensor.to(device)

                # Viz Model
                if global_step == 0:
                    writer.add_graph(model, (batch_img_tensor, batch_target_tensor))

                optimizer.zero_grad()
                img_loss, text_loss = model(batch_img_tensor, batch_target_tensor)
                writer.add_scalar("train/Img Loss", img_loss, global_step)
                writer.add_scalar("train/Text Loss", text_loss, global_step)
                writer.add_scalar("train/Loss", img_loss+text_loss, global_step)
                writer.add_scalar("Learning Rate", scheduler.get_last_lr()[-1], global_step)
                loss = img_loss + text_loss
                loss.backward()
                optimizer.step()
                scheduler.step()

                if train_step > 0 and train_step % EVAL_INTERVAL == 0:
                    avg_vloss, _ = eval(model=model, global_step=global_step, writer=writer)
                
                    if avg_vloss is not None and avg_vloss < best_vloss:
                        best_vloss = avg_vloss
                        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                        model_path = f"multi_label_img_classifier_{epoch}_{timestamp}"
                        torch.save(model.state_dict(), model_path)

with SummaryWriter(flush_secs=1) as writer:
    train(model=model, writer=writer)