In [8]:

import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd

In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [10]:
"""
画像処理のモデル
"""

class ImageEncoder(nn.Module):
    def __init__(self, embedding_size):
        super(ImageEncoder, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.fc = nn.Linear(self.resnet50.fc.out_features, embedding_size)
    
    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc(x)
        return x

In [11]:
"""
テキスト処理のモデル
"""
class CaptionEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
  def forward(self, x):
    x = self.bert(x)
    x = torch.max(x.last_hidden_state, dim=1)[0]  # max pooling
    return x

In [12]:
from util_class.CustomDataset import EmbeddingDataset


dataset = EmbeddingDataset('./data/anotation_new.csv')

(635192, 3)


In [19]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 4
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [15]:

import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.abspath("../"))
from models.ContrastiveLoss import ContrastiveLoss


image_model = ImageEncoder(768).to(device)
caption_model = CaptionEncoder().to(device)
image_model.load_state_dict(torch.load('model/model_image_2023-07-02.pth'))
caption_model.load_state_dict(torch.load('model/model_caption_2023-07-02.pth'))
img_optimizer = torch.optim.Adam(image_model.parameters(), lr=learning_rate)
cpt_optimizer = torch.optim.Adam(caption_model.parameters(), lr=learning_rate)

loss_fn = ContrastiveLoss()
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
def train_loop(dataloader, img_model, cpt_model,  loss_fn, img_opt, cpt_opt):
    size = len(dataloader.dataset)
    for batch, (img, cap, label, _) in enumerate(dataloader):        
        # 予測と損失の計算
        img = img.to(device)
        label = label.to(device)
        pred = img_model(img)
        ids = tokenizer.batch_encode_plus(cap, return_tensors='pt', padding='max_length', truncation=True, max_length=256).input_ids
        ids = ids.to(device)
        target = cpt_model(ids)
        # ここ不安
        loss = loss_fn(pred, target, label)

        # バックプロパゲーション
        img_opt.zero_grad()
        cpt_opt.zero_grad()
        loss.backward()
        img_opt.step()
        cpt_opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(img), batch * len(img)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, img_model, cpt_model,  loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (img, cap, label, _) in dataloader:
            img = img.to(device)
            label = label.to(device)
            pred = img_model(img)
            ids = tokenizer.encode(cap, return_tensors='pt')
            ids = ids.to(device)
            target = cpt_model(ids)
            # print(pred.shape, target.shape, len(X), len(y))
            # ここ不安
            loss = loss_fn(pred, target, label).mean()
            test_loss += loss.item()
            
    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")

In [17]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, image_model, caption_model, loss_fn, img_optimizer, cpt_optimizer)
    test_loop(test_dataloader, image_model, caption_model, loss_fn)
print("Done!")

start
Epoch 1\-------------------------------
loss: 0.040809  [    0/508153]
loss: 0.031671  [  480/508153]
loss: 0.030572  [  960/508153]
loss: 0.032185  [ 1440/508153]
loss: 0.031886  [ 1920/508153]
loss: 0.038456  [ 2400/508153]
loss: 0.046150  [ 2880/508153]
loss: 0.036169  [ 3360/508153]
loss: 0.047579  [ 3840/508153]
loss: 0.035831  [ 4320/508153]
loss: 0.039821  [ 4800/508153]
loss: 0.033201  [ 5280/508153]
loss: 0.033310  [ 5760/508153]
loss: 0.030289  [ 6240/508153]
loss: 0.030229  [ 6720/508153]
loss: 0.037267  [ 7200/508153]
loss: 0.037608  [ 7680/508153]
loss: 0.029941  [ 8160/508153]
loss: 0.025767  [ 8640/508153]
loss: 0.073675  [ 9120/508153]
loss: 0.032276  [ 9600/508153]
loss: 0.042545  [10080/508153]
loss: 0.035856  [10560/508153]
loss: 0.033655  [11040/508153]
loss: 0.028770  [11520/508153]
loss: 0.030893  [12000/508153]
loss: 0.034909  [12480/508153]
loss: 0.081878  [12960/508153]
loss: 0.059162  [13440/508153]
loss: 0.032088  [13920/508153]
loss: 0.034830  [14400/5

KeyboardInterrupt: 

In [None]:
import gc
gc.collect()

203

In [18]:
from datetime import datetime

# 現在の日付を取得します
now = datetime.now()

# YYYY-MM-DD形式で日付を出力します
formatted_date = now.strftime("%Y-%m-%d")

torch.save(caption_model.state_dict(), f'model_caption_{formatted_date}.pth')
torch.save(image_model.state_dict(), f'model_image_{formatted_date}.pth')