In [75]:
import os
import io

import clip
import einops
import torch
import numpy as np
import torch
import webdataset as wds
import torch.nn.functional as F

from torch.utils.data import DataLoader
from einops import rearrange, repeat
from torch import nn, einsum


import pandas as pd

In [38]:
!pip3.8 install git+https://github.com/openai/CLIP.git
!pip3.8 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip3.8 install webdataset
!pip3.8 install pandas
!pip3.8 install einops

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


# DataLoader

In [21]:
# Dataset code:
"""
utils for processing datasets of format described in https://github.com/iejMac/clip-video-encode/pull/13
used https://github.com/rom1504/laion-prepro/blob/main/laion5B/usage_guide/dataloader_pytorch.py as template
"""


def standardize_embedding_shape(emb, seq_len):
    if len(emb) > seq_len:
        print(f"Warning: Raw embedding is longer than standard sequence length ({len(emb)} > {seq_len})")
        emb = emb[:seq_len]

    pad = np.zeros((seq_len - len(emb), emb.shape[1]), dtype=emb.dtype)
    padded_emb = np.concatenate([emb, pad])
    return padded_emb


def create_embeddingwebdataset(
    urls,
    embedding_transform=lambda emb: emb,
    standard_seq_len=-1,
    to_tensor=True,
    enable_text=True,
    enable_meta=True,
):
    """
    Create a WebDataset reader for Frame Embedding Dataset
    Input:
        standard_seq_len: sequence length to pad all embedding sequences to (for batching)
            !(-1) : pad to standard_seq_len
            -1: don't pad (dataset can't be used in DataLoader with batch_size > 1)
        enable_text: include text captions
        enable_meta: include metadata
    """

    dataset = wds.WebDataset(urls)
    # TODO: different tokeinzers??
    tokenizer = lambda text: clip.tokenize([text], truncate=True)[0]

    def preprocess_dataset(item):
        output = {}

        npy_data = item["npy"]
        stream = io.BytesIO(npy_data)
        emb = np.lib.format.read_array(stream)

        if standard_seq_len != -1:
            emb = standardize_embedding_shape(emb, standard_seq_len)
        if to_tensor:
            emb = torch.from_numpy(emb)

        output["embeddings"] = embedding_transform(emb)

        if enable_text:
            text_data = item["cap"]
            text = text_data.decode("utf-8")
            output["text"] = text
            output["text_tokens"] = tokenizer(text)
        if enable_meta:
            meta_data = item["json"]
            meta = meta_data.decode("utf-8")
            output["meta"] = meta
        return output

    transformed_dataset = dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue)
    return transformed_dataset


def dataset_to_dataloader(dataset, batch_size, num_prepro_workers):
    """converts WebDataset to PyTorch DataLoader."""

    dl = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_prepro_workers,
        pin_memory=True,
        prefetch_factor=2,
    )

    return dl


class EmbeddingWebDatasetReader:
    """WebDataset reader for Embedding Datasets"""

    def __init__(
        self,
        urls,
        standard_seq_len,
        batch_size,
        num_prepro_workers,
        to_tensor=True,
        enable_text=True,
        enable_meta=False,
        embedding_transform=lambda emb: emb,
    ):
        self.batch_size = batch_size
        dataset = create_embeddingwebdataset(
            urls,
            embedding_transform,
            standard_seq_len,
            to_tensor,
            enable_text,
            enable_meta,
        )
        self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers)

    def __iter__(self):
        for batch in self.dataloader:
            yield batch

# Modeling

In [116]:
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return x


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
    
    
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim=None,
        dim_head=64,
        heads=8,
        parallel_ff=False,
        ff_mult=4,
        norm_context=False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head
        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        # whether to have parallel feedforward

        ff_inner_dim = ff_mult * dim

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_inner_dim * 2, bias=False),
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        ) if parallel_ff else None

    def forward(self, x, context):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # pre-layernorm, for queries and context

        x = self.norm(x)
        context = self.context_norm(context)

        # get queries

        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # scale

        q = q * self.scale

        # get key / values

        k, v = self.to_kv(context).chunk(2, dim=-1)

        # query / key similarity

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # attention

        sim = sim - sim.amax(dim=-1, keepdim=True)
        attn = sim.softmax(dim=-1)

        # aggregate

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # merge and combine heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        # add parallel feedforward (for multimodal layers)

        if exists(self.ff):
            out = out + self.ff(x)

        return out
    
    
class VideoPooler(nn.Module):
    def __init__(self, dim, context_dim, seq_len, heads, dim_head, proj_dim=None):
        super().__init__()
        self.pos_encoding = PositionalEncoding(dim)
        self.cls_token = nn.Parameter(torch.randn(dim))

        self.img_queries = nn.Parameter(torch.randn(seq_len + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning
        self.img_attn_pool = CrossAttention(dim=dim, context_dim=dim, dim_head=dim_head, heads=heads, norm_context=True)
        self.img_attn_pool_norm = LayerNorm(dim)
        
        self.proj = None if proj_dim is None else nn.Sequential(
            nn.Linear(dim, (dim+proj_dim)//2),
            nn.GELU(),
            nn.Linear((dim+proj_dim)//2, proj_dim),
        )

    def forward(self, x):
        
        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=x.shape[0])
        x = torch.cat((cls_tokens, x), dim=-2)
        
        x = self.pos_encoding(x)
        
        img_queries = repeat(self.img_queries, 'n d -> b n d', b=x.shape[0])
        img_queries = self.img_attn_pool(img_queries, x)
        img_queries = self.img_attn_pool_norm(img_queries)
        
        video_embedding = img_queries[:, 0]
        pred = video_embedding
        if self.proj is not None:
            pred = self.proj(video_embedding)

        return pred

In [117]:
DATA_DIR = "/home/iejmac/wds_kinetics"
splits = pd.read_csv(os.path.join(DATA_DIR, "splits.csv"))

In [118]:
train_tars = splits[splits["split"] == "train"]["tar_file"].tolist()
val_tars = splits[splits["split"] == "val"]["tar_file"].tolist()

train_tars_paths = [os.path.join(DATA_DIR, t + ".tar") for t in train_tars]
val_tars_paths = [os.path.join(DATA_DIR, t + ".tar") for t in val_tars]

In [119]:
all_labels = pd.read_csv(os.path.join(DATA_DIR, "annotations/train.csv"))["label"].unique().tolist()
len(all_labels)

700

# Training

In [120]:
# MODEL PARAMS:
DIM = 512
DEPTH = 12
HEADS = 8
DIM_HEAD = 128
MLP_DIM = 512
PROJ_DIM = 700
DROPOUT=0.0

In [121]:
pool = VideoPooler(
    dim=DIM,
    context_dim=DIM,
    seq_len=SEQ_LEN,
    heads=HEADS,
    dim_head=DIM_HEAD,
    proj_dim=PROJ_DIM,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
pool = pool.to(device)

In [122]:
# from open_clip/training/scheduler.py
def assign_learning_rate(optimizer, new_lr):
    for param_group in optimizer.param_groups:
        param_group["lr"] = new_lr


def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length


def cosine_lr(optimizer, base_lr, warmup_length, steps):
    def _lr_adjuster(step):
        if step < warmup_length:
            lr = _warmup_lr(base_lr, warmup_length, step)
        else:
            e = step - warmup_length
            es = steps - warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
        assign_learning_rate(optimizer, lr)
        return lr
    return _lr_adjuster

In [123]:
LR = 5e-4
WEIGHT_DECAY = 0.0
GRAD_CLIP = 1.0
LAMBDA = 0.8
EPOCHS = 20

WARMUP_STEPS = 1000
ALL_STEPS = 120000

SEQ_LEN = 25
BATCH_SIZE = 128
NUM_PREPRO = 6

step = 0

In [124]:
loss_f = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(pool.parameters(), lr=LR, weight_decay=0.0)

# lr_schedule = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda step: LAMBDA**step)
lr_schedule = cosine_lr(opt, LR, WARMUP_STEPS, ALL_STEPS)

In [125]:
val_reader = EmbeddingWebDatasetReader(
    urls=val_tars_paths,
    standard_seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    num_prepro_workers=NUM_PREPRO,
    to_tensor=True,
    enable_text=True,
    enable_meta=True,
    embedding_transform=lambda emb: emb,
)

train_reader = EmbeddingWebDatasetReader(
    urls=train_tars_paths,
    standard_seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    num_prepro_workers=NUM_PREPRO,
    to_tensor=True,
    enable_text=True,
    enable_meta=False,
    embedding_transform=lambda emb: emb,
)

In [136]:
running_loss = 0.0
for e in range(EPOCHS):
  print(f"Epoch {e}...")
  for i, b in enumerate(train_reader):
    step += 1
    lr_schedule(step)
    embeddings = b["embeddings"].float().to(device)
    labs = torch.Tensor([all_labels.index(l) for l in b["text"]]).long().to(device)

    opt.zero_grad()

    pred = pool(embeddings)
    loss = loss_f(pred, labs)

    loss.backward()
    
    # clip grads:
    torch.nn.utils.clip_grad_norm_(pool.parameters(), GRAD_CLIP)

    opt.step()

    running_loss += loss.item()

    if (step + 1) % 100 == 0:
      print(f"epoch {e} : step {step} average loss = {running_loss/100}")
      running_loss = 0.0

Epoch 0...
epoch 0 : step 73999 average loss = 0.6653468614816666
epoch 0 : step 74099 average loss = 0.9298540154099464
epoch 0 : step 74199 average loss = 0.9315595841407776
epoch 0 : step 74299 average loss = 0.9167524462938309
epoch 0 : step 74399 average loss = 0.9026793563365936
epoch 0 : step 74499 average loss = 0.8978964656591415
epoch 0 : step 74599 average loss = 0.9098487958312035
epoch 0 : step 74699 average loss = 0.8874560636281967
epoch 0 : step 74799 average loss = 0.9201191115379334
epoch 0 : step 74899 average loss = 0.9079224413633347
epoch 0 : step 74999 average loss = 0.9149649530649185
epoch 0 : step 75099 average loss = 0.9329281556606293
epoch 0 : step 75199 average loss = 0.8948884096741676
epoch 0 : step 75299 average loss = 0.9317408782243729
epoch 0 : step 75399 average loss = 0.9159096571803093
epoch 0 : step 75499 average loss = 0.9064727365970612
epoch 0 : step 75599 average loss = 0.9254825103282929
epoch 0 : step 75699 average loss = 0.9249770903587341

KeyboardInterrupt: 

In [137]:
correct = 0
all_ = 0
with torch.no_grad():
  for val_b in val_reader:
    embeddings = val_b["embeddings"].float().to(device)
    labs = torch.Tensor([all_labels.index(l) for l in val_b["text"]])

    pred = pool(embeddings).cpu()
    pred_cls = torch.argmax(pred, axis=-1)

    all_ += len(labs)
    correct += torch.sum(labs == pred_cls)

In [132]:
print(correct/all_)

tensor(0.4709)


In [135]:
print(correct/all_)

tensor(0.4825)


In [138]:
print(correct/all_)

tensor(0.4753)
