In [1]:
import torch
from pathlib import Path
from model.Img2Vec import Img2Vec
from model.ResNet import ResNet
from model.focalnet import FocalNetBackbone

In [5]:
# The checkpoint ".ckpt" path
ckpt_path = "/home/rohitn/D1/Hindi_embed_logs/Img2Vec_FocalNet_wo_B_r256/147196/checkpoints/epoch=7-val_loss=0.41-val_acc=0.77.ckpt"
# give appropriate name for embeddings
name = "focalnet_wo_B_r256.pth"
# save location
save_path = Path("/home/rohitn/IndianSTR/Img2Vec/GroupNet/embeddings/new") / name

### Define Backbone
refer the checkpoint hydra config to initialize the backbone


In [3]:
backbone = ResNet(
                    version= 50,
                    img_size= 128,
                    out_features= 1024,               
                )

In [3]:
import transformers

backbone = FocalNetBackbone(
    config = transformers.FocalNetConfig(
                  image_size= 128,
                  patch_size= 4,
                  num_channels= 3,
                  embed_dim= 96,
                  use_conv_embed= True,
                  hidden_sizes= [96,192,384,768],
                  depths= [2,2,18,2],
                  focal_levels= [2, 2, 2, 2],
                  focal_windows= [3, 3, 3, 3],
                  hidden_act= "gelu",
                  mlp_ratio= 4.0,
                  hidden_dropout_prob= 0.2,
                  drop_path_rate= 0.1,
                  use_layerscale= 1.0e-4,
                  use_post_layernorm= False,
                  use_post_layernorm_in_modulation= False,
                  normalize_modulator= False,
                  initializer_range= 2.0e-2,
                  layer_norm_eps= 1.0e-5,
                  encoder_stride= 32,
                  out_features= None,
              ),
            out_features= 768
)

### Intialize Img2Vec model

In [6]:
model = Img2Vec.load_from_checkpoint(ckpt_path, backbone= backbone)

In [4]:
model = Img2Vec.load_from_checkpoint(ckpt_path, backbone= backbone)

### Get the Embeddings

In [7]:
h_c_2_emb = model.half_character2_head.weight.data
h_c_1_emb = model.half_character1_head.weight.data
f_c_emb = model.character_head.weight.data
d_emb = model.diacritic_head.weight.data

In [8]:
h_c_2_emb.shape, h_c_1_emb.shape, f_c_emb.shape, d_emb.shape

(torch.Size([35, 256]),
 torch.Size([35, 256]),
 torch.Size([70, 256]),
 torch.Size([16, 256]))

In [10]:
len(model.half_character_classes), len(model.character_classes), len(model.diacritic_classes)

(34, 70, 16)

In [11]:
# save the embeddings along with character classes,
# in the embeddings folder of GroupNet
torch.save({
    "h_c_2_emb":h_c_2_emb,
    "h_c_1_emb":h_c_1_emb,
    "f_c_emb":f_c_emb,
    "d_emb":d_emb,
    "h_c_classes": model.half_character_classes,
    "f_c_classes": model.character_classes,
    "d_classes": model.diacritic_classes,
    }, 
    save_path
)

### Sanity Check

In [12]:
loaded_dict = torch.load(save_path)

In [13]:
loaded_dict.keys()

dict_keys(['h_c_2_emb', 'h_c_1_emb', 'f_c_emb', 'd_emb', 'h_c_classes', 'f_c_classes', 'd_classes'])

In [14]:
loaded_dict["h_c_classes"] == [ 'क', 'ख', 'ग', 'घ', 'ङ',
                            'च', 'छ', 'ज', 'झ', 'ञ',
                            'ट', 'ठ', 'ड', 'ढ', 'ण',
                            'त', 'थ', 'द', 'ध', 'न',
                            'प', 'फ', 'ब', 'भ', 'म',
                            'य', 'र', 'ल', 'ळ', 'व', 'श',
                            'ष', 'स', 'ह']

True

In [15]:
len(loaded_dict["h_c_classes"])

34