In [1]:
!pip install open_clip_torch

Collecting open_clip_torch
  Downloading open_clip_torch-2.32.0-py3-none-any.whl.metadata (31 kB)
Collecting ftfy (from open_clip_torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9.0->open_clip_torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9.0->open_clip_torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.9.0->open_clip_torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.9.0->open_clip_torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.9.0->open_clip_torch)
  Downloading nvidia_cusolv

In [2]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [3]:
import open_clip
from open_clip import tokenize
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
import json

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
model.to(device)

model_path = "/kaggle/input/clip_model/pytorch/default/1/clip_model.pt"
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints['model_state_dict'])


open_clip_model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]



<All keys matched successfully>

In [5]:
for param in model.parameters():
    param.requires_grad = False

In [6]:

for name, param in model.named_parameters():
    print(f"{name} norm: {param.norm().item()}")

positional_embedding norm: 2.0288569927215576
text_projection norm: 11.265134811401367
logit_scale norm: 4.605087757110596
visual.class_embedding norm: 5.109840393066406
visual.positional_embedding norm: 11.497537612915039
visual.proj norm: 12.942944526672363
visual.conv1.weight norm: 13.028203010559082
visual.ln_pre.weight norm: 21.016136169433594
visual.ln_pre.bias norm: 2.6796717643737793
visual.transformer.resblocks.0.ln_1.weight norm: 14.776494026184082
visual.transformer.resblocks.0.ln_1.bias norm: 3.969780683517456
visual.transformer.resblocks.0.attn.in_proj_weight norm: 17.87546157836914
visual.transformer.resblocks.0.attn.in_proj_bias norm: 35.37779235839844
visual.transformer.resblocks.0.attn.out_proj.weight norm: 12.688957214355469
visual.transformer.resblocks.0.attn.out_proj.bias norm: 2.594679355621338
visual.transformer.resblocks.0.ln_2.weight norm: 32.28483581542969
visual.transformer.resblocks.0.ln_2.bias norm: 4.974412441253662
visual.transformer.resblocks.0.mlp.c_fc.w

In [7]:
json_path="/kaggle/input/data-info1/output.json"
image_path="/kaggle/input/images/images/"

In [8]:
with open(json_path,"r",encoding="utf-8") as f:
    data=json.load(f)

image_list = []
title_list = []

for i in data:
    image_list.append(image_path + i['image'])
    title_list.append(i['caption'])
class MyDataset(Dataset):
    def __init__(self, list_image_path, list_txt):
        self.image_path = list_image_path
        self.title = tokenize(list_txt)
    def __len__(self):
        return len(self.image_path)
    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]).convert('RGB'))
        title = self.title[idx]
        return image, title


dataset=MyDataset(image_list,title_list)
dataloader=DataLoader(dataset, batch_size=4, shuffle=False)

In [9]:

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, q, k, v):
        attn_output, _ = self.multihead_attn(
            q.transpose(0, 1),
            k.transpose(0, 1),
            v.transpose(0, 1)
        )
        attn_output = attn_output.transpose(0, 1)
        q = q + self.dropout(attn_output)
        q = self.norm1(q)
        
        q2 = self.mlp(q)
        q = q + self.dropout(q2)
        q = self.norm2(q)
        
        return q

class CLIPCrossAttentionWrapper(nn.Module):
    def __init__(self, clip_model, embed_dim=768, num_heads=8, dropout=0.1):
        super().__init__()
        self.clip_model = clip_model
        self.clip_embed_dim = embed_dim  
        
        self.img_to_text_attn = CrossAttention(embed_dim, num_heads, dropout)
        
        self.text_to_img_attn = CrossAttention(embed_dim, num_heads, dropout)
        
        self.img_proj = nn.Linear(self.clip_model.visual.output_dim, embed_dim)
        self.text_proj = nn.Linear(self.clip_model.transformer.resblocks[-1].attn.out_proj.out_features, embed_dim)
        
        self.final_img_proj = nn.Linear(embed_dim, self.clip_model.visual.output_dim)
        self.final_text_proj = nn.Linear(embed_dim, self.clip_model.transformer.resblocks[-1].attn.out_proj.out_features)
        
        self.logit_scale = self.clip_model.logit_scale
        
    def forward(self, images, texts):
        image_features = self.clip_model.encode_image(images)
        text_features = self.clip_model.encode_text(texts)
        
        img_features_proj = self.img_proj(image_features.unsqueeze(1))  
        text_features_proj = self.text_proj(text_features.unsqueeze(1))  
        
        img_attended = self.img_to_text_attn(img_features_proj, text_features_proj, text_features_proj)
        text_attended = self.text_to_img_attn(text_features_proj, img_features_proj, img_features_proj)
        
        img_features_final = self.final_img_proj(img_attended.squeeze(1))
        text_features_final = self.final_text_proj(text_attended.squeeze(1))
        

        image_features = F.normalize(img_features_final, dim=-1)
        text_features = F.normalize(text_features_final, dim=-1)
        
        return image_features, text_features, self.logit_scale

def clip_loss(logits_per_image, logits_per_text):
    ground_truth = torch.arange(len(logits_per_image), dtype=torch.long, device=device)
    loss_i = F.cross_entropy(logits_per_image, ground_truth)
    loss_t = F.cross_entropy(logits_per_text, ground_truth)
    return (loss_i + loss_t) / 2

In [10]:
clip_model = CLIPCrossAttentionWrapper(model).to(device)
clip_model

CLIPCrossAttentionWrapper(
  (clip_model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (patch_dropout): Identity()
      (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-23): 24 x ResidualAttentionBlock(
            (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (ls_2): Identity

In [11]:

optimizer = torch.optim.AdamW([
    {'params':clip_model.img_to_text_attn.parameters()},
    {'params':clip_model.text_to_img_attn.parameters()},
    {'params':clip_model.img_proj.parameters()},
    {'params':clip_model.text_proj.parameters()},
    {'params':clip_model.final_img_proj.parameters()},
    {'params':clip_model.final_text_proj.parameters()}
], lr=1e-6)

In [12]:
def train(dataloader, model, optimizer, save_path=""):
    train_loader = tqdm(dataloader, total=len(dataloader))
    num = 0
    best_loss = float('inf')

    for batch in train_loader:
        num += 1
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts = texts.to(device)

        image_features, text_features, logit_scale = model(images, texts)
        
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        loss = clip_loss(logits_per_image, logits_per_text)

        loss.backward()
        optimizer.step()

        if num % 100 == 0:
            print(f"Step {num}, total_loss: {loss.item()}")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'step': num,
                'loss': best_loss
            }, save_path)
            
        if loss.item() < best_loss:
            print(f"New best loss: {loss.item()} (previous: {best_loss})")
            best_loss = loss.item()

train(dataloader, clip_model, optimizer, save_path="/kaggle/working/clip_cross_attention_model.pt")    

  0%|          | 1/16359 [00:01<7:52:48,  1.73s/it]

New best loss: 1.3757402896881104 (previous: inf)


  0%|          | 4/16359 [00:03<3:22:31,  1.35it/s]

New best loss: 1.3747957944869995 (previous: 1.3757402896881104)


  0%|          | 6/16359 [00:04<2:58:01,  1.53it/s]

New best loss: 1.3680009841918945 (previous: 1.3747957944869995)


  0%|          | 27/16359 [00:17<2:42:04,  1.68it/s]

New best loss: 1.3469129800796509 (previous: 1.3680009841918945)


  0%|          | 43/16359 [00:26<2:44:37,  1.65it/s]

New best loss: 1.3447465896606445 (previous: 1.3469129800796509)


  0%|          | 68/16359 [00:42<2:52:29,  1.57it/s]

New best loss: 1.343186855316162 (previous: 1.3447465896606445)


  1%|          | 99/16359 [01:02<2:58:47,  1.52it/s]

Step 100, total_loss: 1.3786380290985107


  1%|          | 158/16359 [01:46<3:12:17,  1.40it/s]

New best loss: 1.314701795578003 (previous: 1.343186855316162)


  1%|          | 199/16359 [02:16<3:29:06,  1.29it/s]

Step 200, total_loss: 1.383666753768921


  1%|▏         | 218/16359 [02:37<3:33:31,  1.26it/s]

New best loss: 1.300710916519165 (previous: 1.314701795578003)


  2%|▏         | 299/16359 [03:39<3:21:17,  1.33it/s]

Step 300, total_loss: 1.3797162771224976


  2%|▏         | 399/16359 [05:01<3:18:09,  1.34it/s]

Step 400, total_loss: 1.3965859413146973


  3%|▎         | 499/16359 [06:25<3:17:58,  1.34it/s]

Step 500, total_loss: 1.3923895359039307


  4%|▎         | 598/16359 [07:46<3:18:15,  1.32it/s]

New best loss: 1.2680070400238037 (previous: 1.300710916519165)


  4%|▎         | 599/16359 [07:47<3:17:34,  1.33it/s]

Step 600, total_loss: 1.3924286365509033


  4%|▍         | 699/16359 [09:09<3:18:34,  1.31it/s]

Step 700, total_loss: 1.3853884935379028


  5%|▍         | 799/16359 [10:32<3:23:03,  1.28it/s]

Step 800, total_loss: 1.378734827041626


  5%|▌         | 899/16359 [11:55<3:14:08,  1.33it/s]

Step 900, total_loss: 1.3760240077972412


  6%|▌         | 999/16359 [13:18<3:12:17,  1.33it/s]

Step 1000, total_loss: 1.3928844928741455


  6%|▌         | 1013/16359 [13:33<3:14:12,  1.32it/s]

New best loss: 1.2360457181930542 (previous: 1.2680070400238037)


  7%|▋         | 1099/16359 [14:40<3:12:25,  1.32it/s]

Step 1100, total_loss: 1.3813016414642334


  7%|▋         | 1113/16359 [14:56<3:11:08,  1.33it/s]

New best loss: 1.2029707431793213 (previous: 1.2360457181930542)


  7%|▋         | 1199/16359 [16:03<3:08:37,  1.34it/s]

Step 1200, total_loss: 1.3618552684783936


  8%|▊         | 1299/16359 [17:25<3:10:27,  1.32it/s]

Step 1300, total_loss: 1.3702573776245117


  9%|▊         | 1399/16359 [18:47<3:14:35,  1.28it/s]

Step 1400, total_loss: 1.3506860733032227


  9%|▉         | 1483/16359 [19:57<3:12:03,  1.29it/s]

New best loss: 1.186737298965454 (previous: 1.2029707431793213)


  9%|▉         | 1499/16359 [20:09<3:09:31,  1.31it/s]

Step 1500, total_loss: 1.3775677680969238


  9%|▉         | 1503/16359 [20:17<5:02:11,  1.22s/it]

New best loss: 1.1835392713546753 (previous: 1.186737298965454)


 10%|▉         | 1573/16359 [21:12<3:14:30,  1.27it/s]

New best loss: 1.130072832107544 (previous: 1.1835392713546753)


 10%|▉         | 1599/16359 [21:32<3:09:23,  1.30it/s]

Step 1600, total_loss: 1.3671098947525024


 10%|▉         | 1603/16359 [21:40<5:05:01,  1.24s/it]

New best loss: 1.1157636642456055 (previous: 1.130072832107544)


 10%|█         | 1688/16359 [22:47<3:07:04,  1.31it/s]

New best loss: 1.0736465454101562 (previous: 1.1157636642456055)


 10%|█         | 1699/16359 [22:55<3:17:45,  1.24it/s]

Step 1700, total_loss: 1.379514455795288


 11%|█         | 1799/16359 [24:20<3:11:05,  1.27it/s]

Step 1800, total_loss: 1.383650541305542


 11%|█▏        | 1878/16359 [25:27<3:12:43,  1.25it/s]

New best loss: 1.0656883716583252 (previous: 1.0736465454101562)


 12%|█▏        | 1899/16359 [25:44<3:08:51,  1.28it/s]

Step 1900, total_loss: 1.3818011283874512


 12%|█▏        | 1928/16359 [26:10<3:09:02,  1.27it/s]

New best loss: 1.0134153366088867 (previous: 1.0656883716583252)


 12%|█▏        | 1938/16359 [26:18<3:17:08,  1.22it/s]

New best loss: 0.9871490001678467 (previous: 1.0134153366088867)


 12%|█▏        | 1999/16359 [27:07<3:02:45,  1.31it/s]

Step 2000, total_loss: 1.3554348945617676


 13%|█▎        | 2099/16359 [28:30<3:13:51,  1.23it/s]

Step 2100, total_loss: 1.3708897829055786


 13%|█▎        | 2199/16359 [29:53<3:01:12,  1.30it/s]

Step 2200, total_loss: 1.3465462923049927


 14%|█▍        | 2299/16359 [31:15<2:58:58,  1.31it/s]

Step 2300, total_loss: 1.3609870672225952


 14%|█▍        | 2369/16359 [32:15<3:03:54,  1.27it/s]

New best loss: 0.9681435227394104 (previous: 0.9871490001678467)


 15%|█▍        | 2399/16359 [32:38<3:04:07,  1.26it/s]

Step 2400, total_loss: 1.3611416816711426


 15%|█▌        | 2483/16359 [33:47<2:56:56,  1.31it/s]

New best loss: 0.9173741340637207 (previous: 0.9681435227394104)


 15%|█▌        | 2499/16359 [34:00<2:55:54,  1.31it/s]

Step 2500, total_loss: 1.343855857849121


 16%|█▌        | 2578/16359 [35:05<2:56:40,  1.30it/s]

New best loss: 0.908983051776886 (previous: 0.9173741340637207)


 16%|█▌        | 2599/16359 [35:21<2:53:16,  1.32it/s]

Step 2600, total_loss: 1.3576991558074951


 16%|█▋        | 2699/16359 [36:43<2:56:07,  1.29it/s]

Step 2700, total_loss: 1.3509268760681152


 17%|█▋        | 2763/16359 [37:37<2:59:23,  1.26it/s]

New best loss: 0.8386680483818054 (previous: 0.908983051776886)


 17%|█▋        | 2799/16359 [38:05<2:50:36,  1.32it/s]

Step 2800, total_loss: 1.367249846458435


 18%|█▊        | 2899/16359 [39:27<2:51:37,  1.31it/s]

Step 2900, total_loss: 1.280134916305542


 18%|█▊        | 2953/16359 [40:13<2:56:39,  1.26it/s]

New best loss: 0.8324097394943237 (previous: 0.8386680483818054)


 18%|█▊        | 2998/16359 [40:47<2:48:13,  1.32it/s]

New best loss: 0.8305566310882568 (previous: 0.8324097394943237)


 18%|█▊        | 2999/16359 [40:48<2:49:00,  1.32it/s]

Step 3000, total_loss: 1.3744707107543945


 18%|█▊        | 3013/16359 [41:03<2:49:27,  1.31it/s]

New best loss: 0.7681369781494141 (previous: 0.8305566310882568)


 19%|█▉        | 3099/16359 [42:10<2:46:43,  1.33it/s]

Step 3100, total_loss: 1.3318116664886475


 20%|█▉        | 3199/16359 [43:32<2:45:04,  1.33it/s]

Step 3200, total_loss: 1.22950279712677


 20%|█▉        | 3228/16359 [43:58<2:47:31,  1.31it/s]

New best loss: 0.7537909746170044 (previous: 0.7681369781494141)


 20%|█▉        | 3268/16359 [44:30<2:51:46,  1.27it/s]

New best loss: 0.7497565746307373 (previous: 0.7537909746170044)


 20%|██        | 3278/16359 [44:37<2:47:41,  1.30it/s]

New best loss: 0.7394378185272217 (previous: 0.7497565746307373)


 20%|██        | 3299/16359 [44:53<2:44:36,  1.32it/s]

Step 3300, total_loss: 1.3118896484375


 21%|██        | 3393/16359 [46:11<2:43:24,  1.32it/s]

New best loss: 0.7227813005447388 (previous: 0.7394378185272217)


 21%|██        | 3399/16359 [46:15<2:43:05,  1.32it/s]

Step 3400, total_loss: 1.3383970260620117


 21%|██▏       | 3478/16359 [47:20<2:46:23,  1.29it/s]

New best loss: 0.7119156718254089 (previous: 0.7227813005447388)


 21%|██▏       | 3499/16359 [47:36<2:42:53,  1.32it/s]

Step 3500, total_loss: 1.3579866886138916


 22%|██▏       | 3599/16359 [48:58<2:41:06,  1.32it/s]

Step 3600, total_loss: 1.3281282186508179


 23%|██▎       | 3688/16359 [50:11<2:41:05,  1.31it/s]

New best loss: 0.7098078727722168 (previous: 0.7119156718254089)


 23%|██▎       | 3699/16359 [50:19<2:39:51,  1.32it/s]

Step 3700, total_loss: 1.348804235458374


 23%|██▎       | 3799/16359 [51:41<2:39:38,  1.31it/s]

Step 3800, total_loss: 1.336711049079895


 24%|██▍       | 3899/16359 [53:02<2:37:40,  1.32it/s]

Step 3900, total_loss: 1.334659457206726


 24%|██▍       | 3999/16359 [54:25<2:35:30,  1.32it/s]

Step 4000, total_loss: 1.3005081415176392


 25%|██▌       | 4099/16359 [55:47<2:37:42,  1.30it/s]

Step 4100, total_loss: 1.26384699344635


 26%|██▌       | 4199/16359 [57:09<2:32:33,  1.33it/s]

Step 4200, total_loss: 1.3468416929244995


 26%|██▋       | 4299/16359 [58:32<2:31:09,  1.33it/s]

Step 4300, total_loss: 1.2766938209533691


 27%|██▋       | 4358/16359 [59:22<2:41:08,  1.24it/s]

New best loss: 0.6805641055107117 (previous: 0.7098078727722168)


 27%|██▋       | 4399/16359 [59:54<2:30:23,  1.33it/s]

Step 4400, total_loss: 1.3146283626556396


 28%|██▊       | 4499/16359 [1:01:16<2:30:11,  1.32it/s]

Step 4500, total_loss: 1.284731149673462


 28%|██▊       | 4599/16359 [1:02:38<2:30:07,  1.31it/s]

Step 4600, total_loss: 1.226114273071289


 28%|██▊       | 4638/16359 [1:03:13<2:35:25,  1.26it/s]

New best loss: 0.6608313918113708 (previous: 0.6805641055107117)


 29%|██▊       | 4699/16359 [1:04:00<2:27:53,  1.31it/s]

Step 4700, total_loss: 1.2402360439300537


 29%|██▉       | 4799/16359 [1:05:22<2:26:43,  1.31it/s]

Step 4800, total_loss: 1.2763668298721313


 30%|██▉       | 4899/16359 [1:06:44<2:25:59,  1.31it/s]

Step 4900, total_loss: 1.298356056213379


 30%|███       | 4953/16359 [1:07:31<2:32:25,  1.25it/s]

New best loss: 0.58295738697052 (previous: 0.6608313918113708)


 31%|███       | 4999/16359 [1:08:06<2:22:25,  1.33it/s]

Step 5000, total_loss: 1.2730212211608887


 31%|███       | 5099/16359 [1:09:27<2:22:26,  1.32it/s]

Step 5100, total_loss: 1.216477632522583


 32%|███▏      | 5199/16359 [1:10:49<2:21:01,  1.32it/s]

Step 5200, total_loss: 1.362165093421936


 32%|███▏      | 5299/16359 [1:12:11<2:18:14,  1.33it/s]

Step 5300, total_loss: 1.2534687519073486


 33%|███▎      | 5399/16359 [1:13:33<2:17:47,  1.33it/s]

Step 5400, total_loss: 1.216454267501831


 34%|███▎      | 5499/16359 [1:14:56<2:17:19,  1.32it/s]

Step 5500, total_loss: 1.3094244003295898


 34%|███▍      | 5599/16359 [1:16:18<2:16:17,  1.32it/s]

Step 5600, total_loss: 1.2208542823791504


 35%|███▍      | 5699/16359 [1:17:40<2:14:47,  1.32it/s]

Step 5700, total_loss: 1.2099196910858154


 35%|███▌      | 5799/16359 [1:19:02<2:12:44,  1.33it/s]

Step 5800, total_loss: 1.0163447856903076


 36%|███▌      | 5899/16359 [1:20:24<2:12:52,  1.31it/s]

Step 5900, total_loss: 1.1360101699829102


 37%|███▋      | 5999/16359 [1:21:46<2:10:53,  1.32it/s]

Step 6000, total_loss: 1.2617275714874268


 37%|███▋      | 6074/16359 [1:22:49<2:12:22,  1.30it/s]

New best loss: 0.568085789680481 (previous: 0.58295738697052)


 37%|███▋      | 6099/16359 [1:23:08<2:08:43,  1.33it/s]

Step 6100, total_loss: 1.3137774467468262


 38%|███▊      | 6199/16359 [1:24:29<2:08:32,  1.32it/s]

Step 6200, total_loss: 1.3609641790390015


 39%|███▊      | 6299/16359 [1:25:51<2:06:27,  1.33it/s]

Step 6300, total_loss: 1.1375641822814941


 39%|███▊      | 6314/16359 [1:26:07<2:06:06,  1.33it/s]

New best loss: 0.5246502161026001 (previous: 0.568085789680481)


 39%|███▉      | 6399/16359 [1:27:12<2:04:47,  1.33it/s]

Step 6400, total_loss: 1.1908892393112183


 40%|███▉      | 6478/16359 [1:28:18<2:06:24,  1.30it/s]

New best loss: 0.4982089698314667 (previous: 0.5246502161026001)


 40%|███▉      | 6499/16359 [1:28:34<2:05:58,  1.30it/s]

Step 6500, total_loss: 1.1792371273040771


 40%|████      | 6599/16359 [1:29:55<2:03:09,  1.32it/s]

Step 6600, total_loss: 1.1769956350326538


 41%|████      | 6699/16359 [1:31:16<2:01:54,  1.32it/s]

Step 6700, total_loss: 1.0939701795578003


 41%|████      | 6748/16359 [1:31:58<2:06:17,  1.27it/s]

New best loss: 0.495275616645813 (previous: 0.4982089698314667)


 42%|████▏     | 6799/16359 [1:32:38<2:00:05,  1.33it/s]

Step 6800, total_loss: 1.1050002574920654


 42%|████▏     | 6899/16359 [1:33:59<1:58:53,  1.33it/s]

Step 6900, total_loss: 1.0919311046600342


 43%|████▎     | 6998/16359 [1:35:20<1:57:41,  1.33it/s]

New best loss: 0.4596462845802307 (previous: 0.495275616645813)


 43%|████▎     | 6999/16359 [1:35:21<1:57:47,  1.32it/s]

Step 7000, total_loss: 0.8239762783050537


 43%|████▎     | 7083/16359 [1:36:30<1:58:02,  1.31it/s]

New best loss: 0.4382830560207367 (previous: 0.4596462845802307)


 43%|████▎     | 7099/16359 [1:36:42<1:57:30,  1.31it/s]

Step 7100, total_loss: 1.1059041023254395


 44%|████▎     | 7129/16359 [1:37:09<1:58:14,  1.30it/s]

New best loss: 0.43080049753189087 (previous: 0.4382830560207367)


 44%|████▍     | 7199/16359 [1:38:04<1:56:37,  1.31it/s]

Step 7200, total_loss: 1.2304344177246094


 45%|████▍     | 7299/16359 [1:39:25<1:54:11,  1.32it/s]

Step 7300, total_loss: 1.1427186727523804


 45%|████▌     | 7399/16359 [1:40:47<1:53:03,  1.32it/s]

Step 7400, total_loss: 1.1113924980163574


 46%|████▌     | 7499/16359 [1:42:08<1:51:15,  1.33it/s]

Step 7500, total_loss: 1.0370632410049438


 46%|████▋     | 7599/16359 [1:43:29<1:50:17,  1.32it/s]

Step 7600, total_loss: 0.7377674579620361


 47%|████▋     | 7699/16359 [1:44:51<1:48:55,  1.33it/s]

Step 7700, total_loss: 1.2778041362762451


 48%|████▊     | 7799/16359 [1:46:13<1:47:30,  1.33it/s]

Step 7800, total_loss: 1.3023353815078735


 48%|████▊     | 7899/16359 [1:47:34<1:47:50,  1.31it/s]

Step 7900, total_loss: 1.2265980243682861


 49%|████▉     | 7999/16359 [1:48:56<1:45:32,  1.32it/s]

Step 8000, total_loss: 1.1550544500350952


 50%|████▉     | 8099/16359 [1:50:18<1:43:46,  1.33it/s]

Step 8100, total_loss: 1.0486305952072144


 50%|█████     | 8199/16359 [1:51:40<1:43:00,  1.32it/s]

Step 8200, total_loss: 1.327019453048706


 51%|█████     | 8299/16359 [1:53:01<1:41:25,  1.32it/s]

Step 8300, total_loss: 0.9464402198791504


 51%|█████▏    | 8399/16359 [1:54:23<1:40:23,  1.32it/s]

Step 8400, total_loss: 0.8139749765396118


 52%|█████▏    | 8463/16359 [1:55:17<1:42:33,  1.28it/s]

New best loss: 0.35412347316741943 (previous: 0.43080049753189087)


 52%|█████▏    | 8499/16359 [1:55:44<1:39:06,  1.32it/s]

Step 8500, total_loss: 0.9561963081359863


 53%|█████▎    | 8599/16359 [1:57:06<1:39:09,  1.30it/s]

Step 8600, total_loss: 1.114651083946228


 53%|█████▎    | 8699/16359 [1:58:29<1:36:46,  1.32it/s]

Step 8700, total_loss: 0.874369204044342


 53%|█████▎    | 8723/16359 [1:58:52<1:37:43,  1.30it/s]

New best loss: 0.32611382007598877 (previous: 0.35412347316741943)


 54%|█████▍    | 8799/16359 [1:59:51<1:35:29,  1.32it/s]

Step 8800, total_loss: 1.0362887382507324


 54%|█████▍    | 8899/16359 [2:01:13<1:34:24,  1.32it/s]

Step 8900, total_loss: 0.9726827144622803


 55%|█████▍    | 8928/16359 [2:01:40<1:36:36,  1.28it/s]

New best loss: 0.32181516289711 (previous: 0.32611382007598877)


 55%|█████▌    | 8999/16359 [2:02:36<1:33:09,  1.32it/s]

Step 9000, total_loss: 1.040514349937439


 55%|█████▌    | 9023/16359 [2:02:58<1:33:39,  1.31it/s]

New best loss: 0.31167861819267273 (previous: 0.32181516289711)


 56%|█████▌    | 9099/16359 [2:03:58<1:32:29,  1.31it/s]

Step 9100, total_loss: 0.9111467003822327


 56%|█████▌    | 9199/16359 [2:05:20<1:29:55,  1.33it/s]

Step 9200, total_loss: 0.9494462013244629


 57%|█████▋    | 9298/16359 [2:06:41<1:29:24,  1.32it/s]

New best loss: 0.30221834778785706 (previous: 0.31167861819267273)


 57%|█████▋    | 9299/16359 [2:06:42<1:29:34,  1.31it/s]

Step 9300, total_loss: 0.8474338054656982


 57%|█████▋    | 9398/16359 [2:08:04<1:28:21,  1.31it/s]

New best loss: 0.22626137733459473 (previous: 0.30221834778785706)


 57%|█████▋    | 9399/16359 [2:08:04<1:28:27,  1.31it/s]

Step 9400, total_loss: 1.0693671703338623


 58%|█████▊    | 9499/16359 [2:09:26<1:26:32,  1.32it/s]

Step 9500, total_loss: 0.8908640146255493


 59%|█████▊    | 9599/16359 [2:10:48<1:25:35,  1.32it/s]

Step 9600, total_loss: 1.3473080396652222


 59%|█████▉    | 9699/16359 [2:12:11<1:23:43,  1.33it/s]

Step 9700, total_loss: 0.9428017139434814


 60%|█████▉    | 9799/16359 [2:13:33<1:23:15,  1.31it/s]

Step 9800, total_loss: 0.9770540595054626


 61%|██████    | 9899/16359 [2:14:54<1:21:14,  1.33it/s]

Step 9900, total_loss: 1.092768907546997


 61%|██████    | 9999/16359 [2:16:16<1:19:26,  1.33it/s]

Step 10000, total_loss: 0.8600308299064636


 62%|██████▏   | 10099/16359 [2:17:39<1:17:48,  1.34it/s]

Step 10100, total_loss: 0.9482225179672241


 62%|██████▏   | 10199/16359 [2:19:02<1:19:09,  1.30it/s]

Step 10200, total_loss: 0.5256328582763672


 63%|██████▎   | 10299/16359 [2:20:26<1:17:56,  1.30it/s]

Step 10300, total_loss: 0.9393396377563477


 64%|██████▎   | 10399/16359 [2:21:50<1:16:45,  1.29it/s]

Step 10400, total_loss: 0.7167413234710693


 64%|██████▍   | 10499/16359 [2:23:15<1:15:24,  1.30it/s]

Step 10500, total_loss: 0.6499214172363281


 65%|██████▍   | 10599/16359 [2:24:41<1:16:36,  1.25it/s]

Step 10600, total_loss: 0.6889206171035767


 65%|██████▌   | 10699/16359 [2:26:05<1:13:00,  1.29it/s]

Step 10700, total_loss: 0.4301990270614624


 66%|██████▌   | 10799/16359 [2:27:29<1:12:25,  1.28it/s]

Step 10800, total_loss: 0.4852668046951294


 67%|██████▋   | 10899/16359 [2:28:53<1:10:12,  1.30it/s]

Step 10900, total_loss: 0.6188136339187622


 67%|██████▋   | 10999/16359 [2:30:17<1:08:28,  1.30it/s]

Step 11000, total_loss: 0.9056320190429688


 68%|██████▊   | 11099/16359 [2:31:41<1:07:20,  1.30it/s]

Step 11100, total_loss: 0.5256688594818115


 68%|██████▊   | 11199/16359 [2:33:04<1:06:55,  1.28it/s]

Step 11200, total_loss: 0.45544418692588806


 69%|██████▉   | 11299/16359 [2:34:28<1:04:51,  1.30it/s]

Step 11300, total_loss: 0.4283469319343567


 70%|██████▉   | 11399/16359 [2:35:53<1:04:22,  1.28it/s]

Step 11400, total_loss: 0.6565444469451904


 70%|███████   | 11471/16359 [2:36:57<1:04:08,  1.27it/s]

New best loss: 0.21773403882980347 (previous: 0.22626137733459473)


 70%|███████   | 11499/16359 [2:37:18<1:02:32,  1.30it/s]

Step 11500, total_loss: 0.32162201404571533


 71%|███████   | 11599/16359 [2:38:42<1:01:39,  1.29it/s]

Step 11600, total_loss: 0.5340944528579712


 71%|███████   | 11627/16359 [2:39:08<1:02:28,  1.26it/s]

New best loss: 0.2080584466457367 (previous: 0.21773403882980347)


 71%|███████▏  | 11696/16359 [2:40:03<59:56,  1.30it/s]  

New best loss: 0.20360788702964783 (previous: 0.2080584466457367)


 72%|███████▏  | 11699/16359 [2:40:05<59:32,  1.30it/s]

Step 11700, total_loss: 0.34146565198898315


 72%|███████▏  | 11799/16359 [2:41:28<59:04,  1.29it/s]

Step 11800, total_loss: 0.49316495656967163


 73%|███████▎  | 11899/16359 [2:42:52<57:45,  1.29it/s]

Step 11900, total_loss: 0.4100702404975891


 73%|███████▎  | 11999/16359 [2:44:15<56:13,  1.29it/s]

Step 12000, total_loss: 0.3435804545879364


 74%|███████▍  | 12099/16359 [2:45:39<54:55,  1.29it/s]

Step 12100, total_loss: 0.3210922181606293


 74%|███████▍  | 12151/16359 [2:46:25<56:47,  1.24it/s]

New best loss: 0.19677144289016724 (previous: 0.20360788702964783)


 75%|███████▍  | 12199/16359 [2:47:02<53:39,  1.29it/s]

Step 12200, total_loss: 0.33322465419769287


 75%|███████▍  | 12223/16359 [2:47:25<53:55,  1.28it/s]

New best loss: 0.18837827444076538 (previous: 0.19677144289016724)


 75%|███████▌  | 12299/16359 [2:48:26<52:26,  1.29it/s]

Step 12300, total_loss: 0.46520674228668213


 75%|███████▌  | 12322/16359 [2:48:49<52:41,  1.28it/s]

New best loss: 0.15346843004226685 (previous: 0.18837827444076538)


 76%|███████▌  | 12399/16359 [2:49:50<51:03,  1.29it/s]

Step 12400, total_loss: 0.5296943187713623


 76%|███████▋  | 12499/16359 [2:51:13<49:53,  1.29it/s]

Step 12500, total_loss: 0.35296744108200073


 77%|███████▋  | 12599/16359 [2:52:41<52:16,  1.20it/s]

Step 12600, total_loss: 0.437079519033432


 78%|███████▊  | 12699/16359 [2:54:06<48:39,  1.25it/s]

Step 12700, total_loss: 0.4206457734107971


 78%|███████▊  | 12799/16359 [2:55:31<46:17,  1.28it/s]

Step 12800, total_loss: 0.319267600774765


 79%|███████▊  | 12862/16359 [2:56:26<47:02,  1.24it/s]

New best loss: 0.15173965692520142 (previous: 0.15346843004226685)


 79%|███████▊  | 12876/16359 [2:56:38<47:13,  1.23it/s]

New best loss: 0.13342371582984924 (previous: 0.15173965692520142)


 79%|███████▉  | 12899/16359 [2:56:56<46:39,  1.24it/s]

Step 12900, total_loss: 0.615349292755127


 79%|███████▉  | 12999/16359 [2:58:21<43:43,  1.28it/s]

Step 13000, total_loss: 0.3913733959197998


 80%|████████  | 13099/16359 [2:59:45<42:30,  1.28it/s]

Step 13100, total_loss: 0.37930992245674133


 81%|████████  | 13199/16359 [3:01:09<40:44,  1.29it/s]

Step 13200, total_loss: 0.42350998520851135


 81%|████████▏ | 13299/16359 [3:02:33<39:32,  1.29it/s]

Step 13300, total_loss: 0.47047001123428345


 82%|████████▏ | 13387/16359 [3:03:48<38:24,  1.29it/s]

New best loss: 0.12200698256492615 (previous: 0.13342371582984924)


 82%|████████▏ | 13399/16359 [3:03:57<38:09,  1.29it/s]

Step 13400, total_loss: 0.24855850636959076


 83%|████████▎ | 13499/16359 [3:05:21<37:00,  1.29it/s]

Step 13500, total_loss: 0.2729440927505493


 83%|████████▎ | 13599/16359 [3:06:45<35:46,  1.29it/s]

Step 13600, total_loss: 0.32561129331588745


 84%|████████▎ | 13699/16359 [3:08:09<34:18,  1.29it/s]

Step 13700, total_loss: 0.38716012239456177


 84%|████████▍ | 13799/16359 [3:09:33<33:08,  1.29it/s]

Step 13800, total_loss: 0.4721275866031647


 85%|████████▍ | 13899/16359 [3:10:57<31:56,  1.28it/s]

Step 13900, total_loss: 0.4926617741584778


 86%|████████▌ | 13999/16359 [3:12:21<30:34,  1.29it/s]

Step 14000, total_loss: 0.23482754826545715


 86%|████████▌ | 14036/16359 [3:12:55<32:29,  1.19it/s]

New best loss: 0.11638238281011581 (previous: 0.12200698256492615)


 86%|████████▌ | 14047/16359 [3:13:05<34:04,  1.13it/s]

New best loss: 0.08099118620157242 (previous: 0.11638238281011581)


 86%|████████▌ | 14099/16359 [3:13:47<29:35,  1.27it/s]

Step 14100, total_loss: 0.3238750696182251


 87%|████████▋ | 14199/16359 [3:15:11<28:02,  1.28it/s]

Step 14200, total_loss: 0.36124956607818604


 87%|████████▋ | 14299/16359 [3:16:34<26:40,  1.29it/s]

Step 14300, total_loss: 0.34691521525382996


 88%|████████▊ | 14399/16359 [3:17:59<25:12,  1.30it/s]

Step 14400, total_loss: 0.346305787563324


 89%|████████▊ | 14499/16359 [3:19:22<23:49,  1.30it/s]

Step 14500, total_loss: 0.17176693677902222


 89%|████████▉ | 14599/16359 [3:20:46<22:40,  1.29it/s]

Step 14600, total_loss: 0.13653604686260223


 90%|████████▉ | 14699/16359 [3:22:10<21:11,  1.31it/s]

Step 14700, total_loss: 0.1740751564502716


 90%|█████████ | 14799/16359 [3:23:34<20:09,  1.29it/s]

Step 14800, total_loss: 0.5511655807495117


 91%|█████████ | 14899/16359 [3:24:58<18:53,  1.29it/s]

Step 14900, total_loss: 0.16170796751976013


 92%|█████████▏| 14996/16359 [3:26:20<17:27,  1.30it/s]

New best loss: 0.0809800773859024 (previous: 0.08099118620157242)


 92%|█████████▏| 14999/16359 [3:26:22<17:20,  1.31it/s]

Step 15000, total_loss: 0.18096435070037842


 92%|█████████▏| 15099/16359 [3:27:45<16:44,  1.25it/s]

Step 15100, total_loss: 0.3162330985069275


 93%|█████████▎| 15199/16359 [3:29:15<16:20,  1.18it/s]

Step 15200, total_loss: 0.2895810604095459


 94%|█████████▎| 15299/16359 [3:30:44<14:19,  1.23it/s]

Step 15300, total_loss: 0.2715851068496704


 94%|█████████▍| 15399/16359 [3:32:09<12:13,  1.31it/s]

Step 15400, total_loss: 0.23320619761943817


 95%|█████████▍| 15499/16359 [3:33:33<11:06,  1.29it/s]

Step 15500, total_loss: 0.16817814111709595


 95%|█████████▌| 15599/16359 [3:34:57<09:43,  1.30it/s]

Step 15600, total_loss: 0.36170694231987


 96%|█████████▌| 15694/16359 [3:36:17<08:34,  1.29it/s]

New best loss: 0.08010323345661163 (previous: 0.0809800773859024)


 96%|█████████▌| 15699/16359 [3:36:21<08:30,  1.29it/s]

Step 15700, total_loss: 0.3156900405883789


 96%|█████████▋| 15776/16359 [3:37:27<07:39,  1.27it/s]

New best loss: 0.0668879896402359 (previous: 0.08010323345661163)


 97%|█████████▋| 15799/16359 [3:37:45<07:13,  1.29it/s]

Step 15800, total_loss: 0.1285453736782074


 97%|█████████▋| 15899/16359 [3:39:09<05:56,  1.29it/s]

Step 15900, total_loss: 0.1322653591632843


 98%|█████████▊| 15999/16359 [3:40:35<04:55,  1.22it/s]

Step 16000, total_loss: 0.128218412399292


 98%|█████████▊| 16099/16359 [3:42:00<03:22,  1.28it/s]

Step 16100, total_loss: 0.18470236659049988


 99%|█████████▉| 16199/16359 [3:43:25<02:05,  1.27it/s]

Step 16200, total_loss: 0.26308974623680115


100%|█████████▉| 16299/16359 [3:44:50<00:48,  1.25it/s]

Step 16300, total_loss: 0.15562966465950012


100%|██████████| 16359/16359 [3:45:44<00:00,  1.21it/s]


In [13]:

for name, param in clip_model.named_parameters():
    print(f"{name} norm: {param.norm().item()}")

logit_scale norm: 4.605087757110596
clip_model.positional_embedding norm: 2.0288569927215576
clip_model.text_projection norm: 11.265134811401367
clip_model.visual.class_embedding norm: 5.109840393066406
clip_model.visual.positional_embedding norm: 11.497537612915039
clip_model.visual.proj norm: 12.942944526672363
clip_model.visual.conv1.weight norm: 13.028203010559082
clip_model.visual.ln_pre.weight norm: 21.016136169433594
clip_model.visual.ln_pre.bias norm: 2.6796717643737793
clip_model.visual.transformer.resblocks.0.ln_1.weight norm: 14.776494026184082
clip_model.visual.transformer.resblocks.0.ln_1.bias norm: 3.969780683517456
clip_model.visual.transformer.resblocks.0.attn.in_proj_weight norm: 17.87546157836914
clip_model.visual.transformer.resblocks.0.attn.in_proj_bias norm: 35.37779235839844
clip_model.visual.transformer.resblocks.0.attn.out_proj.weight norm: 12.688957214355469
clip_model.visual.transformer.resblocks.0.attn.out_proj.bias norm: 2.594679355621338
clip_model.visual.t