In [9]:
import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd

In [10]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [11]:
import importlib
import Entity
from Entity import CustomDataset


In [12]:
from learning.Entity.CustomDataset import PokemonMoveDataset


dataset = PokemonMoveDataset('D:/tanaka\Documents\poke-move\data\my-dataset\dataset_without_low_count_move.txt', device)

In [13]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 10
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

torch.manual_seed(42)
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [14]:
import torch.nn.functional as F
import sys
import os
from datetime import datetime
from learning.Encoder.MoveEncoder import MoveEncoder
from learning.Encoder.PokemonEncoder import PokemonEncoder

from learning.LossFunctions.ContractiveLoss import ContrastiveLoss


pokemon_model = PokemonEncoder(device).to(device)
move_model = MoveEncoder(device).to(device)

pokemon_optimizer = torch.optim.Adam(pokemon_model.parameters(), lr=learning_rate)
move_optimizer = torch.optim.Adam(move_model.parameters(), lr=learning_rate)

# margin変更
loss_fn = ContrastiveLoss(margin=10)

In [15]:
def train_loop(dataloader, pokemon_model: PokemonEncoder, move_model: MoveEncoder, loss_fn, poke_opt, move_opt):
    size = len(dataloader.dataset)
    for batch, (pokemon, move, label) in enumerate(dataloader):        
        # 予測と損失の計算
        label = label.to(device)
        pred = pokemon_model(pokemon)
        target = move_model(move)
        # ここ不安
        # print(target.shape, pred.shape)
        loss = loss_fn(pred, target, label)

        # バックプロパゲーション
        poke_opt.zero_grad()
        move_opt.zero_grad()
        loss.backward()
        poke_opt.step()
        move_opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(pokemon), batch * len(label)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        if batch % 1000 == 0:
            now = datetime.now()
            # YYYY-MM-DD形式で日付を出力します
            formatted_date = now.strftime("%Y-%m-%d")
            torch.save(pokemon_model.state_dict(), f'./pokemon_model_no-bert_{formatted_date}.pth')
            torch.save(move_model.state_dict(), f'./move_model_no-bert_{formatted_date}.pth')
def test_loop(dataloader, pokemon_model: PokemonEncoder, move_model: MoveEncoder, loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (pokemon, move, label) in dataloader:
            label = label.to(device)
            pred = pokemon_model(pokemon)
            target = move_model(move)
            loss = loss_fn(pred, target, label).mean()
            test_loss += loss.item()

    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")

In [16]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, pokemon_model, move_model, loss_fn, pokemon_optimizer, move_optimizer)
    test_loop(test_dataloader, pokemon_model, move_model, loss_fn)
    # # 現在の日付を取得します
    # now = datetime.now()

    # # YYYY-MM-DD形式で日付を出力します
    # formatted_date = now.strftime("%Y-%m-%d")
    # torch.save(pokemon_model.state_dict(), f'./pokemon_model_no-bert_{formatted_date}-epoch-{t}.pth')
    # torch.save(move_model.state_dict(), f'./move_model_no-bert_{formatted_date}-epoch-{t}.pth')
print("Done!")

start
Epoch 1\-------------------------------
loss: 18105.593750  [    0/260983]
loss: 22408.174479  [  480/260983]
loss: 5945.565104  [  960/260983]
loss: 9233.703776  [ 1440/260983]
loss: 1363.132731  [ 1920/260983]
loss: 2708.937174  [ 2400/260983]
loss: 1483.838379  [ 2880/260983]
loss: 10762.949870  [ 3360/260983]
loss: 3035.807292  [ 3840/260983]
loss: 6725.849609  [ 4320/260983]
loss: 1871.967448  [ 4800/260983]
loss: 3250.437174  [ 5280/260983]
loss: 6850.070312  [ 5760/260983]
loss: 8475.554036  [ 6240/260983]
loss: 12197.738281  [ 6720/260983]
loss: 31180.812500  [ 7200/260983]
loss: 20614.860677  [ 7680/260983]
loss: 1262.207357  [ 8160/260983]
loss: 18676.968750  [ 8640/260983]
loss: 7176.917969  [ 9120/260983]
loss: 764.520752  [ 9600/260983]
loss: 23637.854167  [10080/260983]
loss: 24073.890625  [10560/260983]
loss: 2102.763346  [11040/260983]
loss: 1173.437663  [11520/260983]
loss: 1465.898438  [12000/260983]
loss: 3341.667643  [12480/260983]
loss: 14039.252604  [12960/2