In [142]:
from pandas import read_csv

from datasets.severstal_coco import DatasetCOCO
from torchvision import transforms
from models.clipseg import CLIPDensePredT
import torch
from PIL import Image
import numpy as np
import pandas as pd
import os
from tqdm import tqdm

In [143]:
data_path = '/home/eas/Enol/pycharm_projects/clipseg_steel_defect/Severstal/train_subimages'
weights = '/home/eas/Enol/pycharm_projects/clipseg_steel_defect/logs/rd64-7K-vit16-cbh-coco-enol-5classes_no_neg/weights.pth'
mean = [0.34388125, 0.34388125, 0.34388125]
std = [0.13965334, 0.13965334, 0.13965334]
image_size = 256
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
COCO_CLASSES = {1: 'network of fine, hairline cracks or fissures on the surface of the steel',
                2: 'scale—oxides embedded into the steel plate',
                3: 'shallow, narrow grooves or lines on the surface of the steel',
                4: 'impurity or foreign material embedded within the steel matrix',
                5: 'defects on a steel plate'}
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
model.load_state_dict(torch.load(weights, weights_only=True, map_location=torch.device('cpu')), strict=False)
model.cuda()
model.eval()
split = 'train'

In [144]:
df = pd.DataFrame(columns = ['class_id', 'embedding'])

In [145]:
coco_dataset = DatasetCOCO(data_path, transform, split, True)

loading annotations into memory...
Done (t=1.04s)
creating index...
index created!


In [146]:
class_ids = coco_dataset.ids_by_class
metadata = coco_dataset.img_metadata

In [147]:
idx = 0
for i in class_ids:
    class_id = i
    for ann in tqdm(class_ids[class_id]):
        query = metadata.loadAnns(ids=ann)[0]
        image = Image.open(os.path.join(data_path, query['image_id']))
        q = transform(image).unsqueeze(0)
        q = q.cuda()
        visual_q, _, _ = model.visual_forward(q)
        df.loc[idx] = [class_id, visual_q.squeeze(0).cpu().numpy().tolist()]
        idx += 1


100%|██████████| 1296/1296 [00:18<00:00, 69.37it/s]
100%|██████████| 251/251 [00:03<00:00, 69.25it/s]
100%|██████████| 9846/9846 [02:24<00:00, 67.99it/s]
100%|██████████| 1596/1596 [00:24<00:00, 66.01it/s]
100%|██████████| 48985/48985 [13:30<00:00, 60.47it/s]


In [148]:
df.to_csv('../Severstal/train_embeddings.csv')

In [149]:
df.head()

Unnamed: 0,class_id,embedding
0,1,"[-0.38728979229927063, 0.1325313150882721, 0.10961241275072098, -0.0776892900466919, -0.2376573383808136, 0.396444708108902, 0.4293062090873718, -0.7315817475318909, 0.4360772669315338, -0.21442754566669464, 0.38084766268730164, -0.058034516870975494, -0.16634228825569153, -0.09514516592025757, -0.035816311836242676, -0.07400727272033691, -0.08032824844121933, -0.19801253080368042, -0.10912565886974335, 0.031279195100069046, 0.3076708912849426, 0.07886675000190735, 0.24258442223072052, -0.4971737265586853, -0.10252947360277176, 0.2857493758201599, 0.22882609069347382, 0.2271960973739624, 0.32311880588531494, 0.15770292282104492, 0.48655959963798523, -0.3500517010688782, -0.1687813401222229, -0.0450945645570755, -0.008440330624580383, 0.059139810502529144, 0.0031319819390773773, -0.08210553228855133, 0.39164549112319946, -0.4161904454231262, 0.23444592952728271, -0.1493268609046936, -0.06440040469169617, -0.515926718711853, -0.24705302715301514, 0.0696784183382988, -0.032016247510910034, 0.05659513175487518, -0.0454588308930397, 0.02564576268196106, -1.0520837306976318, 0.27455171942710876, -0.21860134601593018, -0.03460659459233284, -0.41712895035743713, -0.20225539803504944, 0.08583514392375946, 0.10130418837070465, -0.29573720693588257, 0.38771989941596985, -0.19859738647937775, 0.2832903563976288, 0.0646592304110527, 0.2973759174346924, 0.01853802800178528, -0.00280144065618515, -0.19019627571105957, -0.042342446744441986, 0.060694217681884766, 0.023326363414525986, 0.2721039056777954, 0.21078523993492126, 0.21950533986091614, 0.29747268557548523, 0.09350097179412842, 0.01789318025112152, 0.01065075397491455, 0.05529436469078064, 0.04901812970638275, -0.24137505888938904, 0.020862117409706116, 1.1850284337997437, -0.38424891233444214, 0.05827374383807182, -0.4147138297557831, 0.04594685882329941, 0.019365280866622925, 0.3230944573879242, -0.11523213982582092, 0.4022407829761505, -0.29988038539886475, 0.005788251757621765, 0.44290459156036377, -0.4417787194252014, 0.04298654571175575, -0.05820927023887634, 0.4811069369316101, 0.06407156586647034, 0.10631725937128067, 0.04234514757990837, ...]"
1,1,"[-0.2801094055175781, 0.15718357264995575, -0.04015563800930977, -0.09137618541717529, -0.458618700504303, 0.29957664012908936, 0.2690992057323456, -0.6623987555503845, 0.46521905064582825, 0.15326252579689026, 0.40036439895629883, -0.26047730445861816, -0.10022713243961334, 0.04868073761463165, 0.059188250452280045, 0.016262397170066833, -0.5134696960449219, 0.00570862740278244, -0.04755609482526779, 0.0433519184589386, 0.22405296564102173, 0.21138912439346313, 0.1691884696483612, -0.10479582846164703, 0.12893274426460266, 0.5140072107315063, 0.1602390557527542, 0.25570589303970337, 0.24100430309772491, 0.23052600026130676, 0.1857987642288208, -0.14518029987812042, 0.23109591007232666, 0.02023504301905632, 0.12283474206924438, -0.1041703149676323, -0.17244169116020203, -0.03495006263256073, 0.18781721591949463, -0.3569689393043518, 0.20010986924171448, -0.21559344232082367, -0.04534490406513214, -0.2725090980529785, -0.23522822558879852, -0.08121226727962494, 0.039956748485565186, -0.17639410495758057, -0.21518222987651825, 0.10395504534244537, -0.7462994456291199, 0.22589901089668274, -0.32542985677719116, 0.1859058141708374, -0.2912631034851074, -0.3586724102497101, 0.08585558086633682, 0.026710210368037224, -0.3648502826690674, 0.17597728967666626, -0.1491468846797943, 0.25386619567871094, 0.08484192937612534, 0.26596954464912415, -0.126654714345932, 0.3657137453556061, -0.01639195904135704, -0.3195495009422302, 0.11555413156747818, -0.14012359082698822, 0.02452421933412552, 0.2973252236843109, -0.10456810891628265, 0.22481457889080048, -0.1663247048854828, 0.2770361304283142, -0.15046796202659607, 0.3947797417640686, 0.08398444950580597, -0.2013736367225647, 0.00796101987361908, 1.3161910772323608, -0.5672309398651123, -0.15336467325687408, -0.137942835688591, 0.1709488034248352, 0.43361765146255493, 0.15788224339485168, -0.14363700151443481, 0.3549407720565796, -0.15281563997268677, -0.18558694422245026, 0.2837427854537964, -0.267708957195282, -0.053555555641651154, -0.26588520407676697, 0.8909883499145508, 0.2520841956138611, -0.019264470785856247, -0.005950059741735458, ...]"
2,1,"[-0.3826413154602051, 0.19792422652244568, -0.015792451798915863, -0.08102995902299881, -0.2054748833179474, 0.1281183660030365, 0.14337828755378723, -0.757550835609436, 0.166789248585701, 0.013629473745822906, 0.38061296939849854, -0.06911234557628632, 0.009134829044342041, -0.03420183062553406, 0.10616698116064072, -0.04714302346110344, -0.41679635643959045, 0.13644759356975555, -0.03780820220708847, -0.25945281982421875, 0.0925293117761612, 0.15511386096477509, 0.03263287991285324, -0.11344002187252045, 0.08718059211969376, 0.4830605983734131, 0.025937557220458984, 0.17101450264453888, 0.2598717510700226, 0.40909138321876526, 0.31641846895217896, -0.3739442825317383, 0.04813726991415024, -0.26016756892204285, -0.1012931764125824, -0.19630461931228638, -0.16139961779117584, -0.10216903686523438, -0.12341705709695816, -0.4359232187271118, 0.1934213936328888, -0.15702378749847412, 0.2110065370798111, -0.15277066826820374, -0.14771755039691925, -0.12442879378795624, -0.198712557554245, -0.004791121929883957, -0.3633420169353485, 0.19739015400409698, -0.8776089549064636, 0.1173323318362236, -0.39160847663879395, 0.4435184597969055, -0.3400459587574005, -0.40590980648994446, 0.1222209706902504, -0.005894474685192108, -0.1256919801235199, 0.41026583313941956, -0.2542388141155243, 0.04133158177137375, -0.05463869124650955, 0.07671163231134415, -0.20728012919425964, 0.22208088636398315, 0.157108873128891, -0.30035436153411865, 0.18219202756881714, -0.2852945029735565, 0.09975612163543701, 0.41688358783721924, 0.16289372742176056, 0.356706440448761, -0.1674250066280365, 0.25994086265563965, -0.06548383831977844, 0.4162321388721466, -0.004793688654899597, -0.20832321047782898, 0.07054979354143143, 1.4425147771835327, -0.49129945039749146, -0.2341916561126709, -0.11153580248355865, 0.2685663402080536, 0.34025269746780396, 0.212144136428833, -0.15651459991931915, 0.20628207921981812, 0.07791882008314133, -0.19316071271896362, 0.32160553336143494, -0.21518990397453308, -0.155859112739563, -0.40599364042282104, 1.0284093618392944, 0.18012401461601257, -0.023329565301537514, 0.0575430691242218, ...]"
3,1,"[-0.5552012324333191, 0.25537413358688354, -0.05433167517185211, -0.0723295658826828, -0.3331466317176819, 0.434346079826355, 0.443824827671051, -0.22699512541294098, 0.40405386686325073, -0.08588922023773193, 0.38622772693634033, -0.15862399339675903, 0.14688104391098022, 0.09161332249641418, 0.010543867945671082, 0.4040035605430603, -0.39630162715911865, 0.10248055309057236, -0.004529334604740143, -0.227260023355484, -0.04099203273653984, -0.14193426072597504, 0.25031787157058716, -0.22132211923599243, 0.017247825860977173, 0.4299442768096924, 0.2451317310333252, 0.46108248829841614, 0.26815536618232727, 0.26772361993789673, -0.2542440593242645, -0.14139333367347717, 0.20492717623710632, 0.09640364348888397, -0.14132022857666016, 0.015251852571964264, 0.006130457855761051, -0.01009000837802887, 0.1305685192346573, -0.4192633628845215, 0.1464238166809082, -0.4067828059196472, 0.057264216244220734, -0.04992176592350006, -0.2484186291694641, 0.11850886791944504, 0.036084383726119995, -0.09272708743810654, -0.3314101994037628, -0.07930420339107513, -0.9210882186889648, 0.08782438933849335, -0.29749196767807007, 0.2722328007221222, -0.34898701310157776, 0.012838363647460938, 0.175644651055336, -0.035728126764297485, -0.27088671922683716, 0.3705482482910156, -0.19683125615119934, 0.22314828634262085, -0.09510277211666107, 0.2921000123023987, -0.16711658239364624, 0.35968801379203796, -0.27169740200042725, -0.36874663829803467, 0.2776908278465271, 0.10276833921670914, -0.014862548559904099, 0.22605876624584198, 0.17820735275745392, 0.2241518199443817, -0.3176015019416809, -0.02536264806985855, -0.3014776110649109, 0.33755430579185486, 0.07293783128261566, -0.048624299466609955, -0.0777970403432846, 1.4658552408218384, -0.4953076243400574, -0.007828395813703537, -0.0008215755224227905, 0.05545738339424133, 0.4566596448421478, 0.33677253127098083, -0.11973616480827332, 0.40669530630111694, -0.42459312081336975, -0.45860937237739563, 0.542346715927124, -0.19214734435081482, 0.1163230612874031, -0.30559441447257996, 0.9781761169433594, 0.28333592414855957, 0.03523014858365059, -0.06308187544345856, ...]"
4,1,"[-0.5057097673416138, -0.054634831845760345, 0.03243769705295563, -0.21033257246017456, -0.3331807255744934, 0.263546884059906, 0.4639630615711212, -0.6098024249076843, 0.20411211252212524, 0.18121975660324097, 0.24534720182418823, -0.11414247006177902, -0.27681803703308105, 0.3927534520626068, -0.18839898705482483, 0.3391340970993042, -0.27027401328086853, 0.03127623349428177, -0.20408865809440613, -0.010059282183647156, 0.24302157759666443, -0.03152664750814438, 0.20778140425682068, -0.4009173810482025, -0.015259236097335815, 0.4284830689430237, 0.32880550622940063, 0.20947419106960297, 0.3058100640773773, 0.26454365253448486, 0.23988443613052368, -0.01714957505464554, 0.175520658493042, -0.05424857139587402, -0.25042733550071716, -0.00017038360238075256, -0.2460595965385437, -0.013925820589065552, 0.45958247780799866, -0.07351827621459961, 0.029161028563976288, -0.3363761305809021, 0.1708647757768631, 0.026638159528374672, -0.402912437915802, -0.008655814453959465, -0.18696379661560059, -0.1515980064868927, -0.24339058995246887, -0.030558854341506958, -0.8097818493843079, 0.12612834572792053, 0.05124777555465698, 0.1713104546070099, -0.19655084609985352, -0.17416852712631226, 0.10548857599496841, 0.09420032799243927, -0.13004091382026672, -0.2725737988948822, -0.18443123996257782, 0.3950836658477783, 0.2654860019683838, 0.36432042717933655, 0.09661847352981567, -0.10494248569011688, -0.396776020526886, -0.25181955099105835, 0.05330274999141693, -0.06717322021722794, 0.03404320776462555, 0.26656410098075867, -0.024881087243556976, 0.1342945098876953, 0.3882603049278259, 0.18244193494319916, -0.22614315152168274, 0.05216924846172333, -0.10981924831867218, -0.07750173658132553, -0.039257586002349854, 0.9777536988258362, -0.3185191750526428, -0.029184281826019287, 0.188857302069664, 0.1698978841304779, 0.20738965272903442, 0.2642240822315216, 0.051154159009456635, 0.3144454061985016, -0.02250901609659195, 0.005115527659654617, 0.35163456201553345, -0.3566664755344391, -0.07341723889112473, 0.29449906945228577, 0.9347761273384094, 0.16617105901241302, 0.06639858335256577, 0.03998081386089325, ...]"


In [150]:
len(df.loc[0]['embedding'])

512