Used dataset:


https://www.kaggle.com/datasets/vishalsubbiah/pokemon-images-and-types/data

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import pandas as pd
from PIL import Image
import kagglehub
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score

In [29]:
torch.cuda.is_available()

True

In [30]:
path = kagglehub.dataset_download("vishalsubbiah/pokemon-images-and-types")

df = pd.read_csv(path+"/pokemon.csv")
df = df[['Name', 'Type1']]
df.rename(columns={'Type1': 'Type'}, inplace=True)

df = df.sample(frac=1, random_state=42).reset_index(drop=True)

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['Type'])

train_df.to_csv("train.csv", index=False)
test_df.to_csv("test.csv", index=False)



In [31]:
train_df.groupby('Type').count()

Unnamed: 0_level_0,Name
Type,Unnamed: 1_level_1
Bug,58
Dark,23
Dragon,22
Electric,32
Fairy,14
Fighting,23
Fire,42
Flying,3
Ghost,22
Grass,62


In [32]:
test_df.groupby('Type').count()

Unnamed: 0_level_0,Name
Type,Unnamed: 1_level_1
Bug,14
Dark,6
Dragon,5
Electric,8
Fairy,4
Fighting,6
Fire,11
Ghost,5
Grass,16
Ground,6


In [33]:
class PokemonDataset(Dataset):
  def __init__(self, csv_file, path, transform=None):
    self.data = pd.read_csv(csv_file)
    self.path = path
    if transform:
      self.transform = transform
    else:
      self.transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    self.labels = {c: i for i, c in enumerate(self.data['Type'].unique())}

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path = f"{self.path}/images/{self.data.iloc[idx]['Name']}.png"
    image = Image.open(img_path).convert('RGB')
    label = self.labels[self.data.iloc[idx]['Type']]

    image = self.transform(image)

    return image, label

In [34]:
class WideCNN(nn.Module):
    def __init__(self, num_classes):
        super(WideCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(256 * 28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.drop = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))

        x = x.view(x.size(0), -1)
        x = self.drop(F.relu(self.fc1(x)))
        x = self.drop(F.relu(self.fc2(x)))
        x = self.fc3(x)

        return x

In [35]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  # случайное отражение
    transforms.RandomRotation(10),  # случайный поворот
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # изменение яркости и контраста
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [36]:
batch_size = 64

train_dataset = PokemonDataset("train.csv", path=path, transform=transform)
test_dataset = PokemonDataset("test.csv", path=path, transform=None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = len(train_dataset.labels)
model = WideCNN(num_classes=num_classes)

model = model.to(device)

In [38]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001, weight_decay=1e-3)

In [47]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_labels = []
    all_predictions = []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    print(f"Precision: {precision_score(all_labels, all_predictions, average='macro', zero_division=1):.2f}")
    print(f"Recall: {recall_score(all_labels, all_predictions, average='macro', zero_division=1):.2f}")


import IPython
display(IPython.display.Audio(url="https://static.sfdict.com/audio/C07/C0702600.mp3", autoplay=True))

Epoch [1/100], Loss: 1.0796
Precision: 0.71
Recall: 0.59
Epoch [2/100], Loss: 1.0052
Precision: 0.70
Recall: 0.60
Epoch [3/100], Loss: 1.0910
Precision: 0.70
Recall: 0.59
Epoch [4/100], Loss: 1.0285
Precision: 0.70
Recall: 0.60
Epoch [5/100], Loss: 0.9803
Precision: 0.75
Recall: 0.65
Epoch [6/100], Loss: 1.0531
Precision: 0.72
Recall: 0.60
Epoch [7/100], Loss: 0.9855
Precision: 0.73
Recall: 0.61
Epoch [8/100], Loss: 1.0254
Precision: 0.73
Recall: 0.63
Epoch [9/100], Loss: 1.0184
Precision: 0.71
Recall: 0.66
Epoch [10/100], Loss: 1.0741
Precision: 0.70
Recall: 0.57
Epoch [11/100], Loss: 1.0231
Precision: 0.73
Recall: 0.62
Epoch [12/100], Loss: 1.0438
Precision: 0.70
Recall: 0.58
Epoch [13/100], Loss: 0.9844
Precision: 0.72
Recall: 0.63
Epoch [14/100], Loss: 1.0175
Precision: 0.69
Recall: 0.62
Epoch [15/100], Loss: 1.0138
Precision: 0.67
Recall: 0.60
Epoch [16/100], Loss: 1.0682
Precision: 0.72
Recall: 0.63
Epoch [17/100], Loss: 0.9781
Precision: 0.72
Recall: 0.61
Epoch [18/100], Loss: 1

In [46]:
model.eval()

total = 0

correct = 0
test_loss = 0.0
all_labels = []
all_predictions = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        _, predicted = torch.max(outputs, 1)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, labels)
        test_loss += loss.item() * labels.size(0)

        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
test_loss = test_loss/total
precision = precision_score(all_labels, all_predictions, average='macro', zero_division=1)
recall = recall_score(all_labels, all_predictions, average='macro', zero_division=1)

print(f"Test Accuracy: {accuracy:.2f}%")
print(f"Test Loss: {test_loss:.4f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

Test Accuracy: 7.41%
Test Loss: 4.0204
Precision: 0.18
Recall: 0.06


In [41]:
def imshow(img, std, mean, t_label, p_label):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)

    img = img * std + mean
    npimg = img.numpy()

    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis("off")
    plt.show()
    print(f"True label: {label_list[t_label]}")
    print(f"Predicted label: {label_list[p_label]}")

cnt = 0

label_list = list(test_dataset.labels.keys())
for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    cnt += 1

    imshow(images[0], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], t_label = labels[0].item(), p_label = predicted[0].item())
    if cnt == 5:
      break

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!