In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import time
import os
from PIL import ImageFile
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report
import numpy as np

In [7]:
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
#Biar kalau ada gambar yang corrupt, tetap jalan
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [4]:
#Parameters
data_directory = 'gen3_dataset'
batch_size = 32
learning_rate = 0.001
epochs = 30
img_size = 224


In [8]:
def train_model():
    #Pakai GPU kalo ada
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Datanya diubah-ubah agar model gak nginget gambar doang soalnya jumlah image dikit
    transform_train = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.RandomRotation(30),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    full_data = datasets.ImageFolder(data_directory, transform=transform_train)
    #Train Val Split 0.8:0.2
    train_size = int(0.8 * len(full_data))
    val_size = len(full_data) - train_size
    train_data, val_data = random_split(full_data, [train_size, val_size])

    val_data.dataset.transform = transform_val
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)

    class_name = full_data.classes
    num_classes = len(class_name)
    print(f'Classes Found: {class_name}, Train Size: {train_size}, Val Size: {val_size}')

    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V2)

    for param in model.features.parameters():
        param.requires_grad = False

    model.classifier[1] = nn.Linear(model.last_channel, num_classes)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / train_size
        epoch_acc = correct / total

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_top3 = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)

                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

                _, top3_pred = outputs.topk(3, 1, True, True)
                val_top3 += torch.eq(top3_pred, labels.view(-1, 1)).sum().item()
        
        val_epoch_loss = val_loss / val_size
        val_epoch_acc = val_correct / val_total
        val_epoch_top3 = val_top3 / val_total

        precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        print(f'Epoch {epoch+1}/{epochs} | Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f} Top-3 Acc: {val_epoch_top3:.4f}')

        if val_epoch_acc > best_acc:
            best_acc = val_epoch_acc
            torch.save(model.state_dict(), 'gen3_model_v3.pth')
            cm = confusion_matrix(all_labels, all_preds)
            plt.figure(figsize = (30, 30))
            sns.heatmap(cm, annot= False, fmt='d', cmap='Blues', xticklabels=class_name, yticklabels=class_name)
            plt.xlabel('Predicted Label')
            plt.ylabel('True Label')
            plt.title(f'Confusion Matrix (Epoch {epoch+1})')
            plt.xticks(rotation=90)
            plt.yticks(rotation=0)
            plt.tight_layout()
            plt.savefig('confusion_matrix.png')
            plt.close()
            print(f'Model Saved ({best_acc:.4f})')
            print(f'Confusion Matrix:\n{cm}')
            print(classification_report(all_labels, all_preds, target_names=class_name))

    print(f'Training Complete. Best Val Acc: {best_acc:.4f}')
if __name__ == '__main__':
    train_model()

Classes Found: ['Absol Pokemon', 'Aggron Pokemon', 'Altaria Pokemon', 'Anorith Pokemon', 'Armaldo Pokemon', 'Aron Pokemon', 'Azurill Pokemon', 'Bagon Pokemon', 'Baltoy Pokemon', 'Banette Pokemon', 'Barboach Pokemon', 'Beautifly Pokemon', 'Beldum Pokemon', 'Blaziken Pokemon', 'Breloom Pokemon', 'Cacnea Pokemon', 'Cacturne Pokemon', 'Camerupt Pokemon', 'Carvanha Pokemon', 'Cascoon Pokemon', 'Castform Pokemon', 'Chimecho Pokemon', 'Clamperl Pokemon', 'Claydol Pokemon', 'Combusken Pokemon', 'Corphish Pokemon', 'Cradily Pokemon', 'Crawdaunt Pokemon', 'Delcatty Pokemon', 'Deoxys Pokemon', 'Dusclops Pokemon', 'Duskull Pokemon', 'Dustox Pokemon', 'Electrike Pokemon', 'Exploud Pokemon', 'Feebas Pokemon', 'Flygon Pokemon', 'Gardevoir Pokemon', 'Glalie Pokemon', 'Gorebyss Pokemon', 'Groudon Pokemon', 'Grovyle Pokemon', 'Grumpig Pokemon', 'Gulpin Pokemon', 'Hariyama Pokemon', 'Huntail Pokemon', 'Illumise Pokemon', 'Jirachi Pokemon', 'Kecleon Pokemon', 'Kirlia Pokemon', 'Kyogre Pokemon', 'Lairon Po

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 2/30 | Train Loss: 2.7945 Acc: 0.6266 | Val Loss: 2.7249 Acc: 0.5725 Top-3 Acc: 0.7206
Model Saved (0.5725)
Confusion Matrix:
[[5 0 0 ... 0 0 0]
 [0 2 0 ... 0 0 2]
 [0 0 3 ... 0 0 0]
 ...
 [0 0 0 ... 2 0 0]
 [1 0 0 ... 0 5 0]
 [0 0 0 ... 0 0 5]]
                    precision    recall  f1-score   support

     Absol Pokemon       0.71      0.83      0.77         6
    Aggron Pokemon       1.00      0.18      0.31        11
   Altaria Pokemon       0.75      0.25      0.38        12
   Anorith Pokemon       1.00      0.40      0.57         5
   Armaldo Pokemon       0.00      0.00      0.00         5
      Aron Pokemon       0.23      1.00      0.38         3
   Azurill Pokemon       0.25      0.33      0.29         3
     Bagon Pokemon       0.50      0.25      0.33         4
    Baltoy Pokemon       0.83      0.45      0.59        11
   Banette Pokemon       1.00      0.09      0.17        11
  Barboach Pokemon       0.33      0.50      0.40         4
 Beautifly Pokemon       0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 3/30 | Train Loss: 1.9528 Acc: 0.7735 | Val Loss: 2.3012 Acc: 0.6061 Top-3 Acc: 0.7689
Model Saved (0.6061)
Confusion Matrix:
[[5 0 0 ... 0 0 0]
 [0 5 0 ... 0 0 0]
 [0 0 3 ... 0 0 0]
 ...
 [0 0 0 ... 2 0 0]
 [1 0 0 ... 0 5 0]
 [0 0 0 ... 0 0 5]]
                    precision    recall  f1-score   support

     Absol Pokemon       0.71      0.83      0.77         6
    Aggron Pokemon       0.83      0.45      0.59        11
   Altaria Pokemon       0.60      0.25      0.35        12
   Anorith Pokemon       1.00      0.40      0.57         5
   Armaldo Pokemon       0.33      0.20      0.25         5
      Aron Pokemon       0.43      1.00      0.60         3
   Azurill Pokemon       0.33      0.33      0.33         3
     Bagon Pokemon       0.20      0.25      0.22         4
    Baltoy Pokemon       1.00      0.45      0.62        11
   Banette Pokemon       1.00      0.27      0.43        11
  Barboach Pokemon       0.50      0.50      0.50         4
 Beautifly Pokemon       1.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 4/30 | Train Loss: 1.4582 Acc: 0.8400 | Val Loss: 1.9541 Acc: 0.6618 Top-3 Acc: 0.7826
Model Saved (0.6618)
Confusion Matrix:
[[5 0 1 ... 0 0 0]
 [0 8 0 ... 0 0 0]
 [0 0 9 ... 0 0 0]
 ...
 [0 0 0 ... 2 0 0]
 [0 0 2 ... 0 5 0]
 [0 0 0 ... 0 0 4]]
                    precision    recall  f1-score   support

     Absol Pokemon       0.83      0.83      0.83         6
    Aggron Pokemon       0.80      0.73      0.76        11
   Altaria Pokemon       0.53      0.75      0.62        12
   Anorith Pokemon       1.00      0.40      0.57         5
   Armaldo Pokemon       0.50      0.20      0.29         5
      Aron Pokemon       0.60      1.00      0.75         3
   Azurill Pokemon       0.33      0.33      0.33         3
     Bagon Pokemon       0.50      0.50      0.50         4
    Baltoy Pokemon       0.70      0.64      0.67        11
   Banette Pokemon       0.62      0.45      0.53        11
  Barboach Pokemon       0.27      0.75      0.40         4
 Beautifly Pokemon       1.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 7/30 | Train Loss: 0.7595 Acc: 0.9317 | Val Loss: 1.5922 Acc: 0.6796 Top-3 Acc: 0.8046
Epoch 8/30 | Train Loss: 0.6339 Acc: 0.9446 | Val Loss: 1.4972 Acc: 0.6891 Top-3 Acc: 0.8162
Model Saved (0.6891)
Confusion Matrix:
[[5 0 1 ... 0 0 0]
 [0 8 0 ... 0 0 0]
 [0 0 8 ... 0 0 0]
 ...
 [0 0 0 ... 2 0 0]
 [0 0 2 ... 0 5 0]
 [0 0 0 ... 0 0 5]]
                    precision    recall  f1-score   support

     Absol Pokemon       0.83      0.83      0.83         6
    Aggron Pokemon       0.73      0.73      0.73        11
   Altaria Pokemon       0.53      0.67      0.59        12
   Anorith Pokemon       1.00      0.40      0.57         5
   Armaldo Pokemon       0.40      0.40      0.40         5
      Aron Pokemon       0.50      1.00      0.67         3
   Azurill Pokemon       0.50      0.33      0.40         3
     Bagon Pokemon       0.50      0.50      0.50         4
    Baltoy Pokemon       0.78      0.64      0.70        11
   Banette Pokemon       0.67      0.36      0.47     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 9/30 | Train Loss: 0.5440 Acc: 0.9611 | Val Loss: 1.4392 Acc: 0.6975 Top-3 Acc: 0.8309
Model Saved (0.6975)
Confusion Matrix:
[[5 0 1 ... 0 0 0]
 [0 9 0 ... 0 0 0]
 [1 0 7 ... 0 0 0]
 ...
 [0 0 0 ... 2 0 0]
 [0 0 2 ... 0 5 0]
 [0 0 0 ... 0 0 5]]
                    precision    recall  f1-score   support

     Absol Pokemon       0.71      0.83      0.77         6
    Aggron Pokemon       0.75      0.82      0.78        11
   Altaria Pokemon       0.50      0.58      0.54        12
   Anorith Pokemon       1.00      0.40      0.57         5
   Armaldo Pokemon       0.33      0.20      0.25         5
      Aron Pokemon       0.50      1.00      0.67         3
   Azurill Pokemon       0.67      0.67      0.67         3
     Bagon Pokemon       0.50      0.50      0.50         4
    Baltoy Pokemon       0.67      0.73      0.70        11
   Banette Pokemon       0.80      0.36      0.50        11
  Barboach Pokemon       0.50      0.75      0.60         4
 Beautifly Pokemon       0.

In [6]:
pip install streamlit

Collecting streamlit
  Downloading streamlit-1.52.1-py3-none-any.whl.metadata (9.8 kB)
Collecting altair!=5.4.0,!=5.4.1,<7,>=4.0 (from streamlit)
  Downloading altair-6.0.0-py3-none-any.whl.metadata (11 kB)
Collecting blinker<2,>=1.5.0 (from streamlit)
  Downloading blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting cachetools<7,>=4.0 (from streamlit)
  Downloading cachetools-6.2.2-py3-none-any.whl.metadata (5.6 kB)
Collecting click<9,>=7.0 (from streamlit)
  Downloading click-8.3.1-py3-none-any.whl.metadata (2.6 kB)
Collecting pyarrow>=7.0 (from streamlit)
  Downloading pyarrow-22.0.0-cp313-cp313-win_amd64.whl.metadata (3.3 kB)
Collecting tenacity<10,>=8.1.0 (from streamlit)
  Downloading tenacity-9.1.2-py3-none-any.whl.metadata (1.2 kB)
Collecting toml<2,>=0.10.1 (from streamlit)
  Downloading toml-0.10.2-py2.py3-none-any.whl.metadata (7.1 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-win_amd64.whl.metadata (44 kB)
Collecting gitpyt


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip
