In [None]:
import torch

from data.data_load import get_neckline_df, get_device, download_images
from data_set import get_dataset, print_dataset_names
from torch.utils.data import DataLoader
from fashion_clip.pretrained_model import get_model, train_step
from tqdm.auto import tqdm

In [None]:
device = get_device()
device

In [None]:
data_size = 20

train_df, test_df = get_neckline_df('neck_line_concated.csv', 'neck_line_test_item_no_list.txt', data_size=data_size)

In [None]:
train_df

In [None]:
test_df

In [None]:
train_urls = [url for url in train_df['detail_image_url_1'].tolist()]
train_labels = train_df['neck_line_label'].tolist()
train_item_no_list = train_df['item_no'].tolist()
train_sentences = train_df['neck_line_label_desc'].tolist()

test_urls = [url for url in test_df['detail_image_url_1'].tolist()]
test_labels = test_df['neck_line_label'].tolist()
test_item_no_list = test_df['item_no'].tolist()
test_sentences = test_df['neck_line_label_desc'].tolist()

In [None]:
download_images(train_urls)
download_images(test_urls)

In [None]:
batch_size = 128

print_dataset_names()
dataset = get_dataset(train_item_no_list, train_urls, train_sentences, train_labels, 'CustomDatasetWithOutProcessor')
train_dataloader = DataLoader(dataset, batch_size=batch_size)

In [None]:
clip_model = get_model(device)

In [None]:
lr=1e-6

optimizer = torch.optim.Adam(clip_model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
img_criterion = torch.nn.CrossEntropyLoss()
txt_criterion = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=0)

In [None]:
num_epochs = 1000

for epoch in range(num_epochs):
    clip_model.train()
    running_loss = 0.0
    running_img_loss = 0.0
    running_txt_loss = 0.0
    predictions = []
    targets = []

    # for batch_data in tqdm(train_dataloader):
    for batch_data in train_dataloader:
        outputs = train_step(clip_model, device, batch_data, return_loss=True)

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

        target = torch.arange(len(logits_per_image), device=device)
        
        img_loss = torch.nn.functional.cross_entropy(logits_per_image, target)
        # txt_loss = txt_criterion(logits_per_text, target)
        txt_loss = torch.nn.functional.cross_entropy(logits_per_text, target)
        # txt_loss = txt_criterion(logits_per_text, target)
        
        loss = (img_loss + txt_loss) / 2

        # loss = (txt_criterion(logits_per_text, batch_labels) + img_criterion(logits_per_image, batch_labels)) / 2

        # loss = txt_criterion(logits_per_text, batch_labels)

        # loss = img_criterion(logits_per_text, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_img_loss += img_loss.item()
        running_txt_loss += txt_loss.item()
        
    scheduler.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_dataloader)}, Img Loss: {running_img_loss / len(train_dataloader)}, Txt Loss: {running_txt_loss / len(train_dataloader)}, LR: {scheduler.get_last_lr()}")

Epoch [10/1000], Loss: 3.1842846870422363, Img Loss: 3.124016284942627, Txt Loss: 3.2445532083511353, LR: [9.179036806841351e-07]
Epoch [20/1000], Loss: 3.1503570079803467, Img Loss: 3.0817757844924927, Txt Loss: 3.2189382314682007, LR: [8.698155474893048e-07]
Epoch [30/1000], Loss: 3.1723897457122803, Img Loss: 3.102574586868286, Txt Loss: 3.2422049045562744, LR: [8.126213281678525e-07]
Epoch [40/1000], Loss: 3.123449683189392, Img Loss: 3.066511034965515, Txt Loss: 3.180388331413269, LR: [7.477293342162037e-07]
Epoch [50/1000], Loss: 3.1395249366760254, Img Loss: 3.082792282104492, Txt Loss: 3.1962573528289795, LR: [6.767374218896286e-07]
Epoch [60/1000], Loss: 3.147813081741333, Img Loss: 3.081658959388733, Txt Loss: 3.2139673233032227, LR: [6.013936476782562e-07]
Epoch [70/1000], Loss: 3.1357086896896362, Img Loss: 3.0837225914001465, Txt Loss: 3.1876946687698364, LR: [5.235532253548213e-07]
Epoch [80/1000], Loss: 3.072700023651123, Img Loss: 3.028701901435852, Txt Loss: 3.11669814