In [14]:
# import Sent2textDataset
import torch
from sklearn.metrics import precision_score
from torch.utils.data import DataLoader


from torch.utils.tensorboard import SummaryWriter



import pandas as pd
from sklearn.metrics import f1_score
from CustomDataset import Sent2textDataset
from clip.evaluate.utils import load_weights_only
from trochvisions import transforms
from torch.optim import lr_scheduler
# !rm -rf  data/images
# !ls data/
# !mkdir data/images

# model, args = load_weights_only("ViT-B/32-small",seq_length = 15)

200k_data_preproc.csv  60k_data_preproc.csv  images.json  metadata.json


In [15]:
from load_model import load_model

In [16]:
import numpy as np
from torch.utils.data import SubsetRandomSampler
def sempler(data_train, batch_size = 4,split = .8):
    
    data_size = len(data_train)

    validation_split = split
    split = int(np.floor(validation_split * data_size))
    indices = list(range(data_size))
    np.random.shuffle(indices)

    train_indices,val_indices = indices[split:],indices[:split]

    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    

    train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size,
                                              sampler=train_sampler,)
    
    val_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size,
                                            sampler=val_sampler,)

    return train_loader,val_loader

In [17]:
def train_model(model, train_loader, val_loader, loss, optimizer, num_epochs, writer):
    loss_history = []
    train_history = []
    val_history = []
    val_loss_hist = []
    metric_y_val = metric_p_val = None
    
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                  T_0=10, 
                                                  T_mult=2,
                                                  eta_min=1e-9)
    for epoch in range(num_epochs):
        
        print(epoch)
        model.train()
        
        correct_samples = 0
        total_samples = 0
        loss_accum = 0
        for i_step, data  in enumerate(train_loader):
            
            
                imgs_gpu = torch.squeeze(data[0].cuda(),1)
                texts_gpu = data[1].cuda()
                att_mask_gpu = data[2].cuda()
                label_gpu = torch.arange(data[0].shape[0]).cuda()
                
                
                prediction ,_ = model(img_input={"x": imgs_gpu},
                                    text_input={"x": texts_gpu, "attention_mask":att_mask_gpu})
                
            
                loss_value = loss(prediction, label_gpu)
                
                _, preds = torch.max(prediction, 1)
                

                optimizer.zero_grad()
                loss_value.backward()
                optimizer.step()
                
                
                if i_step == 0 and epoch == 0:
                    metric_y = label_gpu.cpu().numpy()
                    metric_p = preds.cpu().numpy()
                else:
                    metric_y = np.concatenate((metric_y, label_gpu.cpu().numpy()))
                    metric_p = np.concatenate((metric_p, preds.cpu().numpy())) 
                    
                correct_samples += torch.sum(preds == label_gpu)
                loss_accum += loss_value
                total_samples += label_gpu.shape[0]
                
                del imgs_gpu
                del texts_gpu
                del att_mask_gpu
            
                del label_gpu
        
        
        ave_loss = loss_accum / (i_step + 1)
        train_accuracy = correct_samples / total_samples
        writer.add_scalar("Loss/train", ave_loss, epoch)
        writer.add_scalar("Acc/train", train_accuracy, epoch)

        val_accuracy, loss_val,metric_y_val, metric_p_val = compute_valid(model, val_loader, loss, epoch,metric_y_val, metric_p_val)
        writer.add_scalar("Loss/valid", loss_val, epoch)
        writer.add_scalar("Acc/valid", val_accuracy, epoch)
        
        writer.add_scalar("Lr/epoch", scheduler.get_last_lr()[-1], epoch)
        scheduler.step(epoch)
        
        loss_history.append(float(ave_loss))
        train_history.append(train_accuracy)
        val_history.append(val_accuracy)
        val_loss_hist.append(loss_val)
        if epoch >= 2 and val_loss_hist[-1] < min(val_loss_hist):
            torch.save(model.state_dict(),f"GPTsmall_epoch_{epoch}.pt")
            

        print("Average loss: %f, Val loss: %f, Train accuracy: %f, Val accuracy: %f" % (ave_loss,loss_val, train_accuracy, val_accuracy))
    return metric_y_val, metric_p_val
        print('Epoch:', epoch, 'LR:', scheduler.get_last_lr())

def compute_valid(model, loader, loss,epoch, metric_y= None, metric_p=None):
    model.eval()
    with torch.no_grad():
        correct_samples = 0
        total_samples = 0
        loss_accum = 0
        
        for i_step, data in enumerate(loader):
            
            imgs_gpu = torch.squeeze(data[0].cuda(),1)
            texts_gpu = data[1].cuda()
            att_mask_gpu = data[2].cuda()
            label_gpu = torch.arange(data[0].shape[0]).cuda()

            prediction, _ = model(img_input={"x": imgs_gpu},
                                    text_input={"x": texts_gpu, "attention_mask":att_mask_gpu})
            
            
            loss_value = loss(prediction, label_gpu)
            _, preds = torch.max(prediction, 1)
            
            if i_step == 0 and epoch == 0:
                metric_y = label_gpu.cpu().numpy()
                metric_p = preds.cpu().numpy()
            else:
                metric_y = np.concatenate((metric_y, label_gpu.cpu().numpy()))
                metric_p = np.concatenate((metric_p, preds.cpu().numpy())) 
            
            
            correct_samples += torch.sum(preds == label_gpu)
            total_samples += label_gpu.shape[0]
            loss_accum += loss_value

            del imgs_gpu
            del texts_gpu
            del att_mask_gpu
            
            del label_gpu
                
        loss_val = loss_accum / (i_step + 1)
        val_accuracy = correct_samples / total_samples
        return val_accuracy, loss_val, metric_y, metric_p

In [5]:
# model, img_transfrom, text_tokenizer = load_model()
model, img_transfrom, text_tokenizer = load_model()



In [None]:
path_t_csv = pd.read_csv("data/200k_data_preproc.csv")#
path_i_json = "data/images.json"
path_i_folder = "data/images"


# в csv file для каждого касса должен быть массив


ds = Sent2textDataset(path_t_csv,path_i_json,
                      path_i_folder,
                      text_tokenizer,img_transfrom,
                      down_data = True,check_img = True,
                      n_classes = 5)

5462418it [00:11, 466104.95it/s]


Oyy response empty, miss 5796719
Oyy, miss 4109461
Oyy, miss 6023162
Oyy response empty, miss 2776781
Oyy, miss 1357095
Oyy response empty, miss 3999838
Oyy response empty, miss 1839073
Oyy, miss 2565899
Oyy response empty, miss 633849
Oyy response empty, miss 5361644
Oyy response empty, miss 4896271
Oyy, miss 3880284
Oyy response empty, miss 2046252
Oyy response empty, miss 6220633
Oyy response empty, miss 1400127
Oyy response empty, miss 5244823
Oyy response empty, miss 6139374
Oyy response empty, miss 3534040
Oyy response empty, miss 1927122
Oyy response empty, miss 1827343
Oyy response empty, miss 4941340
Oyy, miss 1980876
Oyy response empty, miss 3695509
Oyy, miss 5832403
Oyy, miss 1466996
Oyy response empty, miss 548321
Oyy, miss 6210434
Oyy response empty, miss 2989158
Oyy response empty, miss 5047225
Oyy, miss 1393991
Oyy response empty, miss 3315244
Oyy, miss 1093804
Oyy, miss 5487770
Oyy response empty, miss 4512688
Oyy response empty, miss 5492334
Oyy response empty, miss 26

Oyy response empty, miss 4552910
Oyy response empty, miss 5768356
Oyy response empty, miss 4323560
Oyy response empty, miss 1638721
Oyy response empty, miss 2986742
Oyy response empty, miss 4893451
Oyy, miss 3006553
Oyy response empty, miss 2869233
Oyy, miss 947997
Oyy response empty, miss 1723191
Oyy response empty, miss 537697
Oyy response empty, miss 2819576
Oyy response empty, miss 2471477
Oyy response empty, miss 1033474
Oyy, miss 3057680
Oyy response empty, miss 2056451
Oyy response empty, miss 5916407
Oyy response empty, miss 804237
Oyy response empty, miss 1060070
Oyy response empty, miss 2024046
Oyy response empty, miss 739737
Oyy response empty, miss 4145617
Oyy response empty, miss 3855383
Oyy response empty, miss 1568524
Oyy response empty, miss 2988809
Oyy response empty, miss 4097888
Oyy, miss 6034877
Oyy, miss 1998112
Oyy, miss 3462392
Oyy response empty, miss 1416731
Oyy, miss 4765483
Oyy, miss 5018138
Oyy response empty, miss 6105533
Oyy response empty, miss 5588425
Oy

Oyy, miss 5769145
Oyy, miss 625154
Oyy response empty, miss 768147
Oyy, miss 3544002
Oyy, miss 5084501
Oyy, miss 1515447
Oyy, miss 5166346
Oyy, miss 3285368
Oyy, miss 3800158
Oyy, miss 2643684
Oyy, miss 5567636
Oyy, miss 592716
Oyy response empty, miss 3269467
Oyy, miss 644976
Oyy, miss 1843449
Oyy, miss 1504774
Oyy, miss 1366943
Oyy, miss 4267223
Oyy, miss 199674
Oyy, miss 4799224
Oyy, miss 4989471
Oyy, miss 5372613
Oyy, miss 852577
Oyy, miss 99485
Oyy, miss 3603837
Oyy, miss 6090708
Oyy, miss 5642052
Oyy, miss 3398784
Oyy, miss 3942513
Oyy, miss 1510765
Oyy, miss 309961
Oyy, miss 4200646
Oyy response empty, miss 5298409
Oyy, miss 6169125
Oyy, miss 654185
Oyy, miss 3338520
Oyy, miss 4801412
Oyy, miss 3963057
Oyy, miss 1416841
Oyy, miss 1934523
Oyy response empty, miss 2169906
Oyy, miss 91961
Oyy, miss 5038607
Oyy, miss 3846810
Oyy, miss 4499243
Oyy, miss 2411366
Oyy, miss 5938787
Oyy response empty, miss 3366986
Oyy, miss 2508021
Oyy, miss 1742692
Oyy, miss 1009476
Oyy, miss 4646478
O

Oyy response empty, miss 548782
Oyy, miss 4713202
Oyy response empty, miss 2392300
Oyy response empty, miss 5846022
Oyy response empty, miss 582308
Oyy response empty, miss 5173286
Oyy response empty, miss 2085669
Oyy response empty, miss 5468022
Oyy response empty, miss 5741522
Oyy response empty, miss 2896802
Oyy response empty, miss 5616426
Oyy, miss 5150227
Oyy response empty, miss 4243815
Oyy response empty, miss 2831094
Oyy response empty, miss 2445667
Oyy response empty, miss 4250893
Oyy, miss 2651068
Oyy response empty, miss 5670250
Oyy response empty, miss 945200
Oyy response empty, miss 589372
Oyy response empty, miss 1777637
Oyy response empty, miss 1598817
Oyy response empty, miss 4549367
Oyy, miss 5818213
Oyy, miss 5275888
Oyy response empty, miss 5280413
Oyy response empty, miss 930345
Oyy, miss 2102791
Oyy, miss 1232554
Oyy, miss 5884384
Oyy, miss 1526732
Oyy response empty, miss 4517850
Oyy response empty, miss 2915096
Oyy, miss 1864947
Oyy response empty, miss 2559307


Oyy response empty, miss 5498994
Oyy response empty, miss 5320780
Oyy, miss 1894426
Oyy, miss 4484824
Oyy response empty, miss 4456023
Oyy response empty, miss 2657099
Oyy, miss 894829
Oyy, miss 5219920
Oyy response empty, miss 267593
Oyy, miss 4025358
Oyy response empty, miss 4417609
Oyy response empty, miss 4939466
Oyy, miss 2919933
Oyy response empty, miss 2291476
Oyy response empty, miss 3892618
Oyy, miss 5278985
Oyy, miss 3817477
Oyy response empty, miss 2430898
Oyy, miss 4840494
Oyy response empty, miss 334817
Oyy, miss 2742647
Oyy response empty, miss 948710
Oyy response empty, miss 1698765
Oyy, miss 3437491
Oyy response empty, miss 3759328
Oyy, miss 2624112
Oyy response empty, miss 801717
Oyy response empty, miss 1354462
Oyy response empty, miss 6032897
Oyy, miss 4254553
Oyy, miss 4923381
Oyy response empty, miss 1426257
Oyy response empty, miss 1818444
Oyy response empty, miss 4005657
Oyy response empty, miss 3654087
Oyy response empty, miss 2052287
Oyy response empty, miss 46

Oyy response empty, miss 3732104
Oyy response empty, miss 649497
Oyy, miss 5340864Oyy response empty, miss 1148192

Oyy response empty, miss 4129607
Oyy response empty, miss 4143810
Oyy response empty, miss 4021018
Oyy, miss 1377368
Oyy, miss 1978226
Oyy, miss 1042776
Oyy, miss 1762778
Oyy response empty, miss 1694389
Oyy response empty, miss 2089420
Oyy response empty, miss 4470202
Oyy response empty, miss 6119963
Oyy response empty, miss 3278711
Oyy, miss 670194
Oyy, miss 4757489
Oyy, miss 2729489
Oyy response empty, miss 4773114
Oyy response empty, miss 352381
Oyy, miss 4756294
Oyy, miss 1076908
Oyy response empty, miss 3407875
Oyy, miss 4947810
Oyy response empty, miss 3949763
Oyy response empty, miss 5574460
Oyy, miss 4929761
Oyy, miss 3022124
Oyy response empty, miss 3304959
Oyy, miss 3343318
Oyy response empty, miss 1213262
Oyy, miss 3072474
Oyy response empty, miss 3150690
Oyy response empty, miss 1693993
Oyy response empty, miss 5868255
Oyy response empty, miss 308954
Oyy resp

Oyy response empty, miss 5091938
Oyy response empty, miss 226054
Oyy response empty, miss 1637032
Oyy response empty, miss 3009053
Oyy, miss 648225
Oyy response empty, miss 5085298
Oyy, miss 5700519
Oyy response empty, miss 1925763
Oyy, miss 3219869
Oyy, miss 2451545
Oyy response empty, miss 5541222
Oyy, miss 3695178
Oyy response empty, miss 5527291
Oyy, miss 4203192
Oyy response empty, miss 2058398
Oyy, miss 2455316
Oyy response empty, miss 3371148
Oyy response empty, miss 989033
Oyy response empty, miss 2541555
Oyy, miss 3447981
Oyy response empty, miss 3349725
Oyy, miss 2647962
Oyy response empty, miss 5108769
Oyy response empty, miss 5014630
Oyy response empty, miss 1884713
Oyy, miss 6103481
Oyy response empty, miss 3818615
Oyy response empty, miss 695124
Oyy, miss 1244202
Oyy, miss 2477866
Oyy, miss 1798315
Oyy response empty, miss 6123390
Oyy response empty, miss 1312853
Oyy response empty, miss 3597139
Oyy response empty, miss 1473972
Oyy, miss 2048535
Oyy response empty, miss 2

Oyy response empty, miss 2169563
Oyy response empty, miss 5490210
Oyy, miss 1307893
Oyy response empty, miss 2814015Oyy response empty, miss 5231286

Oyy, miss 5797489
Oyy response empty, miss 731930
Oyy response empty, miss 830908
Oyy, miss 184421
Oyy, miss 123253
Oyy, miss 2919538
Oyy, miss 5827454
Oyy response empty, miss 2887123
Oyy response empty, miss 2121497
Oyy response empty, miss 5077622
Oyy, miss 2702192
Oyy response empty, miss 4157752
Oyy, miss 5372736
Oyy response empty, miss 1823527
Oyy, miss 1923147
Oyy response empty, miss 2577393
Oyy, miss 3649354
Oyy, miss 3142070
Oyy, miss 3773753
Oyy, miss 222704
Oyy, miss 4721472
Oyy, miss 3720212
Oyy response empty, miss 3540806
Oyy response empty, miss 270673
Oyy, miss 3266830
Oyy response empty, miss 3754863
Oyy response empty, miss 3347618
Oyy response empty, miss 717126
Oyy response empty, miss 1287506
Oyy, miss 6166663
Oyy, miss 4027755
Oyy response empty, miss 5242049
Oyy, miss 3003355
Oyy response empty, miss 4837571
Oyy r

Oyy, miss 6042387
Oyy response empty, miss 3000849
Oyy response empty, miss 3570070
Oyy response empty, miss 2166741
Oyy, miss 4483642
Oyy, miss 5687362
Oyy response empty, miss 4266071
Oyy response empty, miss 572091
Oyy response empty, miss 1359799
Oyy, miss 4860415
Oyy response empty, miss 3223571Oyy response empty, miss 1872791

Oyy response empty, miss 3518728
Oyy response empty, miss 2546293
Oyy response empty, miss 5176546
Oyy, miss 1760648
Oyy, miss 5407853
Oyy response empty, miss 5657396
Oyy, miss 5837368
Oyy, miss 5704820
Oyy, miss 4531032
Oyy, miss 26641
Oyy, miss 474213
Oyy, miss 6027392
Oyy response empty, miss 5557311
Oyy response empty, miss 5751410
Oyy, miss 400288
Oyy, miss 2002986
Oyy response empty, miss 865878
Oyy response empty, miss 6124376
Oyy response empty, miss 3925291
Oyy, miss 2839596
Oyy response empty, miss 3825432
Oyy, miss 3413873
Oyy response empty, miss 3719885
Oyy response empty, miss 5417012
Oyy response empty, miss 434845
Oyy response empty, miss 1

Oyy response empty, miss 867094
Oyy response empty, miss 4929276
Oyy response empty, miss 3437090
Oyy response empty, miss 4540441
Oyy response empty, miss 6003716
Oyy, miss 1037081
Oyy response empty, miss 3007970
Oyy response empty, miss 1728194
Oyy, miss 5437180
Oyy response empty, miss 5504655
Oyy, miss 6123866
Oyy response empty, miss 4520421
Oyy, miss 2880903
Oyy, miss 2136454
Oyy response empty, miss 3257971
Oyy response empty, miss 4638405
Oyy response empty, miss 3818413
Oyy response empty, miss 2653732
Oyy response empty, miss 1399396
Oyy response empty, miss 3452455
Oyy, miss 2470704
Oyy, miss 573048
Oyy, miss 1011407
Oyy response empty, miss 3654859
Oyy response empty, miss 3205653
Oyy response empty, miss 6046804
Oyy response empty, miss 1174611
Oyy, miss 2932086
Oyy response empty, miss 3785928
Oyy response empty, miss 5595818
Oyy response empty, miss 6057521
Oyy, miss 3305350
Oyy, miss 1796358
Oyy response empty, miss 402974
Oyy response empty, miss 5897389
Oyy, miss 523

Oyy, miss 5278416
Oyy, miss 496272
Oyy response empty, miss 2892969
Oyy, miss 2754783
Oyy response empty, miss 2605717
Oyy, miss 5000110
Oyy, miss 5395411
Oyy, miss 6120274
Oyy, miss 2485587
Oyy, miss 4818973
Oyy response empty, miss 2866157
Oyy response empty, miss 1433426
Oyy response empty, miss 4550287
Oyy response empty, miss 1162755
Oyy response empty, miss 4637623
Oyy, miss 1598572
Oyy, miss 4555853
Oyy response empty, miss 5528790
Oyy, miss 2806728
Oyy, miss 696103
Oyy, miss 5022750
Oyy response empty, miss 3720240
Oyy response empty, miss 5641546
Oyy, miss 978415
Oyy response empty, miss 407168
Oyy, miss 4772586Oyy response empty, miss 4165582

Oyy response empty, miss 1726925
Oyy, miss 3218633
Oyy response empty, miss 352857
Oyy, miss 12936
Oyy response empty, miss 1904837
Oyy response empty, miss 1276300
Oyy response empty, miss 2839317
Oyy, miss 756115
Oyy, miss 3801074
Oyy response empty, miss 2074272
Oyy, miss 876824
Oyy response empty, miss 2231737
Oyy response empty, mi

In [7]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

In [8]:
model = model.cuda().float()
for name,param in model.named_parameters():
    if name.find("visual_encoder") != -1:
        param.requires_grad = False
    if name.find("text_encoder")!= -1:
        param.requires_grad = False
        
#     if name.find("8")!=-1 and name.find("visual_encoder")== -1:
#         param.requires_grad = True
        
    if name.find("projection")!=-1:
        param.requires_grad = True
    if name == "logit_scale":
        param.requires_grad = True
        
    
[name for name,param in model.named_parameters() if param.requires_grad ]

['logit_scale',
 'visual_encoder.projection.linear1.weight',
 'visual_encoder.projection.linear2.weight',
 'visual_encoder.projection.layer_norm.weight',
 'visual_encoder.projection.layer_norm.bias',
 'text_encoder.projection.linear1.weight',
 'text_encoder.projection.linear2.weight',
 'text_encoder.projection.layer_norm.weight',
 'text_encoder.projection.layer_norm.bias']

In [9]:
loss = CrossEntropyLoss()
optimizer = Adam([param for param in model.parameters() if param.requires_grad], lr = 1e-03, weight_decay = 5e-04)


In [10]:
train_dl,val_dl =  sempler(ds, batch_size=20)

In [13]:
writer = SummaryWriter()
result = train_model(model, train_dl, val_dl, loss, optimizer, 64,writer)
writer.flush()
writer.close()




0


KeyboardInterrupt: 