In [1]:
import torch

In [2]:
import torchvision.transforms as transforms

In [3]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [4]:
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from PIL import Image

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
device

device(type='cpu')

In [7]:
transform = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
transform

Compose(
    Resize(size=(227, 227), interpolation=bilinear, max_size=None, antialias=True)
    RandomHorizontalFlip(p=0.5)
    RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [9]:
dataset = ImageFolder(root='dataset/images', transform=transform)

In [10]:
dataset

Dataset ImageFolder
    Number of datapoints: 2503
    Root location: dataset/images
    StandardTransform
Transform: Compose(
               Resize(size=(227, 227), interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [11]:
# Check if the dataset is loaded correctly
print(f"Number of classes: {len(dataset.classes)}")
print(f"Class names: {dataset.classes}")
print(f"Number of images: {len(dataset)}")

Number of classes: 898
Class names: ['Abomasnow', 'Abra', 'Absol', 'Accelgor', 'Aegislash', 'Aerodactyl', 'Aggron', 'Aipom', 'Alakazam', 'Alcremie', 'Alomomola', 'Altaria', 'Amaura', 'Ambipom', 'Amoonguss', 'Ampharos', 'Anorith', 'Appletun', 'Applin', 'Araquanid', 'Arbok', 'Arcanine', 'Arceus', 'Archen', 'Archeops', 'Arctovish', 'Arctozolt', 'Ariados', 'Armaldo', 'Aromatisse', 'Aron', 'Arrokuda', 'Articuno', 'Audino', 'Aurorus', 'Avalugg', 'Axew', 'Azelf', 'Azumarill', 'Azurill', 'Bagon', 'Baltoy', 'Banette', 'Barbaracle', 'Barboach', 'Barraskewda', 'Basculin', 'Bastiodon', 'Bayleef', 'Beartic', 'Beautifly', 'Beedrill', 'Beheeyem', 'Beldum', 'Bellossom', 'Bellsprout', 'Bergmite', 'Bewear', 'Bibarel', 'Bidoof', 'Binacle', 'Bisharp', 'Blacephalon', 'Blastoise', 'Blaziken', 'Blipbug', 'Blissey', 'Blitzle', 'Boldore', 'Boltund', 'Bonsly', 'Bouffalant', 'Bounsweet', 'Braixen', 'Braviary', 'Breloom', 'Brionne', 'Bronzong', 'Bronzor', 'Bruxish', 'Budew', 'Buizel', 'Bulbasaur', 'Buneary', 'Bun

In [12]:
# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

In [13]:
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [14]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [15]:
# Loading the pre-trained AlexNet model
model = models.alexnet(pretrained=True)



In [16]:
# Modify the classifier to match the number of classes in your dataset
num_classes = len(dataset.classes)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

In [17]:
num_classes

898

In [18]:
# Move the model to the device (GPU/CPU)
model = model.to(device)

In [19]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [20]:
# Number of epochs to train
num_epochs = 100

# Try to load checkpoint if exists
try:
    checkpoint = torch.load('model_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    print(f"Loaded model checkpoint from epoch {start_epoch}")
except FileNotFoundError:
    start_epoch = 0
    print("No checkpoint found, starting training from scratch.")

Loaded model checkpoint from epoch 10


In [21]:
# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

Epoch [11/100], Loss: 1.472314415736632
Epoch [12/100], Loss: 0.34906055141579023
Epoch [13/100], Loss: 0.09423339443958618
Epoch [14/100], Loss: 0.07873648642985658
Epoch [15/100], Loss: 0.07670900476461445
Epoch [16/100], Loss: 0.07765107488429004
Epoch [17/100], Loss: 0.03896035540171645
Epoch [18/100], Loss: 0.05776336357336153
Epoch [19/100], Loss: 0.03375180332328786
Epoch [20/100], Loss: 0.034716771940954703
Epoch [21/100], Loss: 0.02448503641670951
Epoch [22/100], Loss: 0.05092530411943285
Epoch [23/100], Loss: 0.054636814223009755
Epoch [24/100], Loss: 0.03769262673820115
Epoch [25/100], Loss: 0.048040961260928515
Epoch [26/100], Loss: 0.03326623877827925
Epoch [27/100], Loss: 0.031844433267939495
Epoch [28/100], Loss: 0.03969400988214395
Epoch [29/100], Loss: 0.02560193930945719
Epoch [30/100], Loss: 0.052604428155940366
Epoch [31/100], Loss: 0.07839573502424173
Epoch [32/100], Loss: 0.062363267002034595
Epoch [33/100], Loss: 0.06119555445514958
Epoch [34/100], Loss: 0.058067

In [22]:
checkpoint_path = 'model_checkpoint.pth'

# Save the model and optimizer state
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, checkpoint_path)

print(f"Model checkpoint saved at epoch {epoch+1}")

Model checkpoint saved at epoch 100


In [23]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation Accuracy: {100 * correct / total}%')

Validation Accuracy: 64.26666666666667%


In [30]:
image_path='./Input_Pokaemon/4.jpg'

def predict_image(image_path, model, transform):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)
    return dataset.classes[predicted.item()]

predicted_class = predict_image(image_path, model, transform)
print(f'The predicted class is: {predicted_class}')

The predicted class is: Zigzagoon


In [24]:
import os
os.getcwd()

'/home/pantho/projects/Own project/pokemon/pokemon'