In [1]:
import utilities
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import open_clip

In [2]:
class CFG:
    model_name = 'ViT-L-14-336' 
    model_data = 'openai'
    emb_size = 512

In [3]:
vit_backbone, model_transforms, _ = open_clip.create_model_and_transforms(CFG.model_name)

In [None]:
class Head(nn.Module):
    def __init__(self, hidden_size):
        super(Head, self).__init__()

        self.emb = nn.Linear(hidden_size, CFG.emb_size, bias=False)
        self.arc = None
        self.dropout = utilities.Multisample_Dropout()

    def forward(self, x):
        embeddings = self.dropout(x, self.emb)
        
        output = self.arc(embeddings)

        return output, embeddings

In [None]:
class Model(nn.Module):
    def __init__(self, vit_backbone):
        super(Model, self).__init__()

        self.vit_backbone = vit_backbone

        self.head = Head(768)

In [None]:
path_list =  [
              '../models/soup-v1/ViT-L-14-336',
              '../models/soup-v2/ViT-L-14-336',
              '../models/soup-v3/ViT-L-14-336',
              '../models/soup-v4/ViT-L-14-336'
              ]

# Load models weights
weight_list = []

for path in path_list:
    model = Model(vit_backbone)
    model.load_state_dict(torch.load(path), strict=False)
    weight_list.append(model.state_dict())

# Average weights
state_dict = dict((k, torch.stack([v[k] for v in weight_list]).mean(0)) for k in weight_list[0])
model.load_state_dict(state_dict)


In [None]:
model_name = CFG.model_name.replace('/','-')
torch.save(model.state_dict(), f'../models/{model_name}-soup')