In [None]:
import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import sys
import os
import torch.nn.functional as F

from datetime import datetime
from models.ContrastiveLoss import ContrastiveLoss



In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [None]:
dataset = EmbeddingDataset('./data/anotation_new.csv')

In [None]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 4
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [None]:

sys.path.append(os.path.abspath("../"))

image_model = ImageEncoder(768).to(device)
caption_model = CaptionEncoder().to(device)
# image_model.load_state_dict(torch.load('model/model_image_2023-07-07.pth'))
# caption_model.load_state_dict(torch.load('model/model_caption_2023-07-06.pth'))
img_optimizer = torch.optim.Adam(image_model.parameters(), lr=learning_rate)
cpt_optimizer = torch.optim.Adam(caption_model.parameters(), lr=learning_rate)

loss_fn = ContrastiveLoss()
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

In [None]:
def train_loop(dataloader, img_model, cpt_model,  loss_fn, img_opt, cpt_opt):
    size = len(dataloader.dataset)
    for batch, (img, cap, label, _) in enumerate(dataloader):        
        # 予測と損失の計算
        img = img.to(device)
        label = label.to(device)
        pred = img_model(img)
        ids = tokenizer.batch_encode_plus(cap, return_tensors='pt', padding='max_length', truncation=True, max_length=256, add_special_tokens=True).input_ids
        ids = ids.to(device)
        target = cpt_model(ids)
        # ここ不安
        loss = loss_fn(pred, target, label)

        # バックプロパゲーション
        img_opt.zero_grad()
        cpt_opt.zero_grad()
        loss.backward()
        img_opt.step()
        cpt_opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(img), batch * len(img)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        if batch % 1000 == 0:
            # 現在の日付を取得します
            now = datetime.now()

            # YYYY-MM-DD形式で日付を出力します
            formatted_date = now.strftime("%Y-%m-%d")
            torch.save(caption_model.state_dict(), f'model_caption_{formatted_date}.pth')
            torch.save(image_model.state_dict(), f'model_image_{formatted_date}.pth')


def test_loop(dataloader, img_model, cpt_model,  loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (img, cap, label, _) in dataloader:
            img = img.to(device)
            label = label.to(device)
            pred = img_model(img)
            ids = tokenizer.batch_encode_plus(cap, return_tensors='pt', padding='max_length', truncation=True, max_length=256, add_special_tokens=True).input_ids
            ids = ids.to(device)
            target = cpt_model(ids)
            # print(pred.shape, target.shape, len(X), len(y))
            # ここ不安
            loss = loss_fn(pred, target, label).mean()
            test_loss += loss.item()
            
    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")

In [None]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, image_model, caption_model, loss_fn, img_optimizer, cpt_optimizer)
    test_loop(test_dataloader, image_model, caption_model, loss_fn)
print("Done!")