In [None]:
import torch

from data.data_load import get_neckline_df, get_device, download_images
from multi_modal.data_set import get_dataset, print_dataset_names
from torch.utils.data import DataLoader
from multi_modal.open_clip.pretrained_model import get_model, train_step

In [None]:
device = get_device()
device

In [None]:
data_size = 1

train_df, test_df = get_neckline_df('neck_line_concated.csv', 'neck_line_test_item_no_list.txt', data_size=data_size)

In [None]:
train_df

In [None]:
test_df

In [None]:
train_urls = [url for url in train_df['detail_image_url_1'].tolist()]
train_labels = train_df['neck_line_label'].tolist()
train_item_no_list = train_df['item_no'].tolist()
train_sentences = train_df['neck_line_label_desc'].tolist()

test_urls = [url for url in test_df['detail_image_url_1'].tolist()]
test_labels = test_df['neck_line_label'].tolist()
test_item_no_list = test_df['item_no'].tolist()
test_sentences = test_df['neck_line_label_desc'].tolist()

In [None]:
download_images(train_urls)
download_images(test_urls)

In [None]:
print_dataset_names()

In [None]:
batch_size = 64

dataset = get_dataset(train_item_no_list, train_urls, train_sentences, train_labels, 'CustomDatasetWithPreprocessor')
train_dataloader = DataLoader(dataset, batch_size=batch_size)

In [None]:
open_clip_model = get_model(device)

In [None]:
lr=1e-6
EPOCH=5000

# optimizer = torch.optim.Adam(open_clip_model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
optimizer = torch.optim.Adam(open_clip_model.parameters(), lr=1e-5)

img_criterion = torch.nn.CrossEntropyLoss()
txt_criterion = torch.nn.CrossEntropyLoss()

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

In [None]:
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float()
        p.grad.data = p.grad.data.float() 

In [None]:
num_epochs = 1000

for epoch in range(num_epochs):
    open_clip_model.train()
    running_loss = 0.0
    running_img_loss = 0.0
    running_txt_loss = 0.0
    predictions = []
    targets = []

    # for batch_data in tqdm(train_dataloader):
    for batch_data in train_dataloader:
        optimizer.zero_grad()
        
        outputs = train_step(open_clip_model, device, batch_data)

        logits_per_image, logits_per_text, _ = outputs
        
        print(logits_per_image)

        target = torch.arange(len(logits_per_image), device=device)
        
        img_loss = img_criterion(logits_per_text, target)
        txt_loss = txt_criterion(logits_per_text, target)
        
        loss = (img_loss + txt_loss) / 2

        loss.backward()
        # convert_models_to_fp32(open_clip_model)
        optimizer.step()
        # clip.model.convert_weights(open_clip_model)

        running_loss += loss.item()
        running_img_loss += img_loss.item()
        running_txt_loss += txt_loss.item()
        
    scheduler.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_dataloader)}, Img Loss: {running_img_loss / len(train_dataloader)}, Txt Loss: {running_txt_loss / len(train_dataloader)}, LR: {scheduler.get_last_lr()}")