In [6]:
from datetime import datetime
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import time
# from torch.cuda.amp import autocast
import numpy as np
from CVUSA_dataset import CVUSA_dataset_cropped, CVUSA_Dataset_Eval
# from CVUSA_dataset import CVUSA_Dataset_Eval
from custom_models import ResNet, VIT, CLIP_model
from losses import Contrastive_loss, SoftTripletBiLoss, InfoNCE
from train import train
from eval import predict, accuracy, calculate_scores
import torch.nn.functional as F
import copy
import math
from pytorch_metric_learning import losses as LS
from helper_func import get_rand_id, hyparam_info, save_losses
from transformers import CLIPProcessor




# data_path = '/media/fahimul/2B721C03261BDC8D/Research/datasets/CVUSA' #don't include the / at the end
# data_path = '/home/fa947945/datasets/CVUSA_Cropped/CVUSA' #don't include the / at the end
data_path = '/data/Research/Dataset/CVUSA_Cropped/CVUSA' #don't include the / at the end

# train_data= pd.read_csv(f'{data_path}/splits/train-19zl.csv')
train_data= pd.read_csv(f'{data_path}/splits/train-19zl_30.csv')
val_data= pd.read_csv(f'{data_path}/splits/val-19zl.csv')

# df_loss = pd.DataFrame(columns=['Loss'])

transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])


train_ds = CVUSA_dataset_cropped(df = train_data, path=data_path, transform=transform)
val_que = CVUSA_Dataset_Eval(data_folder=data_path, split='val', img_type='query', transforms=transform)
val_ref = CVUSA_Dataset_Eval(data_folder=data_path, split='val', img_type='reference', transforms=transform)





def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    embed_dim = 1000
    lr = 0.00001
    batch_size = 64
    epochs = 100
    expID = get_rand_id()
    loss_margin = 1


    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
    val_loader_que = DataLoader(val_que, batch_size=batch_size, shuffle=False)
    val_loader_ref = DataLoader(val_ref, batch_size=batch_size, shuffle=False)

    os.mkdir(f'model_weights/{expID}')

    # model = ResNet(emb_dim=embed_dim).to(device)
    # model_r = ResNet(emb_dim=embed_dim).to(device)
    # model_q = ResNet(emb_dim=embed_dim).to(device)

    # model = ResNet().to(device)
    # model = VIT().to(device)
    model = CLIP_model(embed_dim=embed_dim)
    
    # torch.save(model, f'model_weights/{expID}/model_st.pth')

    # criterion = TripletLoss(margin=loss_margin)
    # criterion = nn.TripletMarginLoss(margin=0.5)
  
    # criterion = SoftTripletBiLoss()

    loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
    criterion = InfoNCE(loss_function=loss_fn,
                            device=device,
                            )



    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(name)
    optimizer = optim.Adam(parameters, lr=lr)
    # optimizer = optim.AdamW(parameters, lr=lr)
    # optimizer = optim.SGD(parameters, lr=lr)

    
    
    hyparam_info(emb_dim=embed_dim, loss_id=expID, ln_rate=lr, batch=batch_size, epc=epochs, ls_mrgn=loss_margin, trn_sz=train_data.shape[0], mdl_nm=model.modelName)

    print("Training Start")
    all_loses = train(model, criterion, optimizer, train_loader, num_epochs=epochs, dev=device)
    df_loss = pd.DataFrame({'Loss': all_loses})
    df_loss.to_csv(f'losses/losses_{expID}.csv')



    print("\nExtract Features:")
    query_features, query_labels = predict(model=model, dataloader=val_loader_que, dev=device, isQuery=True)
    reference_features, reference_labels = predict(model = model, dataloader=val_loader_ref, dev=device, isQuery=False) 
    


    print("Compute Scores:")
    # r1 =  calculate_scores(query_features, reference_features, query_labels, reference_labels, step_size=1000, ranks=[1, 5, 10])
    r1 =  accuracy(query_features=query_features, reference_features=reference_features, query_labels=query_labels, topk=[1, 5, 10])
    
    save_losses(df=df_loss, 
                emb_dim=embed_dim, 
                loss_id=expID, 
                ln_rate=lr, 
                batch=batch_size, 
                epc=epochs, 
                ls_mrgn=loss_margin, 
                trn_sz=train_data.shape[0],
                mdl_nm=model.modelName,
                rslt=r1)


    print(r1) 
        

    torch.save(model, f'model_weights/{expID}/model_tr.pth')






if __name__ == '__main__':
    main()




Device: cuda





Hyperparameter info:
Exp ID: 6318896
Embedded dimension: 1000
Learning rate: 1e-05
Batch Size: 64
Loss Margin: 1
Epoch: 100
Training Size: 10658
Model Name: CLIP


Training Start

Date: 2024-05-21 15:14:57.089868

Epoch#0


100%|██████████| 167/167 [04:22<00:00,  1.57s/it]


Epoch: 1/100 Loss: 3.0422379970550537
Epoch#1


100%|██████████| 167/167 [04:16<00:00,  1.53s/it]


Epoch: 2/100 Loss: 2.707144260406494
Epoch#2


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 3/100 Loss: 2.6057398319244385
Epoch#3


100%|██████████| 167/167 [04:02<00:00,  1.45s/it]


Epoch: 4/100 Loss: 2.5404345989227295
Epoch#4


100%|██████████| 167/167 [04:02<00:00,  1.45s/it]


Epoch: 5/100 Loss: 2.4713857173919678
Epoch#5


100%|██████████| 167/167 [03:58<00:00,  1.43s/it]


Epoch: 6/100 Loss: 2.447958469390869
Epoch#6


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 7/100 Loss: 2.431533098220825
Epoch#7


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 8/100 Loss: 2.38909912109375
Epoch#8


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 9/100 Loss: 2.3742265701293945
Epoch#9


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 10/100 Loss: 2.372711420059204
Epoch#10


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 11/100 Loss: 2.3430628776550293
Epoch#11


100%|██████████| 167/167 [03:59<00:00,  1.43s/it]


Epoch: 12/100 Loss: 2.3155221939086914
Epoch#12


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 13/100 Loss: 2.299579620361328
Epoch#13


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 14/100 Loss: 2.28552508354187
Epoch#14


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 15/100 Loss: 2.2816250324249268
Epoch#15


100%|██████████| 167/167 [04:03<00:00,  1.46s/it]


Epoch: 16/100 Loss: 2.2694921493530273
Epoch#16


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 17/100 Loss: 2.248924493789673
Epoch#17


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 18/100 Loss: 2.236384391784668
Epoch#18


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 19/100 Loss: 2.220731735229492
Epoch#19


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 20/100 Loss: 2.2112810611724854
Epoch#20


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 21/100 Loss: 2.207322835922241
Epoch#21


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 22/100 Loss: 2.20607852935791
Epoch#22


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 23/100 Loss: 2.196237802505493
Epoch#23


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 24/100 Loss: 2.1787350177764893
Epoch#24


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 25/100 Loss: 2.1718051433563232
Epoch#25


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 26/100 Loss: 2.164632558822632
Epoch#26


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 27/100 Loss: 2.1537537574768066
Epoch#27


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 28/100 Loss: 2.146491050720215
Epoch#28


100%|██████████| 167/167 [03:59<00:00,  1.43s/it]


Epoch: 29/100 Loss: 2.1498870849609375
Epoch#29


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 30/100 Loss: 2.1483380794525146
Epoch#30


100%|██████████| 167/167 [03:59<00:00,  1.43s/it]


Epoch: 31/100 Loss: 2.1434521675109863
Epoch#31


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 32/100 Loss: 2.1423890590667725
Epoch#32


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 33/100 Loss: 2.1348018646240234
Epoch#33


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 34/100 Loss: 2.1220738887786865
Epoch#34


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 35/100 Loss: 2.1077494621276855
Epoch#35


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 36/100 Loss: 2.101602077484131
Epoch#36


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 37/100 Loss: 2.098109483718872
Epoch#37


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 38/100 Loss: 2.096135139465332
Epoch#38


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 39/100 Loss: 2.0971009731292725
Epoch#39


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 40/100 Loss: 2.095500946044922
Epoch#40


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 41/100 Loss: 2.0906169414520264
Epoch#41


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 42/100 Loss: 2.0825183391571045
Epoch#42


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 43/100 Loss: 2.0768864154815674
Epoch#43


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 44/100 Loss: 2.074903726577759
Epoch#44


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 45/100 Loss: 2.0694923400878906
Epoch#45


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 46/100 Loss: 2.0648975372314453
Epoch#46


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 47/100 Loss: 2.057082414627075
Epoch#47


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 48/100 Loss: 2.0511910915374756
Epoch#48


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 49/100 Loss: 2.0500054359436035
Epoch#49


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 50/100 Loss: 2.0517473220825195
Epoch#50


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 51/100 Loss: 2.0524280071258545
Epoch#51


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 52/100 Loss: 2.053260564804077
Epoch#52


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 53/100 Loss: 2.0509254932403564
Epoch#53


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 54/100 Loss: 2.0546786785125732
Epoch#54


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 55/100 Loss: 2.0536704063415527
Epoch#55


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 56/100 Loss: 2.0426247119903564
Epoch#56


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 57/100 Loss: 2.0376675128936768
Epoch#57


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 58/100 Loss: 2.0270206928253174
Epoch#58


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 59/100 Loss: 2.0174524784088135
Epoch#59


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 60/100 Loss: 2.013686180114746
Epoch#60


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 61/100 Loss: 2.011704444885254
Epoch#61


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 62/100 Loss: 2.016566753387451
Epoch#62


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 63/100 Loss: 2.0349183082580566
Epoch#63


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 64/100 Loss: 2.03066086769104
Epoch#64


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 65/100 Loss: 2.0166547298431396
Epoch#65


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 66/100 Loss: 2.010258436203003
Epoch#66


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 67/100 Loss: 2.0060670375823975
Epoch#67


100%|██████████| 167/167 [04:02<00:00,  1.45s/it]


Epoch: 68/100 Loss: 2.005100727081299
Epoch#68


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 69/100 Loss: 2.0035455226898193
Epoch#69


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 70/100 Loss: 2.001235008239746
Epoch#70


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 71/100 Loss: 1.9989218711853027
Epoch#71


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 72/100 Loss: 1.9966119527816772
Epoch#72


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 73/100 Loss: 1.995327115058899
Epoch#73


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 74/100 Loss: 1.9956880807876587
Epoch#74


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 75/100 Loss: 1.998033881187439
Epoch#75


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 76/100 Loss: 2.00451397895813
Epoch#76


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 77/100 Loss: 2.0427615642547607
Epoch#77


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 78/100 Loss: 2.024425506591797
Epoch#78


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 79/100 Loss: 2.0323123931884766
Epoch#79


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 80/100 Loss: 2.0282793045043945
Epoch#80


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 81/100 Loss: 1.9957841634750366
Epoch#81


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 82/100 Loss: 1.9774320125579834
Epoch#82


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 83/100 Loss: 1.9704527854919434
Epoch#83


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 84/100 Loss: 1.967181921005249
Epoch#84


100%|██████████| 167/167 [03:59<00:00,  1.43s/it]


Epoch: 85/100 Loss: 1.9669770002365112
Epoch#85


100%|██████████| 167/167 [04:01<00:00,  1.44s/it]


Epoch: 86/100 Loss: 1.9700608253479004
Epoch#86


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 87/100 Loss: 1.975066900253296
Epoch#87


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 88/100 Loss: 1.9823704957962036
Epoch#88


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 89/100 Loss: 1.9857062101364136
Epoch#89


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 90/100 Loss: 1.9835751056671143
Epoch#90


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 91/100 Loss: 1.9807987213134766
Epoch#91


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 92/100 Loss: 1.980095624923706
Epoch#92


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 93/100 Loss: 1.9824103116989136
Epoch#93


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 94/100 Loss: 1.9792791604995728
Epoch#94


100%|██████████| 167/167 [03:58<00:00,  1.43s/it]


Epoch: 95/100 Loss: 1.97551429271698
Epoch#95


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 96/100 Loss: 1.973526120185852
Epoch#96


100%|██████████| 167/167 [04:00<00:00,  1.44s/it]


Epoch: 97/100 Loss: 1.9728883504867554
Epoch#97


100%|██████████| 167/167 [04:01<00:00,  1.45s/it]


Epoch: 98/100 Loss: 1.9727100133895874
Epoch#98


100%|██████████| 167/167 [04:03<00:00,  1.46s/it]


Epoch: 99/100 Loss: 1.9749112129211426
Epoch#99


100%|██████████| 167/167 [03:59<00:00,  1.44s/it]


Epoch: 100/100 Loss: 2.039949655532837

Date: 2024-05-21 21:56:50.592842


Extract Features:


100%|██████████| 139/139 [00:18<00:00,  7.55it/s]
100%|██████████| 139/139 [01:08<00:00,  2.04it/s]


Compute Scores:
Percentage-top1:0.7203962179198559, top5:2.8590724898694284, top10:5.481764970733904, top1%:25.529040972534894, time:0.2287003993988037
[ 0.72039622  2.85907249  5.48176497 25.52904097]
