In [1]:
from datetime import datetime
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import time
# from torch.cuda.amp import autocast
import numpy as np
from CVUSA_dataset import CVUSA_dataset_cropped, CVUSA_Dataset_Eval
# from CVUSA_dataset import CVUSA_Dataset_Eval
from custom_models import ResNet, VIT, CLIP_model
from losses import Contrastive_loss, SoftTripletBiLoss, InfoNCE
from train import train
from eval import predict, accuracy, calculate_scores
import torch.nn.functional as F
import copy
import math
from pytorch_metric_learning import losses as LS
from helper_func import get_rand_id, hyparam_info, save_exp, write_to_file, write_to_rank_file
from transformers import CLIPProcessor
from attributes import Configuration as hypm




# data_path = '/media/fahimul/2B721C03261BDC8D/Research/datasets/CVUSA' #don't include the / at the end
# data_path = '/home/fa947945/datasets/CVUSA_Cropped/CVUSA' #don't include the / at the end
data_path = '/data/Research/Dataset/CVUSA_Cropped/CVUSA' #don't include the / at the end

train_data= pd.read_csv(f'{data_path}/splits/train-19zl.csv', header=None)
# train_data= pd.read_csv(f'{data_path}/splits/train-19zl_5.csv', header=None)
# train_data= pd.read_csv(f'{data_path}/splits/train-19zl_30.csv', header=None)

val_data= pd.read_csv(f'{data_path}/splits/val-19zl.csv', header=None)

# df_loss = pd.DataFrame(columns=['Loss'])

transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])


train_ds = CVUSA_dataset_cropped(df = train_data, path=data_path, transform=transform, train=True, lang=hypm.lang)
val_ds = CVUSA_dataset_cropped(df = val_data, path=data_path, transform=transform, train=False, lang=hypm.lang)

# val_que = CVUSA_Dataset_Eval(data_folder=data_path, split='val', img_type='query', transforms=transform)
# val_ref = CVUSA_Dataset_Eval(data_folder=data_path, split='val', img_type='reference', transforms=transform)





def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    embed_dim = 512
    lr = 0.000001
    batch_size = 64
    epochs = 100
    expID = get_rand_id()
    loss_margin = 1

    hypm.expID = expID






    # print(f"Device: {device}")


    train_loader = DataLoader(train_ds, batch_size=hypm.batch_size, shuffle=False)
    val_loader = DataLoader(val_ds, batch_size=hypm.batch_size, shuffle=False)
    # val_loader_ref = DataLoader(val_ref, batch_size=hypm.batch_size, shuffle=False)

    if hypm.save_weights:
        os.mkdir(f'model_weights/{hypm.expID}')

    # model = ResNet(emb_dim=embed_dim).to(device)
    # model_r = ResNet(emb_dim=embed_dim).to(device)
    # model_q = ResNet(emb_dim=embed_dim).to(device)

    # model = ResNet().to(device)
    # model = VIT().to(device)
    model = CLIP_model(embed_dim=hypm.embed_dim)

    # model = torch.load(f'model_weights/{7355080}/model_tr.pth')

    # torch.save(model, f'model_weights/{expID}/model_st.pth')

    # criterion = TripletLoss(margin=loss_margin)
    # criterion = nn.TripletMarginLoss(margin=0.5)
  
    # criterion = SoftTripletBiLoss()

    loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=hypm.label_smoothing)
    criterion = InfoNCE(loss_function=loss_fn,
                            device=hypm.device,
                            )



    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(name)
    optimizer = optim.Adam(parameters, lr=hypm.lr)
    # optimizer = optim.AdamW(parameters, lr=lr)
    # optimizer = optim.SGD(parameters, lr=lr)


    
    
    hyparam_info(emb_dim = hypm.embed_dim, 
                 loss_id = hypm.expID, 
                 ln_rate = hypm.lr, 
                 batch = hypm.batch_size, 
                 epc = hypm.epochs, 
                 ls_mrgn = hypm.loss_margin, 
                 trn_sz = train_data.shape[0],
                 val_sz= val_data.shape[0],
                 mdl_nm = model.modelName)
    
    save_exp(emb_dim=hypm.embed_dim, 
                loss_id=hypm.expID, 
                ln_rate=hypm.lr, 
                batch=hypm.batch_size, 
                epc=hypm.epochs, 
                ls_mrgn=hypm.loss_margin, 
                trn_sz=train_data.shape[0],
                val_sz= val_data.shape[0],
                mdl_nm=model.modelName,
                msg= hypm.msg)

    print("Training Start")
    all_loses = train(model, criterion, optimizer, train_loader, num_epochs=hypm.epochs, dev=hypm.device)
    df_loss = pd.DataFrame({'Loss': all_loses})
    df_loss.to_csv(f'losses/losses_{hypm.expID}.csv')

    write_to_file(expID=hypm.expID, msg=f'End of training: ', content=datetime.now())


    print("\nExtract Features:")
    query_features, reference_features, labels = predict(model=model, dataloader=val_loader, dev=hypm.device, isQuery=True)
    # reference_features, reference_labels = predict(model = model, dataloader=val_loader_ref, dev=hypm.device, isQuery=False) 
    


    print("Compute Scores:")
    # r1 =  calculate_scores(query_features, reference_features, query_labels, reference_labels, step_size=1000, ranks=[1, 5, 10])
    r1 =  accuracy(query_features=query_features, reference_features=reference_features, query_labels=labels, topk=[1, 5, 10])
    print(f'{r1}\n') 

    write_to_file(expID=hypm.expID, msg=f'Final eval: ', content=r1)
    write_to_rank_file(expID=hypm.expID, step=hypm.epochs, row=r1)



    if hypm.save_weights:
        torch.save(model, f'model_weights/{hypm.expID}/model_tr.pth')
    





    torch.cuda.empty_cache()
        






if __name__ == '__main__':
    main()


  from .autonotebook import tqdm as notebook_tqdm



Hyperparameter info:
Exp ID: 9264928
Embedded dimension: 768
Learning rate: 1e-05
Batch Size: 64
Loss Margin: 1
Epoch: 100
Training Size: 35532
Validation Size: 8884
Model Name: CLIP


Training Start

Date: 2024-06-24 17:35:30.573150

Epoch#1


100%|██████████| 556/556 [23:23<00:00,  2.52s/it]


Epoch: 1/100 Loss: 3.6563124656677246
Epoch#2


100%|██████████| 556/556 [23:27<00:00,  2.53s/it]


Epoch: 2/100 Loss: 3.5065524578094482

Train Step Eval: 2


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:37<00:00,  2.43s/it]


Compute Scores:
Percentage-top1:13.991445294912202, top5:34.5114813147231, top10:46.443043674020714, top1%:82.91310220621342, time:0.222245454788208
[13.99144529 34.51148131 46.44304367 82.91310221]
Epoch#3


100%|██████████| 556/556 [23:22<00:00,  2.52s/it]


Epoch: 3/100 Loss: 3.471588373184204
Epoch#4


100%|██████████| 556/556 [23:18<00:00,  2.52s/it]


Epoch: 4/100 Loss: 3.4520132541656494

Train Step Eval: 4


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:41<00:00,  2.45s/it]


Compute Scores:
Percentage-top1:17.59342638451148, top5:40.9725348941918, top10:53.039171544349394, top1%:86.94281855020262, time:0.21639180183410645
[17.59342638 40.97253489 53.03917154 86.94281855]
Epoch#5


100%|██████████| 556/556 [23:27<00:00,  2.53s/it]


Epoch: 5/100 Loss: 3.4391725063323975
Epoch#6


100%|██████████| 556/556 [23:27<00:00,  2.53s/it]


Epoch: 6/100 Loss: 3.429814338684082

Train Step Eval: 6


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:39<00:00,  2.44s/it]


Compute Scores:
Percentage-top1:19.630796938316074, top5:43.95542548401621, top10:56.584871679423685, top1%:88.45114813147231, time:0.21959853172302246
[19.63079694 43.95542548 56.58487168 88.45114813]
Epoch#7


100%|██████████| 556/556 [23:21<00:00,  2.52s/it]


Epoch: 7/100 Loss: 3.4224894046783447
Epoch#8


100%|██████████| 556/556 [23:21<00:00,  2.52s/it]


Epoch: 8/100 Loss: 3.41652512550354

Train Step Eval: 8


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:39<00:00,  2.44s/it]


Compute Scores:
Percentage-top1:21.060333183250787, top5:46.16163890139577, top10:58.45339936965331, top1%:89.72309770373705, time:0.2379438877105713
[21.06033318 46.1616389  58.45339937 89.7230977 ]
Epoch#9


100%|██████████| 556/556 [23:33<00:00,  2.54s/it]


Epoch: 9/100 Loss: 3.4115021228790283
Epoch#10


100%|██████████| 556/556 [23:32<00:00,  2.54s/it]


Epoch: 10/100 Loss: 3.407140016555786

Train Step Eval: 10


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:36<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:22.276001800990546, top5:47.613687528140474, top10:60.06303466906798, top1%:90.63484916704188, time:0.21936655044555664
[22.2760018  47.61368753 60.06303467 90.63484917]
Epoch#11


100%|██████████| 556/556 [23:08<00:00,  2.50s/it]


Epoch: 11/100 Loss: 3.4032504558563232
Epoch#12


100%|██████████| 556/556 [23:07<00:00,  2.50s/it]


Epoch: 12/100 Loss: 3.399721384048462

Train Step Eval: 12


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:36<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:23.06393516434039, top5:48.86312471859523, top10:61.31247185952273, top1%:91.17514633048177, time:0.21869850158691406
[23.06393516 48.86312472 61.31247186 91.17514633]
Epoch#13


100%|██████████| 556/556 [23:06<00:00,  2.49s/it]


Epoch: 13/100 Loss: 3.3965072631835938
Epoch#14


100%|██████████| 556/556 [23:06<00:00,  2.49s/it]


Epoch: 14/100 Loss: 3.39359188079834

Train Step Eval: 14


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:36<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:23.671769473210265, top5:49.86492570914003, top10:62.4831157136425, top1%:91.58036920306168, time:0.22087621688842773
[23.67176947 49.86492571 62.48311571 91.5803692 ]
Epoch#15


100%|██████████| 556/556 [23:15<00:00,  2.51s/it]


Epoch: 15/100 Loss: 3.390953540802002
Epoch#16


100%|██████████| 556/556 [23:20<00:00,  2.52s/it]


Epoch: 16/100 Loss: 3.388556480407715

Train Step Eval: 16


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:38<00:00,  2.44s/it]


Compute Scores:
Percentage-top1:24.122017109410177, top5:50.73165240882486, top10:63.34984241332733, top1%:91.87303016659163, time:0.21906805038452148
[24.12201711 50.73165241 63.34984241 91.87303017]
Epoch#17


100%|██████████| 556/556 [23:07<00:00,  2.50s/it]


Epoch: 17/100 Loss: 3.386362314224243
Epoch#18


100%|██████████| 556/556 [23:12<00:00,  2.50s/it]


Epoch: 18/100 Loss: 3.3843398094177246

Train Step Eval: 18


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:36<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:24.74110760918505, top5:51.55335434488969, top10:63.78883385862224, top1%:92.39081494822152, time:0.2200307846069336
[24.74110761 51.55335434 63.78883386 92.39081495]
Epoch#19


100%|██████████| 556/556 [23:22<00:00,  2.52s/it]


Epoch: 19/100 Loss: 3.382462978363037
Epoch#20


100%|██████████| 556/556 [23:23<00:00,  2.52s/it]


Epoch: 20/100 Loss: 3.3807127475738525

Train Step Eval: 20


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:34<00:00,  2.41s/it]


Compute Scores:
Percentage-top1:25.506528590724898, top5:52.51013057181449, top10:64.43043674020711, top1%:92.67221972084646, time:0.2184147834777832
[25.50652859 52.51013057 64.43043674 92.67221972]
Epoch#21


100%|██████████| 556/556 [23:21<00:00,  2.52s/it]


Epoch: 21/100 Loss: 3.3790740966796875
Epoch#22


100%|██████████| 556/556 [23:21<00:00,  2.52s/it]


Epoch: 22/100 Loss: 3.3775341510772705

Train Step Eval: 22


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:36<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:25.934263845114813, top5:52.85907248986943, top10:65.16208914903197, top1%:92.85231877532642, time:0.21867775917053223
[25.93426385 52.85907249 65.16208915 92.85231878]
Epoch#23


100%|██████████| 556/556 [23:06<00:00,  2.49s/it]


Epoch: 23/100 Loss: 3.3760833740234375
Epoch#24


100%|██████████| 556/556 [23:07<00:00,  2.50s/it]


Epoch: 24/100 Loss: 3.3747146129608154

Train Step Eval: 24


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:43<00:00,  2.47s/it]


Compute Scores:
Percentage-top1:26.14813147230977, top5:53.16298964430437, top10:65.60108059432687, top1%:92.98739306618641, time:0.21822309494018555
[26.14813147 53.16298964 65.60108059 92.98739307]
Epoch#25


100%|██████████| 556/556 [23:14<00:00,  2.51s/it]


Epoch: 25/100 Loss: 3.373420000076294
Epoch#26


100%|██████████| 556/556 [23:13<00:00,  2.51s/it]


Epoch: 26/100 Loss: 3.3721940517425537

Train Step Eval: 26


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:38<00:00,  2.43s/it]


Compute Scores:
Percentage-top1:26.452048626744713, top5:53.770823953174244, top10:65.9500225123818, top1%:93.23502926609635, time:0.21695923805236816
[26.45204863 53.77082395 65.95002251 93.23502927]
Epoch#27


100%|██████████| 556/556 [23:09<00:00,  2.50s/it]


Epoch: 27/100 Loss: 3.3710317611694336
Epoch#28


100%|██████████| 556/556 [23:17<00:00,  2.51s/it]


Epoch: 28/100 Loss: 3.3699283599853516

Train Step Eval: 28


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:37<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:26.643403872129674, top5:54.20981539846916, top10:66.09635299414678, top1%:93.40387212967131, time:0.21810030937194824
[26.64340387 54.2098154  66.09635299 93.40387213]
Epoch#29


100%|██████████| 556/556 [23:29<00:00,  2.53s/it]


Epoch: 29/100 Loss: 3.368879556655884
Epoch#30


100%|██████████| 556/556 [23:30<00:00,  2.54s/it]


Epoch: 30/100 Loss: 3.367882251739502

Train Step Eval: 30


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:37<00:00,  2.42s/it]


Compute Scores:
Percentage-top1:26.82350292660963, top5:54.46870778928411, top10:66.25393966681675, top1%:93.49392165691131, time:0.22869277000427246
[26.82350293 54.46870779 66.25393967 93.49392166]
Epoch#31


100%|██████████| 556/556 [23:26<00:00,  2.53s/it]


Epoch: 31/100 Loss: 3.3669333457946777
Epoch#32


100%|██████████| 556/556 [23:26<00:00,  2.53s/it]


Epoch: 32/100 Loss: 3.366029977798462

Train Step Eval: 32


Number of Validation data: 8884

Extract Features:


100%|██████████| 139/139 [05:38<00:00,  2.44s/it]


Compute Scores:
Percentage-top1:26.969833408374605, top5:54.55875731652409, top10:66.5015758667267, top1%:93.49392165691131, time:0.22478818893432617
[26.96983341 54.55875732 66.50157587 93.49392166]
Epoch#33


 96%|█████████▌| 535/556 [22:27<00:53,  2.53s/it]