In [19]:
import clip

In [20]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [1]:
import wandb
import os
import torch
import torch.nn as nn
from torch.cuda import amp

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
wandb.init(project="clip_cls_9", id="h4vl8o8o", resume='must')
CONFIG = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mshivamshrirao[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [61]:
features_cache = f"{CONFIG['clip_type']}_features.pth"
ft_dict = torch.load(features_cache)
train_features = ft_dict["train_features"]
train_labels = ft_dict["train_labels"]
test_features = ft_dict["test_features"]
test_labels = ft_dict["test_labels"]

In [62]:
torch.unique(train_labels, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
        device='cuda:0'),
 tensor([3594, 3524, 3563, 3530, 3582, 3565, 3564, 3557, 3550, 3549, 3554, 3570,
         3555, 3542, 3531, 3579, 3770, 3530, 3762, 3558, 3750, 3560, 3549, 3788,
         3616, 3561, 3561, 3566, 3790, 3556, 3581, 3548, 3580, 3756, 3834, 3759],
        device='cuda:0'))

In [63]:
torch.unique(test_labels, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
        device='cuda:0'),
 tensor([189, 210, 163, 205, 171, 171, 180, 183, 193, 191, 180, 193, 182, 199,
         201, 189, 189, 201, 197, 178, 217, 210, 188, 209, 157, 198, 178, 181,
         208, 188, 198, 180, 180, 198, 182, 198], device='cuda:0'))

In [64]:
cnv = {}
for i in range(8):
    x = i*45
    if i%2:
        cnv[x-5] = cnv[x+5] = i
    else:
        cnv[(x-10)%360] = cnv[x] = cnv[(x+10)%360] = i
cnv = {k//10:v for k,v in cnv.items()}

for i in range(len(train_labels)):
    try:
        train_labels[i] = cnv[train_labels[i].item()]
    except KeyError:
        train_labels[i] = 8
for i in range(len(test_labels)):
    try:
        test_labels[i] = cnv[test_labels[i].item()]
    except KeyError:
        test_labels[i] = 8

In [5]:
num_classes = 9

In [6]:
cls_head = nn.Sequential(
    nn.Linear(len(train_features[0]), CONFIG["hid_dim"]),
    nn.ReLU(),
    nn.Dropout(CONFIG["dropout"]),
    nn.Linear(CONFIG["hid_dim"], num_classes)
).to(device).eval()

In [7]:
cls_head.load_state_dict(torch.load(wandb.restore("best_weights_new.pth").name))

<All keys matched successfully>

In [16]:
data = train_features[:256]
print(data.shape)

torch.Size([256, 512])


In [9]:
class Wrapped_linear_model(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    @torch.inference_mode()
    def forward(self, data, fp16=True):
        with amp.autocast(enabled=fp16):
            x = self.model(data)
            x = torch.softmax(x, dim=1)
            # x = x.argmax(dim=1,keepdim=True)
        return x
wrp_model = Wrapped_linear_model(cls_head)

In [10]:
%%time
torch.cuda.synchronize()
with torch.no_grad():
    svd_out = wrp_model(data, True)
torch.cuda.synchronize()
print(svd_out.shape, svd_out.dtype)

torch.Size([256, 9]) torch.float32
CPU times: user 4.29 ms, sys: 824 µs, total: 5.11 ms
Wall time: 4.11 ms


In [17]:
with torch.inference_mode(), torch.jit.optimized_execution(True):
    traced_script_module = torch.jit.trace(wrp_model, data)
    traced_script_module = torch.jit.optimize_for_inference(traced_script_module)

print(traced_script_module.code)

def forward(self,
    data: Tensor) -> Tensor:
  _0 = torch.add(torch.matmul(data, CONSTANTS.c0), CONSTANTS.c1)
  input = torch.relu(_0)
  _1 = torch.add(torch.matmul(input, CONSTANTS.c2), CONSTANTS.c3)
  return torch.softmax(_1, 1, 6)



In [18]:
OUT_PATH = "car_angle_classifier_9/1/"
os.makedirs(OUT_PATH, exist_ok=True)

traced_script_module.save(f"{OUT_PATH}/model.pt")
traced_script_module = torch.jit.load(f"{OUT_PATH}/model.pt")

In [15]:
%%time
torch.cuda.synchronize()
with torch.no_grad():
    o = traced_script_module(data)
torch.cuda.synchronize()
print(o.shape, o.dtype)

torch.Size([256, 9]) torch.float32
CPU times: user 1.15 ms, sys: 538 µs, total: 1.69 ms
Wall time: 743 µs
