In [5]:
import sys
from pathlib import Path
ROOT = Path.cwd().parent.parent
sys.path.append(str(ROOT))

In [6]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
import time

In [12]:
from models.text_encoders import ClipDistilBert, ClipMiniLM
from models.vision_encoders import ClipVitTiny, ClipVitSmall
from models.ResNet50 import ClipResNet50

In [17]:
def log_text_model_information(model_class, output_dim=256, device="cuda"):
    print(f"MODEL: {model_class.__name__}")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    mem_before = torch.cuda.memory_allocated()
    model = model_class(output_dim).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters: {total_params:,}")
    mem_after = torch.cuda.memory_allocated()
    print(f"GPU Memory Before Loading: {mem_before / 1e6:.2f} MB")
    print(f"GPU Memory After Loading: {mem_after / 1e6:.2f} MB")
    print(f"Model Memory Footprint: {(mem_after - mem_before) / 1e6:.2f} MB\n")
    caption = "This is a fake image."
    t0 = time.perf_counter()
    fake_caption = model.tokenizer(
        caption,
        padding="max_length",
        truncation=True,
        max_length=32,
        return_tensors="pt"
    ).to(device)
    t1 = time.perf_counter()
    with torch.autocast("cuda", dtype=torch.bfloat16):
        _ = model(**fake_caption)
    t2 = time.perf_counter()
    print(f"Tokenization Time: {(t1 - t0)*1000:.2f} ms")
    print(f"Forward Pass Time (BF16): {(t2 - t1)*1000:.2f} ms")
    print(f"Total Pipeline Time: {(t2 - t0)*1000:.2f} ms")
    peak_mem = torch.cuda.max_memory_allocated()
    print(f"Peak GPU Memory Used During Run: {peak_mem / 1e6:.2f} MB\n")
    return {
        "params": total_params,
        "mem_before": mem_before,
        "mem_after": mem_after,
        "peak_mem": peak_mem,
        "tokenize_time_ms": (t1 - t0)*1000,
        "forward_time_ms": (t2 - t1)*1000,
        "total_time_ms": (t2 - t0)*1000,
    }

def log_vision_encoder_info(encoder_class, output_dim=256, device="cuda"):
    print(f"VISION MODEL: {encoder_class.__name__}")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    mem_before = torch.cuda.memory_allocated()
    model = encoder_class(output_dim).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters: {total_params:,}")
    mem_after = torch.cuda.memory_allocated()
    print(f"GPU Memory Before Loading: {mem_before / 1e6:.2f} MB")
    print(f"GPU Memory After Loading: {mem_after / 1e6:.2f} MB")
    print(f"Model Memory Footprint: {(mem_after - mem_before) / 1e6:.2f} MB\n")
    raw_image = torch.randn(1, 3, 224, 224).to(device)
    t0 = time.perf_counter()
    t1 = time.perf_counter()
    with torch.autocast("cuda", dtype=torch.bfloat16):
        _ = model(raw_image)
    t2 = time.perf_counter()
    print(f"Preprocessing Time: {(t1 - t0)*1000:.2f} ms")
    print(f"Forward Pass Time (BF16): {(t2 - t1)*1000:.2f} ms")
    print(f"Total Pipeline Time: {(t2 - t0)*1000:.2f} ms")
    peak_mem = torch.cuda.max_memory_allocated()
    print(f"Peak GPU Memory Used During Run: {peak_mem / 1e6:.2f} MB\n")
    return {
        "params": total_params,
        "mem_before": mem_before,
        "mem_after": mem_after,
        "peak_mem": peak_mem,
        "preprocess_ms": (t1 - t0)*1000,
        "forward_ms": (t2 - t1)*1000,
        "total_ms": (t2 - t0)*1000,
    }

def log_resnet_encoder_info(resnet_class, output_dim=256, device="cuda"):
    print(f"VISION MODEL: {resnet_class.__name__}")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    mem_before = torch.cuda.memory_allocated()
    model = resnet_class(output_dim).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters: {total_params:,}")
    mem_after = torch.cuda.memory_allocated()
    print(f"GPU Memory Before Loading: {mem_before / 1e6:.2f} MB")
    print(f"GPU Memory After Loading: {mem_after / 1e6:.2f} MB")
    print(f"Model Memory Footprint: {(mem_after - mem_before) / 1e6:.2f} MB\n")
    raw_image = (torch.rand(224, 224, 3) * 255).to(torch.uint8)
    t0 = time.perf_counter()
    processed = raw_image.permute(2, 0, 1).float() / 255.0
    processed = processed.unsqueeze(0).to(device)
    t1 = time.perf_counter()
    with torch.autocast("cuda", dtype=torch.bfloat16):
        _ = model(processed)
    t2 = time.perf_counter()
    print(f"Preprocessing Time: {(t1 - t0)*1000:.2f} ms")
    print(f"Forward Pass Time (BF16): {(t2 - t1)*1000:.2f} ms")
    print(f"Total Pipeline Time: {(t2 - t0)*1000:.2f} ms")
    peak_mem = torch.cuda.max_memory_allocated()
    print(f"Peak GPU Memory Used During Run: {peak_mem / 1e6:.2f} MB\n")
    return {
        "params": total_params,
        "mem_before": mem_before,
        "mem_after": mem_after,
        "peak_mem": peak_mem,
        "preprocess_ms": (t1 - t0)*1000,
        "forward_ms": (t2 - t1)*1000,
        "total_ms": (t2 - t0)*1000,
    }

def log_information():
    for text_encoder in [ClipDistilBert, ClipMiniLM]:
        log_text_model_information(text_encoder)
    for vision_encoder in [ClipVitTiny, ClipVitSmall]:
        log_vision_encoder_info(vision_encoder)
    #log_resnet_encoder_info(ClipResNet50)
    

In [18]:
log_information()

MODEL: ClipDistilBert
Total Parameters: 66,559,744
GPU Memory Before Loading: 77.73 MB
GPU Memory After Loading: 344.58 MB
Model Memory Footprint: 266.85 MB

Tokenization Time: 1.74 ms
Forward Pass Time (BF16): 29.15 ms
Total Pipeline Time: 30.90 ms
Peak GPU Memory Used During Run: 446.48 MB

MODEL: ClipMiniLM
Total Parameters: 22,811,776
GPU Memory Before Loading: 77.73 MB
GPU Memory After Loading: 169.64 MB
Model Memory Footprint: 91.91 MB

Tokenization Time: 0.90 ms
Forward Pass Time (BF16): 11.93 ms
Total Pipeline Time: 12.84 ms
Peak GPU Memory Used During Run: 195.59 MB

VISION MODEL: ClipVitTiny
Total Parameters: 5,573,824
GPU Memory Before Loading: 77.73 MB
GPU Memory After Loading: 100.04 MB
Model Memory Footprint: 22.32 MB

Preprocessing Time: 0.00 ms
Forward Pass Time (BF16): 15.06 ms
Total Pipeline Time: 15.07 ms
Peak GPU Memory Used During Run: 128.76 MB

VISION MODEL: ClipVitSmall
Total Parameters: 21,764,224
GPU Memory Before Loading: 77.73 MB
GPU Memory After Loading: 16

### Text Encoder Comparison

DistilBert vs MiniLM
Parameters: 66M vs 22M
Forward pass time: 29.15 vs 11.93 ms
Memory Footprint: 266.85 vs 91.91 MB

Due to the size of MiniLM, it may struggle to have a similar expressive capability as DistilBert.

MiniLM is 2.4x faster and 2.9x lighter in memory.

### Vision Encoder Comparison

VitTiny vs VitSmall
Parameters: 5.5M vs 21.8M
Forward Pass: 15.06 ms vs 36.51ms
Memory Footprint: 22.32 MB vs 87.71 MB

VitTiny is extremely small which is potentially concerning. Most likely, ViTSmall will end up being used or even potentially upscaling to the next size.

ViTTiny is 2.4x faster and 3.9x lighter in memory. If ViTTiny works, it will be very efficient.

### Determining if a model works

1. Compare the gradient norms. If one encoder receives much smaller gradients than the other, its capacity is too small and the other encoder is overcompensating.
2. Can Freeze one encoder and see how the loss changes. If the loss barely decreases, the non-frozen model is too small.

In [None]:
def quick_train_cycle(model):
    pass