In [1]:
import utilities
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import open_clip
import os


In [2]:
class Head(nn.Module):
    def __init__(self, hidden_size, slim):
        super(Head, self).__init__()
        self.emb = nn.Linear(hidden_size, CFG.emb_size, bias=False)
        self.slim = slim
        if not slim:
            self.arc = utilities.ArcMarginProduct_subcenter(CFG.emb_size, CFG.n_classes)

        self.dropout = utilities.Multisample_Dropout()

    def forward(self, x):
        embeddings = self.dropout(x, self.emb)
        if not self.slim:
            output = self.arc(embeddings)
            return output, F.normalize(embeddings)
        return F.normalize(embeddings)
class Model(nn.Module):
    def __init__(self, vit_backbone, head_size, version='v1', slim=False):
        super(Model, self).__init__()
        if version == 'v1':
            self.encoder = vit_backbone.visual
        elif version == 'v2':
            self.encoder = vit_backbone.visual.trunk
            
        self.head = Head(head_size, slim)

    def forward(self, x):
        x = self.encoder(x)
        return self.head(x)


In [4]:
path_list =  [
  f'my_experiments/ViT-H-14-laion2b_s32b_b79k-cut_out-product-10k/model_best_epoch_2_mAP3_0.55_slim.pt',
  f'my_experiments/ViT-H-14-laion2b_s32b_b79k-happy_whale-product-10k/model_best_epoch_3_mAP3_0.54_slim.pt',
  f'my_experiments/ViT-H-14-laion2b_s32b_b79k-None-product-10k/model_best_epoch_2_mAP3_0.53_slim.pt',
  f'my_experiments/vit_h_224_products-10k/model_best_epoch_3_mAP3_0.53_slim.pt'
]

class CFG:
    emb_size=512
    model_name = 'ViT-H-14'
    hidden_layer = 1024
    version = 'v1'
    n_classes=9004


backbone, _, _ = open_clip.create_model_and_transforms(CFG.model_name, None)
# Load models weights
weight_list = []

for path in path_list:
    model = Model(backbone, CFG.hidden_layer, CFG.version, True)
    model.load_state_dict(torch.load(path), strict=False)
    weight_list.append(model.state_dict())

# Average weights
state_dict = dict((k, torch.stack([v[k] for v in weight_list]).mean(0)) for k in weight_list[0])
model.load_state_dict(state_dict)

torch.save(model.state_dict(), f'my_experiments/{CFG.model_name}-soup.pt')

In [3]:
def slim_model(model_path, CFG):
    name = os.path.splitext(model_path)[0]
    
    checkpoint = torch.load(model_path)
    backbone, _, _ = open_clip.create_model_and_transforms(CFG.model_name, None)

    model = Model(backbone, CFG.hidden_layer, CFG.version)
    model.load_state_dict(checkpoint['model_state_dict'])

    model_slim = Model(backbone, CFG.hidden_layer, CFG.version, True)
    model_slim.head.emb = model.head.emb
    model_slim.encoder = model.encoder

    torch.save(model_slim.state_dict(), 
               name + '_slim.pt')
    
class CFG:
    emb_size=512
    model_name = 'ViT-H-14'
    hidden_layer = 1024
    version = 'v1'
    n_classes=9691
    

slim_model(f'my_experiments/ViT-H-14-laion2b_s32b_b79k-image_net-product-10k-all/model_epoch_1_mAP5_0.55.pt', CFG())
slim_model(f'my_experiments/ViT-H-14-laion2b_s32b_b79k-image_net-product-10k-all/model_epoch_2_mAP5_0.55.pt', CFG())
slim_model(f'my_experiments/ViT-H-14-laion2b_s32b_b79k-image_net-product-10k-all/model_epoch_3_mAP5_0.55.pt', CFG())


In [1]:

    
checkpoint = torch.load(f'my_experiments/convnext_base_w-laion2b_s13b_b82k_augreg-image_net-v2-product-10k/model_best_epoch_3_mAP3_0.47.pt')
backbone, _, _ = open_clip.create_model_and_transforms(CFG.model_name, None)

model = Model(backbone, 1024, 'v2')
model.load_state_dict(checkpoint['model_state_dict'])

model_slim = Model(backbone, 1024, 'v2', True)
model_slim.head.emb = model.head.emb
model_slim.encoder = model.encoder


torch.save(model_slim.state_dict(), 
           f'my_experiments/convnext_base_w-laion2b_s13b_b82k_augreg-image_net-v2-product-10k/model_best_epoch_3_mAP3_0.47_slim.pt')