## 数据集配置

### 下载数据集
- 数据集下载后存放到如下位置 data/dataset.zip 

In [None]:
!zip data/dataset.zip

### 将车与人的Label分开（训练集与验证集）
- 在复现的过程中，此处请手动分开(提供的dataset中已经包含了这两个文件)
- 对于训练集中的车的label，分开后命名为train_car_label.txt, 保存的文件夹与train_label.txt一致
- 对于训练集中的人的label，分开后命名为train_person_label.txt, 保存的文件夹与train_label.txt一致
- 对于验证集中的车的label，分开后命名为val_car_label.txt, 保存的文件夹与val_label.txt一致
- 对于验证集中的人的label，分开后命名为val_person_label.txt, 保存的文件夹与var_label.txt一致

In [56]:
#使用自己电脑上的粘贴板完成(提供的dataset中已经包含了这两个文件)

### 将车与人的Query分开（测试集）
- 1. 读取训练集中的车的属性列表
- 2. 检查每一个Query中是否汽车的属性，如果有则判定为车，如果没有则判定为人

In [1]:
import json
def check_car_attr(line, type_attrs, brand_attrs):
    win = False
    for car_type in type_attrs:
        if car_type in line:
            win = True
            break
    
    for car_brand in brand_attrs:
        if car_brand in line:
            win = True
            break
    return win

def save_lst_to_file(save_path, path_lst):
    with open(save_path, "w") as f:
        for idx, path in enumerate(path_lst):
            if idx < len(path_lst)-1:
                f.write(path+'\n')
            else:
                f.write(path)

In [2]:
!sh script/run_attr_parser.sh

data successfully split
11 6 65
Colors: ['black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
Brands: ['Audi', 'BAOJUN', 'BESTUNE', 'BMW', 'BYD', 'Balong Heavy Truck', 'Bentley', 'Benz', 'Buick', 'Cadillac', 'Chana', 'Chery', 'Chevrolet', 'China-Moto', 'Citroen', 'Dongfeng', 'FAW', 'FORLAND', 'FOTON', 'Ford', 'Geely', 'Golden Dragon', 'GreatWall', 'HAFEI', 'Haima', 'Honda', 'Hongyan', 'Hyundai', 'Infiniti', 'Isuzu', 'Iveco', 'JAC', 'JMC', 'Jeep', 'Jinbei', 'KINGLONG', 'Karma', 'Kia', 'LEOPAARD', 'Landrover', 'Lexus', 'Luxgen', 'MORRIS-GARAGE', 'Mazda', 'Mini', 'Mitsubishi', 'Nissan', 'OPEL', 'PEUGEOT', 'Porsche', 'ROEWE', 'SGMW', 'SKODA', 'Shacman', 'Shuanghuan', 'Soueast', 'Style', 'Subaru', 'Suzuki', 'Toyota', 'Volkswagen', 'Volvo', 'XIALI', 'Yutong', 'ZXAUTO']
Types: ['Bus', 'Microbus', 'Minivan', 'SUV', 'Sedan', 'Truck']
car attribute files are saved to data/car_attribute.json


In [3]:
test_query_path = 'data/datasets/test/test_text.txt'
car_attr_path = 'data/car_attribute.json'

test_car_lst = []
test_car_query_path = 'data/datasets/test/test_car_text.txt'

test_person_lst = []
test_person_query_path = 'data/datasets/test/test_person_text.txt'

with open(car_attr_path) as f:
    auto_attrs = json.load(f)


with open(test_query_path) as f:
    test_lines = f.readlines()
    for line in test_lines:
        if check_car_attr(line, auto_attrs['Types'], auto_attrs['Brands']):
            test_car_lst.append(line.strip())
        else:
            test_person_lst.append(line.strip())
            
    save_lst_to_file(test_car_query_path, test_car_lst)
    save_lst_to_file(test_person_query_path, test_person_lst)

## 数据增强
下载数据增强后的文件到data/augmented_ViT-bigG-14_train_label.txt和data/augmented_ViT-bigG-14_train_label.txt <br/>


【重要事项】由于B榜数据对于选手并不可见，我们也无法对于B榜中的数据集进行分析，无法知道其中汽车的属性。为了在B榜100%复现我们的数据增强过程，需要保证如下条件：
- data/car_attribute.json需要按照类型保存好所有的汽车属性
- 需要保证生成的augmented_ViT-bigG-14_train_label.txt文件与augmented_ViT-bigG-14_train_label.txt缺失的属性均已经补齐（品牌除外）
- 品牌除外是指，如果汽车数据中的汽车品牌不缺失，那么保留；如果汽车数据中的汽车品牌缺失，也不会通过数据增强手段重新加上
- 在补全属性的过程中，由于我们对于B榜的前缀也是不清楚的，因此数据增强并不需要删除前缀 <br/>


【注意1】run_data_aug_val.sh中的batchsize参数需要能够整除 train_car_label.txt 中的样本数量（否则会报错）, 另外如果由于车的样本数量恰好是指数，可以稍微减少几个样本<br/>
【注意2】如果觉得数据增强时间长，可以考虑修改一下un_data_aug_train.sh与run_data_aug_val.sh文件，在python指令前分别加上export CUDA_VISIBLE_DEVICES=0;与export CUDA_VISIBLE_DEVICES=1;这样可以设置不同的数据增强在不同卡上完成(参考run_model.sh)

## 训练集数据增强

In [4]:
!bash  script/run_data_aug_train.sh

  0%|                                                  | 0/2669 [00:00<?, ?it/s]Label probs: [[0.16489657759666443, 0.1609216034412384, 0.16526083648204803, 0.1745302826166153, 0.16811691224575043, 0.166273832321167]]  argmax_id: 3
Label probs: [[0.16728202998638153, 0.16117610037326813, 0.16496524214744568, 0.16740703582763672, 0.17127221822738647, 0.1678972989320755]]  argmax_id: 4
Label probs: [[0.16386529803276062, 0.1561427265405655, 0.16417595744132996, 0.16957470774650574, 0.17818279564380646, 0.16805851459503174]]  argmax_id: 4
Label probs: [[0.16441763937473297, 0.1593405306339264, 0.16698625683784485, 0.16844594478607178, 0.17314794659614563, 0.16766174137592316]]  argmax_id: 4
Label probs: [[0.16444313526153564, 0.15932312607765198, 0.16627995669841766, 0.1671518087387085, 0.17661020159721375, 0.16619175672531128]]  argmax_id: 4
Label probs: [[0.16574232280254364, 0.15901197493076324, 0.16464126110076904, 0.1669035106897354, 0.17764607071876526, 0.1660548448562622]]  argmax_

  0%|                                       | 4/2669 [01:28<16:01:30, 21.65s/it]Label probs: [[0.16756515204906464, 0.15593190491199493, 0.16250726580619812, 0.1677710860967636, 0.17622025310993195, 0.17000433802604675]]  argmax_id: 4
Label probs: [[0.16283932328224182, 0.15688027441501617, 0.16573883593082428, 0.17354416847229004, 0.17454512417316437, 0.16645225882530212]]  argmax_id: 4
Label probs: [[0.168327197432518, 0.1589937061071396, 0.16499894857406616, 0.16935575008392334, 0.1714991331100464, 0.16682523488998413]]  argmax_id: 4
Label probs: [[0.16604012250900269, 0.1553892195224762, 0.16284982860088348, 0.170619934797287, 0.17502108216285706, 0.17007982730865479]]  argmax_id: 4
Label probs: [[0.1662401407957077, 0.15597613155841827, 0.16454823315143585, 0.17080867290496826, 0.17526565492153168, 0.16716112196445465]]  argmax_id: 4
Label probs: [[0.1654333621263504, 0.1570606231689453, 0.1661173403263092, 0.1673283874988556, 0.17634430527687073, 0.16771601140499115]]  argmax_id:

  0%|                                       | 8/2669 [02:51<15:25:12, 20.86s/it]Label probs: [[0.16526809334754944, 0.15939858555793762, 0.1636074185371399, 0.16768455505371094, 0.18194057047367096, 0.16210076212882996]]  argmax_id: 4
Label probs: [[0.1664077341556549, 0.160002201795578, 0.16353082656860352, 0.167372465133667, 0.17807340621948242, 0.16461332142353058]]  argmax_id: 4
Label probs: [[0.16691771149635315, 0.15993233025074005, 0.1645280122756958, 0.1639566719532013, 0.18177641928195953, 0.1628888100385666]]  argmax_id: 4
Label probs: [[0.16527342796325684, 0.16096879541873932, 0.16665174067020416, 0.1686939001083374, 0.17544424533843994, 0.16296793520450592]]  argmax_id: 4
Label probs: [[0.1641179472208023, 0.1594994217157364, 0.16507014632225037, 0.1666514128446579, 0.18060755729675293, 0.16405346989631653]]  argmax_id: 4
Label probs: [[0.16402621567249298, 0.1616906225681305, 0.16846638917922974, 0.166480153799057, 0.17622098326683044, 0.1631157100200653]]  argmax_id: 4
L

  0%|▏                                     | 12/2669 [04:13<15:02:47, 20.39s/it]Label probs: [[0.16102094948291779, 0.1648569405078888, 0.16880996525287628, 0.16932997107505798, 0.17517796158790588, 0.16080418229103088]]  argmax_id: 4
Label probs: [[0.16227905452251434, 0.15963169932365417, 0.17015528678894043, 0.16874907910823822, 0.1790376752614975, 0.16014724969863892]]  argmax_id: 4
Label probs: [[0.1631673276424408, 0.16362744569778442, 0.16948844492435455, 0.165256068110466, 0.17775936424732208, 0.16070139408111572]]  argmax_id: 4
Label probs: [[0.15914613008499146, 0.15981045365333557, 0.17032812535762787, 0.17089785635471344, 0.18150262534618378, 0.1583147943019867]]  argmax_id: 4
Label probs: [[0.158831387758255, 0.1595308780670166, 0.17132382094860077, 0.17225593328475952, 0.17601777613162994, 0.16204017400741577]]  argmax_id: 4
Label probs: [[0.16117507219314575, 0.16045652329921722, 0.16936802864074707, 0.1708802431821823, 0.17653009295463562, 0.16159000992774963]]  argmax_

  1%|▏                                     | 16/2669 [05:34<15:10:20, 20.59s/it]Label probs: [[0.1652195155620575, 0.1610451638698578, 0.16476574540138245, 0.16936342418193817, 0.1744295358657837, 0.1651766300201416]]  argmax_id: 4
Label probs: [[0.16475604474544525, 0.16187059879302979, 0.1675264686346054, 0.1695823073387146, 0.17236922681331635, 0.16389530897140503]]  argmax_id: 4
Label probs: [[0.16221055388450623, 0.16475141048431396, 0.1716194897890091, 0.1681647002696991, 0.17112542688846588, 0.16212838888168335]]  argmax_id: 2
Label probs: [[0.1655394732952118, 0.16298122704029083, 0.16742680966854095, 0.1655367612838745, 0.1756875216960907, 0.1628282070159912]]  argmax_id: 4
Label probs: [[0.162850022315979, 0.16214774549007416, 0.1721813678741455, 0.1682204306125641, 0.17314334213733673, 0.1614570915699005]]  argmax_id: 4
Label probs: [[0.16603150963783264, 0.16255083680152893, 0.1662949025630951, 0.173151895403862, 0.17032188177108765, 0.1616489589214325]]  argmax_id: 3
Label

  1%|▎                                     | 20/2669 [06:54<14:46:11, 20.07s/it]Label probs: [[0.1642032414674759, 0.16323940455913544, 0.16822503507137299, 0.1711057424545288, 0.16816700994968414, 0.16505959630012512]]  argmax_id: 3
Label probs: [[0.16733472049236298, 0.1620672643184662, 0.16572847962379456, 0.16866005957126617, 0.17003227770328522, 0.1661771684885025]]  argmax_id: 4
Label probs: [[0.16548192501068115, 0.16321489214897156, 0.1649785339832306, 0.17265117168426514, 0.1679496020078659, 0.16572390496730804]]  argmax_id: 3
Label probs: [[0.15847642719745636, 0.16357456147670746, 0.17145800590515137, 0.17377673089504242, 0.16970673203468323, 0.16300755739212036]]  argmax_id: 3
Label probs: [[0.1658472865819931, 0.1669626235961914, 0.16698293387889862, 0.1686287522315979, 0.17064650356769562, 0.16093188524246216]]  argmax_id: 4
Label probs: [[0.15763519704341888, 0.15852218866348267, 0.1692308485507965, 0.17062988877296448, 0.1793706864118576, 0.16461120545864105]]  argmax_i

  1%|▎                                     | 24/2669 [08:14<14:41:12, 19.99s/it]Label probs: [[0.16796401143074036, 0.164394348859787, 0.16668248176574707, 0.16574519872665405, 0.1619485467672348, 0.17326538264751434]]  argmax_id: 5
Label probs: [[0.16800256073474884, 0.16499991714954376, 0.16487278044223785, 0.16353683173656464, 0.1613757610321045, 0.177212193608284]]  argmax_id: 5
Label probs: [[0.16697660088539124, 0.1597275286912918, 0.16191285848617554, 0.16705258190631866, 0.16245627403259277, 0.18187415599822998]]  argmax_id: 5
Label probs: [[0.16437682509422302, 0.16808819770812988, 0.1728496551513672, 0.16453798115253448, 0.16717293858528137, 0.16297437250614166]]  argmax_id: 2
Label probs: [[0.16963736712932587, 0.16246956586837769, 0.15994714200496674, 0.1603584885597229, 0.16393420100212097, 0.18365322053432465]]  argmax_id: 5
Label probs: [[0.1632048338651657, 0.15715241432189941, 0.1640118658542633, 0.16887377202510834, 0.1792546659708023, 0.16750243306159973]]  argmax_id

  1%|▍                                     | 28/2669 [09:34<14:42:55, 20.06s/it]Label probs: [[0.16247640550136566, 0.16131676733493805, 0.16624051332473755, 0.17365391552448273, 0.17237715423107147, 0.16393530368804932]]  argmax_id: 3
Label probs: [[0.16168639063835144, 0.15727895498275757, 0.16701604425907135, 0.17455434799194336, 0.17919237911701202, 0.16027186810970306]]  argmax_id: 4
Label probs: [[0.16176089644432068, 0.15763162076473236, 0.16968707740306854, 0.17153404653072357, 0.17862528562545776, 0.1607610583305359]]  argmax_id: 4
Label probs: [[0.16438278555870056, 0.1588619351387024, 0.16436631977558136, 0.1731536090373993, 0.17574894428253174, 0.16348646581172943]]  argmax_id: 4
Label probs: [[0.16391558945178986, 0.1572026163339615, 0.16666623950004578, 0.17201879620552063, 0.1765293926000595, 0.16366738080978394]]  argmax_id: 4
Label probs: [[0.1631019413471222, 0.15741680562496185, 0.1665404736995697, 0.17378757894039154, 0.17597797513008118, 0.16317524015903473]]  argm

  1%|▍                                     | 32/2669 [10:53<14:27:45, 19.74s/it]Label probs: [[0.16331228613853455, 0.15852785110473633, 0.1689092069864273, 0.16926787793636322, 0.1752677857875824, 0.16471491754055023]]  argmax_id: 4
Label probs: [[0.16485252976417542, 0.15956325829029083, 0.16697542369365692, 0.16901355981826782, 0.1755463033914566, 0.16404889523983002]]  argmax_id: 4
Label probs: [[0.164462611079216, 0.16000822186470032, 0.16515718400478363, 0.16850078105926514, 0.17966929078102112, 0.1622018963098526]]  argmax_id: 4
Label probs: [[0.16423873603343964, 0.15830032527446747, 0.16730310022830963, 0.17191362380981445, 0.17648962140083313, 0.1617545485496521]]  argmax_id: 4
Label probs: [[0.16070684790611267, 0.15827308595180511, 0.1666957587003708, 0.1819131225347519, 0.16806510090827942, 0.16434600949287415]]  argmax_id: 3
Label probs: [[0.16534292697906494, 0.15966075658798218, 0.16368073225021362, 0.1692245453596115, 0.17093701660633087, 0.17115402221679688]]  argmax_

  1%|▌                                     | 36/2669 [12:13<14:32:24, 19.88s/it]Label probs: [[0.15697930753231049, 0.16288597881793976, 0.17404475808143616, 0.1767958104610443, 0.1706307977437973, 0.15866336226463318]]  argmax_id: 3
Label probs: [[0.15806159377098083, 0.16762883961200714, 0.17273908853530884, 0.17182818055152893, 0.16652143001556396, 0.1632208377122879]]  argmax_id: 2
Label probs: [[0.16251292824745178, 0.15771542489528656, 0.16503584384918213, 0.17747627198696136, 0.1702421009540558, 0.1670173555612564]]  argmax_id: 3
Label probs: [[0.15755251049995422, 0.15666161477565765, 0.1706237941980362, 0.17825046181678772, 0.1723044067621231, 0.16460716724395752]]  argmax_id: 3
Label probs: [[0.15714304149150848, 0.15893302857875824, 0.16905392706394196, 0.17999674379825592, 0.1697915643453598, 0.1650816947221756]]  argmax_id: 3
Label probs: [[0.1611216515302658, 0.1588270217180252, 0.16714069247245789, 0.17338882386684418, 0.17678125202655792, 0.1627405285835266]]  argmax_id

  1%|▌                                     | 40/2669 [13:32<14:31:34, 19.89s/it]Label probs: [[0.16143205761909485, 0.15988288819789886, 0.1669357568025589, 0.16887803375720978, 0.17896074056625366, 0.16391055285930634]]  argmax_id: 4
Label probs: [[0.1596846729516983, 0.15940313041210175, 0.17230935394763947, 0.1717710644006729, 0.17660269141197205, 0.16022908687591553]]  argmax_id: 4
Label probs: [[0.16063012182712555, 0.1595686376094818, 0.16767926514148712, 0.17070548236370087, 0.17923615872859955, 0.16218039393424988]]  argmax_id: 4
Label probs: [[0.16145065426826477, 0.16035722196102142, 0.1667199730873108, 0.16982761025428772, 0.178280770778656, 0.1633637547492981]]  argmax_id: 4
Label probs: [[0.16183948516845703, 0.16006149351596832, 0.1677587777376175, 0.16863064467906952, 0.1779382824897766, 0.16377130150794983]]  argmax_id: 4
Label probs: [[0.1603005826473236, 0.15936651825904846, 0.17021358013153076, 0.17067743837833405, 0.1792002022266388, 0.16024158895015717]]  argmax_id

  2%|▋                                     | 44/2669 [14:52<14:35:33, 20.01s/it]Label probs: [[0.15949763357639313, 0.15959110856056213, 0.167582705616951, 0.1686665564775467, 0.18018808960914612, 0.16447389125823975]]  argmax_id: 4
Label probs: [[0.15838877856731415, 0.15747088193893433, 0.1701897531747818, 0.18055227398872375, 0.16989487409591675, 0.16350342333316803]]  argmax_id: 3
Label probs: [[0.15779054164886475, 0.1583356410264969, 0.1678750216960907, 0.17104311287403107, 0.1838916838169098, 0.1610639989376068]]  argmax_id: 4
Label probs: [[0.15993911027908325, 0.1610000878572464, 0.16797560453414917, 0.16891323029994965, 0.17802831530570984, 0.1641436070203781]]  argmax_id: 4
Label probs: [[0.1614440232515335, 0.16439460217952728, 0.1639162003993988, 0.16528354585170746, 0.18119597434997559, 0.16376559436321259]]  argmax_id: 4
Label probs: [[0.15993747115135193, 0.1615794152021408, 0.16780361533164978, 0.16578161716461182, 0.1810758411884308, 0.16382208466529846]]  argmax_id: 

  2%|▋                                     | 48/2669 [16:14<14:46:04, 20.28s/it]Label probs: [[0.16410255432128906, 0.15947335958480835, 0.16094541549682617, 0.17463545501232147, 0.16786940395832062, 0.17297381162643433]]  argmax_id: 3
Label probs: [[0.16585853695869446, 0.15856610238552094, 0.16135522723197937, 0.17137663066387177, 0.17555224895477295, 0.16729123890399933]]  argmax_id: 4
Label probs: [[0.1632855236530304, 0.1586054265499115, 0.16181106865406036, 0.1728309541940689, 0.1710146814584732, 0.17245237529277802]]  argmax_id: 3
Label probs: [[0.1667303740978241, 0.1591535061597824, 0.16015173494815826, 0.17097142338752747, 0.17593799531459808, 0.16705496609210968]]  argmax_id: 4
Label probs: [[0.16599391400814056, 0.16122692823410034, 0.16165229678153992, 0.1678892970085144, 0.17771941423416138, 0.1655181348323822]]  argmax_id: 4
Label probs: [[0.16651169955730438, 0.16558773815631866, 0.16859912872314453, 0.17051388323307037, 0.16090843081474304, 0.1678791493177414]]  argmax

  2%|▋                                     | 52/2669 [17:33<14:30:09, 19.95s/it]Label probs: [[0.1674436628818512, 0.1580120176076889, 0.16388587653636932, 0.1697017103433609, 0.1762310117483139, 0.16472575068473816]]  argmax_id: 4
Label probs: [[0.1671593338251114, 0.16066376864910126, 0.16510868072509766, 0.17074881494045258, 0.17302274703979492, 0.1632966548204422]]  argmax_id: 4
Label probs: [[0.16313062608242035, 0.16304783523082733, 0.16695600748062134, 0.17398622632026672, 0.16861087083816528, 0.1642684042453766]]  argmax_id: 3
Label probs: [[0.16327033936977386, 0.15890692174434662, 0.16521640121936798, 0.17885778844356537, 0.16707366704940796, 0.16667479276657104]]  argmax_id: 3
Label probs: [[0.16342419385910034, 0.1617487370967865, 0.16662254929542542, 0.1751675307750702, 0.1691090166568756, 0.16392797231674194]]  argmax_id: 3
Label probs: [[0.1673138588666916, 0.16063062846660614, 0.16300882399082184, 0.1708858609199524, 0.1739843487739563, 0.16417643427848816]]  argmax_id:

  2%|▊                                     | 56/2669 [18:53<14:26:20, 19.89s/it]Label probs: [[0.15931127965450287, 0.16133320331573486, 0.17139792442321777, 0.1849873811006546, 0.16394828259944916, 0.15902189910411835]]  argmax_id: 3
Label probs: [[0.16318480670452118, 0.16792032122612, 0.1705377846956253, 0.168427973985672, 0.16213878989219666, 0.16779029369354248]]  argmax_id: 2
Label probs: [[0.16104717552661896, 0.16211606562137604, 0.16705195605754852, 0.17686273157596588, 0.1728053241968155, 0.16011668741703033]]  argmax_id: 3
Label probs: [[0.16099171340465546, 0.15974336862564087, 0.16481170058250427, 0.18308287858963013, 0.16913187503814697, 0.16223843395709991]]  argmax_id: 3
Label probs: [[0.171514093875885, 0.16533401608467102, 0.16172213852405548, 0.16569776833057404, 0.1600344181060791, 0.17569753527641296]]  argmax_id: 5
Label probs: [[0.15898600220680237, 0.15792356431484222, 0.16615724563598633, 0.18613451719284058, 0.1712195873260498, 0.1595790982246399]]  argmax_id:

  2%|▊                                     | 60/2669 [20:13<14:27:50, 19.96s/it]Label probs: [[0.16234371066093445, 0.15784704685211182, 0.16692204773426056, 0.1691596806049347, 0.17907683551311493, 0.16465070843696594]]  argmax_id: 4
Label probs: [[0.1612260490655899, 0.15848694741725922, 0.16665761172771454, 0.17858386039733887, 0.17005527019500732, 0.16499027609825134]]  argmax_id: 3
Label probs: [[0.15969648957252502, 0.1570950150489807, 0.16814205050468445, 0.17649288475513458, 0.1705300509929657, 0.16804352402687073]]  argmax_id: 3
Label probs: [[0.16056664288043976, 0.15829704701900482, 0.16526293754577637, 0.17232656478881836, 0.17819839715957642, 0.1653483510017395]]  argmax_id: 4
Label probs: [[0.16317546367645264, 0.16856855154037476, 0.1663607656955719, 0.16411887109279633, 0.1741977483034134, 0.16357854008674622]]  argmax_id: 4
Label probs: [[0.1597650796175003, 0.16165150701999664, 0.16875937581062317, 0.172272726893425, 0.1734202355146408, 0.1641310751438141]]  argmax_id

  2%|▉                                     | 64/2669 [21:33<14:31:16, 20.07s/it]Label probs: [[0.16551361978054047, 0.15859189629554749, 0.16695503890514374, 0.1680910289287567, 0.171645388007164, 0.16920310258865356]]  argmax_id: 4
Label probs: [[0.16502784192562103, 0.15800915658473969, 0.16615310311317444, 0.17090047895908356, 0.17251871526241302, 0.1673906296491623]]  argmax_id: 4
Label probs: [[0.16553781926631927, 0.15991619229316711, 0.16693115234375, 0.16756129264831543, 0.1731206178665161, 0.1669328361749649]]  argmax_id: 4
Label probs: [[0.16402043402194977, 0.1584988832473755, 0.1653171330690384, 0.17274583876132965, 0.17221881449222565, 0.16719891130924225]]  argmax_id: 3
Label probs: [[0.16595660150051117, 0.1597747504711151, 0.16866855323314667, 0.1710214465856552, 0.16935139894485474, 0.1652272641658783]]  argmax_id: 3
Label probs: [[0.16536778211593628, 0.15769138932228088, 0.16724520921707153, 0.17048639059066772, 0.17319153249263763, 0.16601774096488953]]  argmax_id: 

  3%|▉                                     | 68/2669 [22:53<14:25:56, 19.98s/it]Label probs: [[0.1637546271085739, 0.16301225125789642, 0.1683395355939865, 0.17264650762081146, 0.16812075674533844, 0.16412632167339325]]  argmax_id: 3
Label probs: [[0.1648194044828415, 0.16408610343933105, 0.16989904642105103, 0.16916532814502716, 0.16754788160324097, 0.1644822359085083]]  argmax_id: 2
Label probs: [[0.16588431596755981, 0.16678164899349213, 0.1712152361869812, 0.1635669469833374, 0.16971807181835175, 0.16283376514911652]]  argmax_id: 2
Label probs: [[0.16642053425312042, 0.16279520094394684, 0.16551807522773743, 0.1632109433412552, 0.17640988528728485, 0.16564537584781647]]  argmax_id: 4
Label probs: [[0.16692428290843964, 0.16465283930301666, 0.1707342118024826, 0.1701907515525818, 0.16488032042980194, 0.16261760890483856]]  argmax_id: 2
Label probs: [[0.16559889912605286, 0.16543462872505188, 0.17080676555633545, 0.16535444557666779, 0.17095793783664703, 0.16184735298156738]]  argmax

  3%|█                                     | 72/2669 [24:13<14:24:57, 19.98s/it]Label probs: [[0.16749659180641174, 0.15850767493247986, 0.16061590611934662, 0.1698455512523651, 0.17293241620063782, 0.17060185968875885]]  argmax_id: 4
Label probs: [[0.16653203964233398, 0.1566142439842224, 0.1615203619003296, 0.17129118740558624, 0.1743827760219574, 0.16965937614440918]]  argmax_id: 4
Label probs: [[0.1644170880317688, 0.15942427515983582, 0.16780215501785278, 0.16822339594364166, 0.17692236602306366, 0.1632106900215149]]  argmax_id: 4
Label probs: [[0.16553252935409546, 0.15917633473873138, 0.16696113348007202, 0.16726864874362946, 0.1786360889673233, 0.16242524981498718]]  argmax_id: 4
Label probs: [[0.16394545137882233, 0.1570846438407898, 0.16148021817207336, 0.17136038839817047, 0.17526529729366302, 0.1708640158176422]]  argmax_id: 4
Label probs: [[0.16361583769321442, 0.15957263112068176, 0.1619357317686081, 0.17230437695980072, 0.17348498106002808, 0.1690864861011505]]  argmax_i

## 验证集数据增强

In [58]:
!bash  script/run_data_aug_val.sh

  0%|                                                  | 0/3413 [00:00<?, ?it/s]Label probs: [[0.1670745611190796, 0.15980379283428192, 0.1625368446111679, 0.16763070225715637, 0.17469225823879242, 0.1682618409395218]]  argmax_id: 4
  0%|                                       | 1/3413 [00:23<22:35:38, 23.84s/it]Label probs: [[0.1648598611354828, 0.16240912675857544, 0.16564446687698364, 0.1688956916332245, 0.1720568984746933, 0.16613395512104034]]  argmax_id: 4
  0%|                                       | 2/3413 [00:45<21:19:56, 22.51s/it]Label probs: [[0.16584745049476624, 0.15640753507614136, 0.16592133045196533, 0.16833040118217468, 0.17283488810062408, 0.1706583946943283]]  argmax_id: 4
  0%|                                       | 3/3413 [01:05<20:15:30, 21.39s/it]Label probs: [[0.1661825031042099, 0.15741696953773499, 0.16467656195163727, 0.16829253733158112, 0.17511403560638428, 0.16831746697425842]]  argmax_id: 4
  0%|                                       | 4/3413 [01:25<19:4

# 模型训练

## 生成CLIP使用的对比学习Text-Image Pair

In [4]:
!sh script/run_aug_parser.sh
!sh script/run_split_parser.sh

data/augmented_ViT-bigG-14_train_label.txt is sucessfully parserd
parsed contents are saved to src/train/data/train_car_data.csv
data/augmented_ViT-bigG-14_val_label.txt is sucessfully parserd
parsed contents are saved to src/train/data/val_car_data.csv
data/datasets/train/train_person_label.txt is successfully parsed
parsed contents are saved to src/train/data/train_person_data.csv
data/datasets/val/val_person_label.txt is successfully parsed
parsed contents are saved to src/train/data/val_person_data.csv


In [5]:
!sh script/run_merge_csv.sh

src/train/data/train_car_data.csv and src/train/data/train_person_data.csv are merged to src/train/data/train_data.csv
src/train/data/val_car_data.csv and src/train/data/val_person_data.csv are merged to src/train/data/val_data.csv


In [16]:
!sh script/run_model.sh

['--save-frequency', '5', '--zeroshot-frequency', '0', '--report-to', 'tensorboard', '--train-data', 'src/train/data/val_data.csv', '--val-data', 'src/train/data/val_data.csv', '--dataset-type', 'csv', '--csv-separator', '\\t', '--csv-img-key', 'filepath', '--csv-caption-key', 'title', '--warmup', '10000', '--batch-size', '25', '--lr', '4e-7', '--wd', '0.01', '--epochs', '30', '--workers', '15', '--model', 'ViT-L-14', '--pretrained', 'openai']
2023-05-19,20:13:22 | INFO | Running with a single process. Device cuda:0.
2023-05-19,20:13:22 | INFO | Loading pretrained ViT-L-14 from OpenAI.
2023-05-19,20:13:28 | INFO | Model:
2023-05-19,20:13:28 | INFO | CLIP(
  (visual): VisionTransformer(
    (patchnorm_pre_ln): Identity()
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0): ResidualAtten

                                      filepath             title
0  data/datasets/val/val_images/kcbluaihml.jpg  white Audi Sedan
1  data/datasets/val/val_images/rbidlukstq.jpg  white Audi Sedan
2  data/datasets/val/val_images/mtblifktzu.jpg  yellow BMW Sedan
3  data/datasets/val/val_images/aopgmhgwft.jpg  yellow BMW Sedan
4  data/datasets/val/val_images/ebhftfswxw.jpg  yellow BMW Sedan
filepath
['filepath' 'title']
Done loading data.
Loading csv data from src/train/data/val_data.csv.
  return func(*args, **kwargs)
                                      filepath             title
0  data/datasets/val/val_images/kcbluaihml.jpg  white Audi Sedan
1  data/datasets/val/val_images/rbidlukstq.jpg  white Audi Sedan
2  data/datasets/val/val_images/mtblifktzu.jpg  yellow BMW Sedan
3  data/datasets/val/val_images/aopgmhgwft.jpg  yellow BMW Sedan
4  data/datasets/val/val_images/ebhftfswxw.jpg  yellow BMW Sedan
filepath
['filepath' 'title']
Done loading data.
2023-05-19 20:13:29.019350: I tensorflow

# 生成结果文件

## 复现检查

In [30]:
!sh script/run_infer.sh

model loading...
Compose(
    Lambda()
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7ff67d7d2170>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
data loading...
100%|████████████████████████████████| 10000/10000 [00:00<00:00, 3529667.59it/s]
100%|████████████████████████████████| 10000/10000 [00:00<00:00, 3469521.05it/s]
query text number: 10000
key image number: 17611
start loading text features
100%|█████████████████████████████████████████████| 5/5 [01:26<00:00, 17.36s/it]
start loading image features
  0%|                                                    | 0/89 [00:22<?, ?it/s]
similarity_argsort.shape (200, 10000)
model loading...
Compose(
    Lambda()
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7faac3cf1170>
    ToTenso

In [32]:
!python src/infer/merge_json.py

## 最佳结果一键运行
- 下载Checkpoint到

In [None]:
!export CUDA_VISIBLE_DEVICES=0; python src/infer/open_clip_infer.py --model_name ViT-bigG-14 --pt_path /data1/code/yyg/image_retrival/yyg/open_clip_G_real/open_clip-G/open_clip/logs/epoch_5.pt --image_root data/datasets/test/test_images/ --text_path data/datasets/test/test_person_text.txt --run_name open_clip_person_infer --text_steps 2000 --image_batch 2000 --topk 10

In [None]:
!export CUDA_VISIBLE_DEVICES=1; python src/infer/open_clip_infer.py --model_name ViT-bigG-14 --pt_path /data1/code/yyg/image_retrival/yyg/open_clip_G_real/open_clip-G/open_clip/logs/epoch_5.pt --image_root data/datasets/test/test_images/ --text_path data/datasets/test/test_car_text.txt --run_name open_clip_car_infer --text_steps 2000 --image_batch 2000 --topk 10