In [23]:
import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import gc

In [24]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [25]:
import sys
import os

sys.path.append(os.path.abspath("../"))
from learning.Entity.CustomDataset import PokemonMoveDataset
import json



In [26]:
dataset = PokemonMoveDataset('D:/tanaka/Documents/poke-move/data/my-dataset/dataset_without_status_z_no_leadned.txt', device)

In [27]:
from learning.Encoder.MoveEncoder import MoveEncoder
from learning.Encoder.PokemonEncoder import PokemonEncoder


pokemon_model = PokemonEncoder(device).to(device)
move_model = MoveEncoder(device).to(device)
# D:\tanaka\Documents\poke-move\learning\pokemon_model_no-bert_2023-08-10-epoch-9.pth
poke_model_name = 'D:/tanaka/Documents/poke-move/learning/pokemon_model_no-bert_2023-08-10.pth'
move_model_name = 'D:/tanaka/Documents/poke-move/learning/move_model_no-bert_2023-08-10.pth'
pokemon_model.load_state_dict(torch.load(poke_model_name))
move_model.load_state_dict(torch.load(move_model_name))
pokemon_model.eval()
move_model.eval()

MoveEncoder(
  (type_nn): Linear(in_features=18, out_features=4, bias=True)
  (damage_class_nn): Linear(in_features=3, out_features=1, bias=True)
  (features_nn1): Linear(in_features=10, out_features=32, bias=True)
  (features_nn2): Linear(in_features=32, out_features=64, bias=True)
  (features_nn3): Linear(in_features=64, out_features=128, bias=True)
  (relu): ReLU()
)

In [28]:
pokemon_vectors = []
for i, pokemon in enumerate(dataset.pokemons):
    poke_vector = dataset.pokemon_to_vector(pokemon)
    for key in poke_vector.keys():
        poke_vector[key] = poke_vector[key].unsqueeze(0)
    v = pokemon_model(poke_vector).to('cpu')
    pokemon_vectors.append(v)
    del poke_vector
pokemon_vectors = torch.cat(pokemon_vectors, dim=0)

In [29]:
from memory_profiler import profile

move_vectors = []
for i, move in enumerate(dataset.moves):
    move_vector = dataset.move_to_vector(move)
    for key in move_vector.keys():
        if key == 'description':
            move_vector[key] = [move_vector[key]]
            continue
        move_vector[key] = move_vector[key].unsqueeze(0)
    with torch.no_grad():
        v = move_model(move_vector).to('cpu')
    move_vectors.append(v)
    print(f'\r{i}回終了', end='')
move_vectors = torch.cat(move_vectors, dim=0)

899回終了

In [30]:
import random

move_rankings = []
pokemon_ids = [random.randint(1, 1010) for i in range(5)]
for pokemon_id in pokemon_ids:
    # 各ベクトルのユークリッド距離を計算
    distances = torch.norm(pokemon_vectors[pokemon_id] - move_vectors, dim=1)

    # 距離を用いてベクトルBの添え字を並び替え
    sorted_indices = torch.argsort(distances)
    move_rankings.append(sorted_indices)


In [31]:
text = ''
remove_move_ids = set([2,4,11,12,14,18,26,27,28,39,41,43,45,46,47,48,49,50,54,73,74,77,78,79,81,86,92,95,96,97,100,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,118,119,121,124,125,128,131,132,133,134,135,136,137,139,140,142,144,146,147,148,150,151,152,155,156,158,159,160,164,165,166,167,169,170,171,174,176,177,178,180,182,183,184,186,187,190,191,193,194,195,197,198,199,201,203,204,207,208,212,213,214,215,217,219,220,221,226,227,230,233,234,235,236,238,240,241,244,254,256,258,259,260,261,262,265,266,267,268,269,270,271,272,273,274,275,277,278,281,284,285,286,287,288,289,292,293,294,295,296,297,298,299,300,301,302,303,307,308,309,312,313,316,319,320,321,322,323,325,327,334,335,336,338,339,344,346,347,349,353,354,355,356,357,361,366,367,373,375,376,377,379,380,381,382,383,384,385,388,390,391,392,393,395,397,415,417,418,432,433,439,443,445,446,448,449,454,455,456,459,460,461,462,463,464,465,467,468,469,470,471,472,475,476,477,478,480,483,487,489,493,494,495,501,502,504,505,508,509,511,513,516,526,531,532,533,536,537,538,539,540,543,544,545,546,547,548,549,550,551,552,553,554,557,558,559,560,561,563,564,567,568,569,570,571,575,576,578,579,580,581,582,586,587,588,589,590,591,592,593,594,596,597,598,599,600,601,602,603,604,606,607,608,610,613,614,615,616,617,618,619,620,621,622,623,624,625,626,627,628,629,630,631,632,633,634,635,636,637,638,639,640,641,642,643,644,645,646,647,648,649,650,651,652,653,654,655,656,657,658,659,660,661,662,664,665,666,668,671,672,673,674,677,678,679,680,681,682,683,685,686,687,688,689,690,691,692,694,695,696,697,698,699,700,701,702,703,704,705,708,709,711,712,713,714,715,716,717,718,719,720,721,722,723,724,725,726,727,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,774,775,777,778,779,780,781,782,783,785,786,787,788,789,790,791,792,793,794,795,801,810,811,816,817,818,819,820,821,822,823,824,825,826,827,828,829,830,831,832,833,834,835,836,837,838,839,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894,895,896,897,898,899,900])
for i in range(len(move_rankings)):
    move_rankings[i] = [move_id for move_id in move_rankings[i] if move_id.item() + 1 not in remove_move_ids]

for i, pokemon_id in enumerate(pokemon_ids):
    pokemon = dataset.pokemons[pokemon_id]
    text += f'選ばれたポケモン: {pokemon.name}\n'
    text += f'近い技 TOP100\n'
    for j, move_id in enumerate(move_rankings[i][:100]):
        move = dataset.moves[move_id]
        with open(f'D:/tanaka/Documents/poke-move/data\moves\{move_id + 1}.json') as f:
            m = json.load(f)
        name = list(filter(lambda x: x["language"]["name"] == 'ja', m["names"]))[0]
        text += f'{j}: {name["name"]}\n'

In [32]:
move_rankings[0][1]

tensor(231)

In [33]:
with open('output.txt', 'w') as f:
    f.write(text)

In [34]:
torch.save(pokemon_vectors, 'pokemon_vector.pt')
torch.save(move_vectors, 'move_vector.pt')