In [None]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from torch import nn, optim
import pickle
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
class PokemonDataset:
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        selected = self.df.iloc[idx, :]
        image = torch.tensor(selected["image"]).float()
        types = torch.tensor(selected["types"]).float()
        return (image, types)

In [None]:
class SimpleModel(nn.Module):
    def __init__(self, dim_image, num_types):
        super().__init__()
        self.fc1 = nn.Linear(dim_image, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc_out = nn.Linear(32, num_types)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return self.fc_out(x)

In [None]:
with open("./train.pkl", "rb") as pf:
    train = pickle.load(pf)
with open("./test.pkl", "rb") as pf:
    test = pickle.load(pf)

In [None]:
train_dataset = PokemonDataset(train)
test_dataset = PokemonDataset(test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
dim_image = len(train_dataset[0][0])
num_types = len(train_dataset[0][1])

In [None]:
model = SimpleModel(dim_image, num_types)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
num_epochs = 100
loss_fn = loss_fn = nn.CrossEntropyLoss()
for e in range(num_epochs):
    epoch_loss = 0
    for image, types in train_loader:
        optimizer.zero_grad()
        pred = model(image)
        loss = loss_fn(pred, types)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
    print(f"Epoch {e}: Loss {epoch_loss}")

In [None]:
k = 10
model.eval()
avg_recall = []
with torch.no_grad():
    for image, types in test_loader:
        pred = model(image)
        recall = types.gather(1, pred.argsort(axis=1))[:, :k].sum(axis=1).squeeze() / types.sum(axis=1)
        avg_recall.append(recall.mean())
recall_at_k = torch.stack(avg_recall).mean().item()