In [1]:
import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [3]:
import importlib
import Entity
from Entity import CustomDataset


In [4]:
from learning.Entity.CustomDataset import PokemonMoveDataset


dataset = PokemonMoveDataset('D:/tanaka/Documents/poke-move/data/my-dataset/dataset.txt', device)

In [5]:
from torch.utils.data import DataLoader
learning_rate = 1e-5
batch_size = 16
epochs = 5
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [6]:
import torch.nn.functional as F
import sys
import os
from datetime import datetime
from learning.Encoder.MoveEncoder import MoveEncoder
from learning.Encoder.PokemonEncoder import PokemonEncoder

from learning.LossFunctions.ContractiveLoss import ContrastiveLoss


pokemon_model = PokemonEncoder(device).to(device)
move_model = MoveEncoder(device).to(device)
# image_model.load_state_dict(torch.load('model/model_image_2023-07-07.pth'))
# caption_model.load_state_dict(torch.load('model/model_caption_2023-07-06.pth'))
pokemon_optimizer = torch.optim.Adam(pokemon_model.parameters(), lr=learning_rate)
move_optimizer = torch.optim.Adam(move_model.parameters(), lr=learning_rate)

loss_fn = ContrastiveLoss()

In [12]:
def train_loop(dataloader, pokemon_model: PokemonEncoder, move_model: MoveEncoder, loss_fn, poke_opt, move_opt):
    size = len(dataloader.dataset)
    for batch, (pokemon, move, label) in enumerate(dataloader):        
        # 予測と損失の計算
        label = label.to(device)
        pred = pokemon_model(pokemon)
        target = move_model(move)
        # ここ不安
        # print(target.shape, pred.shape)
        loss = loss_fn(pred, target, label)

        # バックプロパゲーション
        poke_opt.zero_grad()
        move_opt.zero_grad()
        loss.backward()
        poke_opt.step()
        move_opt.step()

        if batch % 30 == 0:
            loss, current = loss.item() / len(pokemon), batch * len(pokemon)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        if batch % 1000 == 0:
            # 現在の日付を取得します
            now = datetime.now()

            # YYYY-MM-DD形式で日付を出力します
            formatted_date = now.strftime("%Y-%m-%d")
            torch.save(pokemon_model.state_dict(), f'./pokemon_model_{formatted_date}.pth')
            torch.save(move_model.state_dict(), f'./move_model_{formatted_date}.pth')

def test_loop(dataloader, pokemon_model: PokemonEncoder, move_model: MoveEncoder, loss_fn):
    size = len(dataloader.dataset)
    test_loss = 0

    with torch.no_grad():
        for (pokemon, move, label) in dataloader:
            label = label.to(device)
            pred = pokemon_model(pokemon)
            target = move_model(move)
            loss = loss_fn(pred, target, label).mean()
            test_loss += loss.item()

    test_loss /= size
    print(f"Avg loss: {test_loss:>8f} \n")

In [13]:
print("start")
for t in range(epochs):
    print(f"Epoch {t+1}\-------------------------------")
    train_loop(train_dataloader, pokemon_model, move_model, loss_fn, pokemon_optimizer, move_optimizer)
    test_loop(test_dataloader, pokemon_model, move_model, loss_fn)
print("Done!")

start
Epoch 1\-------------------------------
loss: 0.000000  [    0/727199]
loss: 0.743896  [  180/727199]
loss: 0.000000  [  360/727199]
loss: 0.000000  [  540/727199]
loss: 0.352064  [  720/727199]
loss: 4.651002  [  900/727199]
loss: 0.257963  [ 1080/727199]
loss: 0.000000  [ 1260/727199]
loss: 0.000000  [ 1440/727199]
loss: 0.000000  [ 1620/727199]
loss: 0.401096  [ 1800/727199]
loss: 0.427737  [ 1980/727199]
loss: 0.794831  [ 2160/727199]
loss: 0.831454  [ 2340/727199]
loss: 2.319511  [ 2520/727199]
loss: 0.000000  [ 2700/727199]
loss: 0.469415  [ 2880/727199]
loss: 1.679926  [ 3060/727199]
loss: 0.000000  [ 3240/727199]
loss: 0.566458  [ 3420/727199]
loss: 0.214042  [ 3600/727199]
loss: 0.995253  [ 3780/727199]
loss: 1.070386  [ 3960/727199]
loss: 0.331354  [ 4140/727199]
loss: 0.635836  [ 4320/727199]
loss: 0.811649  [ 4500/727199]
loss: 2.158277  [ 4680/727199]
loss: 1.006296  [ 4860/727199]
loss: 0.286115  [ 5040/727199]
loss: 0.000000  [ 5220/727199]
loss: 0.205028  [ 5400/7

KeyboardInterrupt: 