In [1]:
# from fashion_clip.fashion_clip import FashionCLIP

from sklearn.metrics import *
import torch.nn as nn
import torch
from transformers import AutoTokenizer, AutoModel

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [3]:
"""
テキスト処理のモデル
"""
class CaptionEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
    self.fc = nn.Linear(768, 512)
  def forward(self, x):
    x = self.bert(x)
    x = torch.max(x.last_hidden_state, dim=1)[0]  # max pooling
    x = self.fc(x)
    return x

In [4]:
from FashionClipDataset import FashionClipDataset


dataset = FashionClipDataset('./data/anotation_new.csv', 'image_tensor/tensor_0-100000.pt')

In [5]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 32
epochs = 10
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [6]:

import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.abspath("../"))
from models.ContrastiveLoss import ContrastiveLoss

caption_model = CaptionEncoder().to(device)
cpt_optimizer = torch.optim.SGD(caption_model.parameters(), lr=learning_rate)

loss_fn = ContrastiveLoss()
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def train_loop(dataloader, cpt_model,  loss_fn, cpt_opt):
    size = len(dataloader.dataset)
    for batch, (img, cap, label) in enumerate(dataloader):        
        # 予測と損失の計算
        img = img.to(device)
        label = label.to(device)
        ids = tokenizer.encode(cap, return_tensors='pt')
        ids = ids.to(device)
        target = cpt_model(ids)
        # print(pred.shape, target.shape, len(X), len(y))
        # ここ不安
        loss = loss_fn(img, target, label)

        # バックプロパゲーション
        cpt_opt.zero_grad()
        loss.backward()
        cpt_opt.step()

        if batch % 100 == 0:
            loss, current = loss.item() / len(img), batch * len(img)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, cpt_model,  loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (img, cap, label) in dataloader:
            # 予測と損失の計算
            img = img.to(device)
            label = label.to(device)

            ids = tokenizer.encode(cap, return_tensors='pt')
            ids = ids.to(device)
            target = cpt_model(ids)
            # print(pred.shape, target.shape, len(X), len(y))
            # ここ不安
            loss = loss_fn(img, target, label).mean()
            test_loss += loss.item()
            
    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")
    return test_loss

In [8]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader,  caption_model, loss_fn, cpt_optimizer)
    loss = test_loop(test_dataloader,caption_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1.849588  [    0/160000]
loss: 0.999104  [ 3200/160000]
loss: 1.083622  [ 6400/160000]
loss: 0.659812  [ 9600/160000]
loss: 0.538239  [12800/160000]
loss: 0.503493  [16000/160000]
loss: 0.520776  [19200/160000]
loss: 0.547975  [22400/160000]
loss: 0.470417  [25600/160000]
loss: 0.367276  [28800/160000]
loss: 0.500437  [32000/160000]
loss: 0.433457  [35200/160000]
loss: 0.313173  [38400/160000]
loss: 0.454808  [41600/160000]
loss: 0.505045  [44800/160000]
loss: 0.392350  [48000/160000]
loss: 0.474050  [51200/160000]
loss: 0.541405  [54400/160000]
loss: 0.422453  [57600/160000]
loss: 0.457403  [60800/160000]
loss: 0.513172  [64000/160000]
loss: 0.368042  [67200/160000]
loss: 0.649226  [70400/160000]
loss: 0.533731  [73600/160000]
loss: 0.505528  [76800/160000]
loss: 0.510948  [80000/160000]
loss: 0.717671  [83200/160000]
loss: 0.443483  [86400/160000]
loss: 0.516442  [89600/160000]
loss: 0.667437  [92800/160000]
loss: 0.508607  [96000/160000]

KeyboardInterrupt: 

In [11]:
from datetime import datetime

# 現在の日付を取得します
now = datetime.now()

# YYYY-MM-DD形式で日付を出力します
formatted_date = now.strftime("%Y-%m-%d")

torch.save(caption_model.state_dict(), f'clip_model_caption_{formatted_date}.pth')