In [1]:

import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [3]:
"""
画像処理のモデル
"""

class ImageEncoder(nn.Module):
    def __init__(self, embedding_size):
        super(ImageEncoder, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.fc = nn.Linear(self.resnet50.fc.out_features, embedding_size)
    
    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc(x)
        return x

In [4]:
"""
テキスト処理のモデル
"""
class CaptionEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
  def forward(self, x):
    x = self.bert(x)
    x = x.last_hidden_state
    # print(x.shape)
    x = x[:,0,:] 
    # print(x.shape)
    return x

In [5]:
from util_class.CustomDataset import EmbeddingDataset


dataset = EmbeddingDataset('./data/anotation_new.csv')

In [6]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 4
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [7]:

import torch.nn.functional as F
import sys
import os
from datetime import datetime
sys.path.append(os.path.abspath("../"))
from models.ContrastiveLoss import ContrastiveLoss


image_model = ImageEncoder(768).to(device)
caption_model = CaptionEncoder().to(device)
image_model.load_state_dict(torch.load('model/model_image_2023-07-07.pth'))
# caption_model.load_state_dict(torch.load('model/model_caption_2023-07-06.pth'))
img_optimizer = torch.optim.Adam(image_model.parameters(), lr=learning_rate)
cpt_optimizer = torch.optim.Adam(caption_model.parameters(), lr=learning_rate)

loss_fn = ContrastiveLoss()
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
def train_loop(dataloader, img_model, cpt_model,  loss_fn, img_opt, cpt_opt):
    size = len(dataloader.dataset)
    for batch, (img, cap, label, _) in enumerate(dataloader):        
        # 予測と損失の計算
        img = img.to(device)
        label = label.to(device)
        pred = img_model(img)
        ids = tokenizer.batch_encode_plus(cap, return_tensors='pt', padding='max_length', truncation=True, max_length=256, add_special_tokens=True).input_ids
        ids = ids.to(device)
        target = cpt_model(ids)
        # ここ不安
        loss = loss_fn(pred, target, label)

        # バックプロパゲーション
        img_opt.zero_grad()
        cpt_opt.zero_grad()
        loss.backward()
        img_opt.step()
        cpt_opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(img), batch * len(img)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        if batch % 1000 == 0:
            # 現在の日付を取得します
            now = datetime.now()

            # YYYY-MM-DD形式で日付を出力します
            formatted_date = now.strftime("%Y-%m-%d")
            torch.save(caption_model.state_dict(), f'model_caption_{formatted_date}.pth')
            torch.save(image_model.state_dict(), f'model_image_{formatted_date}.pth')


def test_loop(dataloader, img_model, cpt_model,  loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (img, cap, label, _) in dataloader:
            img = img.to(device)
            label = label.to(device)
            pred = img_model(img)
            ids = tokenizer.batch_encode_plus(cap, return_tensors='pt', padding='max_length', truncation=True, max_length=256, add_special_tokens=True).input_ids
            ids = ids.to(device)
            target = cpt_model(ids)
            # print(pred.shape, target.shape, len(X), len(y))
            # ここ不安
            loss = loss_fn(pred, target, label).mean()
            test_loss += loss.item()
            
    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")

In [9]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, image_model, caption_model, loss_fn, img_optimizer, cpt_optimizer)
    test_loop(test_dataloader, image_model, caption_model, loss_fn)
print("Done!")

start
Epoch 1\-------------------------------
loss: 0.614378  [    0/1016307]
loss: 0.103117  [  480/1016307]
loss: 0.000000  [  960/1016307]
loss: 0.012993  [ 1440/1016307]
loss: 0.002273  [ 1920/1016307]
loss: 0.008556  [ 2400/1016307]
loss: 0.009081  [ 2880/1016307]
loss: 0.001411  [ 3360/1016307]
loss: 0.007506  [ 3840/1016307]
loss: 0.007643  [ 4320/1016307]
loss: 0.009015  [ 4800/1016307]
loss: 0.001251  [ 5280/1016307]
loss: 0.000113  [ 5760/1016307]
loss: 0.013051  [ 6240/1016307]
loss: 0.000010  [ 6720/1016307]
loss: 0.000135  [ 7200/1016307]
loss: 0.000216  [ 7680/1016307]
loss: 0.009200  [ 8160/1016307]
loss: 0.014269  [ 8640/1016307]
loss: 0.008257  [ 9120/1016307]
loss: 0.008445  [ 9600/1016307]
loss: 0.001368  [10080/1016307]
loss: 0.014872  [10560/1016307]
loss: 0.018161  [11040/1016307]
loss: 0.015502  [11520/1016307]
loss: 0.007564  [12000/1016307]
loss: 0.001985  [12480/1016307]
loss: 0.019885  [12960/1016307]
loss: 0.000702  [13440/1016307]
loss: 0.010535  [13920/101

KeyboardInterrupt: 

In [None]:
import gc
gc.collect()

: 

In [10]:
from datetime import datetime

# 現在の日付を取得します
now = datetime.now()

# YYYY-MM-DD形式で日付を出力します
formatted_date = now.strftime("%Y-%m-%d")

torch.save(caption_model.state_dict(), f'model_caption_{formatted_date}.pth')
torch.save(image_model.state_dict(), f'model_image_{formatted_date}.pth')