In [1]:
import torch.nn as nn
from torchvision import models
import torch
import torchvision.transforms as transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import gc

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # GPUデバイスを取得
else:
    device = torch.device("cpu")  # CPUデバイスを取得

In [3]:
import sys
import os

sys.path.append(os.path.abspath("../"))
from learning.Entity.CustomDataset import PokemonMoveDataset


dataset = PokemonMoveDataset('D:/tanaka/Documents/poke-move/data/my-dataset/dataset.txt', device)

pokemon_vectors = torch.load('./pokemon_vector.pt')
move_vectors = torch.load('./move_vector.pt')

In [4]:
import random

move_rankings = []
pokemon_ids = [random.randint(1, 1010) for i in range(5)]
for pokemon_id in pokemon_ids:
    # 各ベクトルのユークリッド距離を計算
    distances = torch.norm(pokemon_vectors[pokemon_id] - move_vectors, dim=1)

    # 距離を用いてベクトルBの添え字を並び替え
    sorted_indices = torch.argsort(distances)
    move_rankings.append(sorted_indices)


In [5]:
import json
text = ''
remove_move_ids = set([14, 18, 28, 39, 43, 45, 46, 47, 48, 50, 54, 73, 74, 77, 78, 79, 81, 86, 92, 95, 96, 97, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 118, 119, 133, 134, 135, 137, 139, 142, 144, 147, 148, 150, 151, 156, 159, 160, 164, 165, 166, 169, 170, 171, 174, 176, 178, 180, 182, 184, 186, 187, 191, 193, 194, 195, 197, 199, 201, 203, 204, 207, 208, 212, 213, 214, 215, 219, 220, 226, 227, 230, 234, 235, 236, 240, 241, 244, 254, 256, 258, 259, 260, 261, 262, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 281, 285, 286, 287, 288, 289, 293, 294, 297, 298, 300, 303, 312, 313, 316, 319, 320, 321, 322, 334, 335, 336, 339, 346, 347, 349, 355, 356, 357, 361, 366, 367, 373, 375, 377, 379, 380, 381, 382, 383, 384, 385, 388, 390, 391, 392, 393, 397, 415, 417, 432, 433, 445, 446, 455, 456, 461, 464, 468, 469, 470, 471, 472, 475, 476, 477, 478, 483, 487, 489, 493, 494, 495, 501, 502, 504, 505, 508, 511, 513, 516, 526, 538, 561, 563, 564, 567, 568, 569, 571, 575, 576, 578, 579, 580, 581, 582, 587, 588, 589, 590, 596, 597, 598, 599, 600, 601, 602, 603, 604, 606, 607, 608, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 661, 666, 668, 671, 672, 673, 674, 678, 683, 685, 689, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 715, 719, 723, 724, 725, 726, 727, 728, 743, 747, 748, 749, 750, 752, 753, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 777, 791, 792, 810, 811, 816, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900])
for i in range(len(move_rankings)):
    move_rankings[i] = [move_id for move_id in move_rankings[i] if move_id.item() + 1 not in remove_move_ids]

for i, pokemon_id in enumerate(pokemon_ids):
    with open(f'D:/tanaka/Documents/poke-move/data\pokemons\{pokemon_id}.json') as f:
        pokemon = json.load(f)
    learned_moves = list(map(lambda x: x["move"]["name"], pokemon["moves"]))
    print(f'選ばれたポケモン: {pokemon["name"]}')
    
    for s in range(0, 801, 100):
        cnt = 0
        for j, move_id in enumerate(move_rankings[i][s:s+100]):
            move = dataset.moves[move_id]
            cnt += 1 if move.name in learned_moves else 0
        print(f'top{s}~top{s+100}: {cnt}', end=' ')
    print()

選ばれたポケモン: tropius
top0~top100: 9 top100~top200: 16 top200~top300: 12 top300~top400: 10 top400~top500: 7 top500~top600: 0 top600~top700: 0 top700~top800: 0 top800~top900: 0 
選ばれたポケモン: nihilego
top0~top100: 10 top100~top200: 8 top200~top300: 12 top300~top400: 9 top400~top500: 3 top500~top600: 0 top600~top700: 0 top700~top800: 0 top800~top900: 0 
選ばれたポケモン: umbreon
top0~top100: 7 top100~top200: 13 top200~top300: 12 top300~top400: 8 top400~top500: 6 top500~top600: 0 top600~top700: 0 top700~top800: 0 top800~top900: 0 
選ばれたポケモン: tauros
top0~top100: 12 top100~top200: 14 top200~top300: 18 top300~top400: 14 top400~top500: 10 top500~top600: 0 top600~top700: 0 top700~top800: 0 top800~top900: 0 
選ばれたポケモン: banette
top0~top100: 10 top100~top200: 7 top200~top300: 10 top300~top400: 8 top400~top500: 6 top500~top600: 0 top600~top700: 0 top700~top800: 0 top800~top900: 0 
