In [1]:

import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import os 
import sys

In [2]:
sys.path.append('d:/M1/fashion')

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [4]:
"""
画像処理のモデル
"""

class ImageEncoder(nn.Module):
    def __init__(self, embedding_size):
        super(ImageEncoder, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.fc = nn.Linear(self.resnet50.fc.out_features, embedding_size)

    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc(x)
        return x

In [5]:
"""
テキスト処理のモデル
"""
class CaptionEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
  def forward(self, x):
    x = self.bert(x)
    x = x.last_hidden_state

    x = x[:,0,:]

    return x

In [6]:
class FashionItemEncoder(nn.Module):
    def __init__(self):
        super(FashionItemEncoder, self).__init__()
        self.image_model = ImageEncoder(768)
        self.caption_model = CaptionEncoder()
        self.fc1 = nn.Linear(768 * 2, 768)
        self.fc2 = nn.Linear(768, 768)
        self.relu = nn.ReLU()
        self.tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

    def load_image_dict(self, image_model_path: str):
        self.image_model.load_state_dict(torch.load(image_model_path))

    def load_caption_dict(self, caption_model_path: str):
        self.caption_model.load_state_dict(torch.load(caption_model_path))
    
    def forward(self, image, caption):
        image_vector = self.image_model(image)
        ids = self.tokenizer.batch_encode_plus(caption, return_tensors='pt',
                                                padding='max_length',
                                                truncation=True,
                                                max_length=256,
                                                add_special_tokens=True).input_ids
        ids = ids.to(device)
        caption_vector = self.caption_model(ids)
        concat_vector = torch.cat((image_vector, caption_vector), dim=1)
        concat_vector = self.relu(concat_vector)
        y = self.fc1(concat_vector)
        y = self.relu(y)
        y = self.fc2(y)
        return y

In [7]:
from learning.util_class.CompatibilityDataset import CompatibilityDataset

dataset = CompatibilityDataset('../learning/data/anotation_positive_compatibility.csv', '../learning/data/anotation_negative_compatibility.csv')

In [8]:
# import pandas as pd
# postive_annotations = pd.read_csv('../learning/data/anotation_positive_compatibility.csv', header=None)[:1000]

In [9]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 4
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [10]:
import torch.nn.functional as F
import sys
import os
from datetime import datetime
sys.path.append(os.path.abspath("../"))
from models.ContrastiveLoss import ContrastiveLoss

model = FashionItemEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = ContrastiveLoss()

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
def train_loop(dataloader, model, loss_fn, opt):
    size = len(dataloader.dataset)
    for batch, (img_1, cap_1, img_2, cap_2, label) in enumerate(dataloader):        
        # 画像
        img_1 = img_1.to(device)
        img_2 = img_2.to(device)

        # 予測
        pred1 = model(img_1, cap_1)
        pred2 = model(img_2, cap_2)

        label = label.to(device)

        loss = loss_fn(pred1, pred2, label)

        # バックプロパゲーション
        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(label), batch * len(label)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        if batch % 1000 == 0:
            # 現在の日付を取得します
            now = datetime.now()

            # YYYY-MM-DD形式で日付を出力します
            formatted_date = now.strftime("%Y-%m-%d")
            torch.save(model.state_dict(), f'model_compatibility_{formatted_date}.pth')


# 1. lossを確認
# 2. compatibilityの推定を確認
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0
    posi_sum = 0
    nega_sum = 0
    with torch.no_grad():
        for (img_1, cap_1, img_2, cap_2, label) in dataloader:
            # 画像
            img_1 = img_1.to(device)
            img_2 = img_2.to(device)

            # 予測
            pred1 = model(img_1, cap_1)
            pred2 = model(img_2, cap_2)

            label = label.to(device)

            loss = loss_fn(pred1, pred2, label).mean()
            test_loss += loss.item()

            distance = torch.norm(pred1 - pred2, dim=1, keepdim=True)
            posi_sum += torch.sum(distance[label == 1]) / len(label)
            nega_sum += torch.sum(distance[label == 0]) / len(label)
            
    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n positive_distance: {posi_sum} negative_distance: {nega_sum}")

In [12]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

start
Epoch 1\-------------------------------


loss: 0.066608  [    0/136000]
loss: 0.021142  [  480/136000]
loss: 0.000516  [  960/136000]
loss: 0.007632  [ 1440/136000]
loss: 0.016191  [ 1920/136000]
loss: 0.014739  [ 2400/136000]
loss: 0.000767  [ 2880/136000]
loss: 0.001167  [ 3360/136000]
loss: 0.008446  [ 3840/136000]
loss: 0.007812  [ 4320/136000]
loss: 0.000541  [ 4800/136000]
loss: 0.018323  [ 5280/136000]
loss: 0.008274  [ 5760/136000]
loss: 0.006827  [ 6240/136000]
loss: 0.000504  [ 6720/136000]
loss: 0.000376  [ 7200/136000]
loss: 0.006516  [ 7680/136000]


UnidentifiedImageError: cannot identify image file 'D:/M1/fashion/IQON/IQON3000/996095/3967667/37764010_m.jpg'