In [17]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
from tqdm import tqdm
import os

BATCH_SIZE = 64
IMG_SIZE = 224
EPOCHS = 20
NUM_CLASSES = 102

if not os.path.exists('flower_classes.txt'):
    class_names = [f'class_{i+1:03d}' for i in range(NUM_CLASSES)]
    with open('flower_classes.txt', 'w') as f:
        f.write('\n'.join(class_names))

with open('flower_classes.txt') as f:
    class_names = [line.strip() for line in f.readlines()]


train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


train_data = datasets.Flowers102(
    root='./data',
    split='train',
    transform=train_transform,
    download=True
)

test_data = datasets.Flowers102(
    root='./data',
    split='test',
    transform=test_transform
)


train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_data, BATCH_SIZE, shuffle=False, num_workers=0)

model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')
model.classifier[3] = torch.nn.Linear(model.classifier[3].in_features, NUM_CLASSES)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for param in model.parameters():
    param.requires_grad = False
for param in model.classifier.parameters():
    param.requires_grad = True

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0001, alpha=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

best_acc = 0.0
progress = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    
    epoch_acc = 100 * correct / total
    progress.append(epoch_acc)
    
    print(f'\nEpoch {epoch+1} Results:')
    print(f'Accuracy: {epoch_acc:.2f}%')
    print('-'*50)
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), 'best_model.pth')

model.load_state_dict(torch.load('best_model.pth', weights_only=True))
model.eval()

print('\nClassification Report:')
print(classification_report(all_labels, all_preds, target_names=class_names))

plt.figure(figsize=(20,15))
plt.imshow(confusion_matrix(all_labels, all_preds), cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.savefig('confusion_matrix.png')
plt.close()

plt.plot(progress)
plt.title('Validation Accuracy Progress')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.savefig('training_progress.png')
plt.close()

fig, axes = plt.subplots(3, 3, figsize=(15, 10))
for idx, ax in enumerate(axes.flat):
    image, label = test_data[np.random.randint(len(test_data))]
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device))
        pred = torch.argmax(output).item()
    
    image = image.cpu().numpy().transpose((1, 2, 0))
    image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    ax.set_title(f"True: {class_names[label]}\nPred: {class_names[pred]}", fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.savefig('sample_predictions.png')
plt.close()

pd.DataFrame({
    'True_Label': all_labels,
    'Predicted_Label': all_preds,
    'Correct': [t == p for t, p in zip(all_labels, all_preds)]
}).to_csv('predictions.csv', index=False)

print('\nOutput Files Generated:')
print('- confusion_matrix.png\n- training_progress.png\n- sample_predictions.png\n- predictions.csv\n- best_model.pth')

Epoch 1/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:26<00:00,  1.64s/it]



Epoch 1 Results:
Accuracy: 2.24%
--------------------------------------------------


Epoch 2/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.49s/it]



Epoch 2 Results:
Accuracy: 4.99%
--------------------------------------------------


Epoch 3/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.50s/it]



Epoch 3 Results:
Accuracy: 10.07%
--------------------------------------------------


Epoch 4/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.51s/it]



Epoch 4 Results:
Accuracy: 17.65%
--------------------------------------------------


Epoch 5/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.52s/it]



Epoch 5 Results:
Accuracy: 26.75%
--------------------------------------------------


Epoch 6/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:25<00:00,  1.59s/it]



Epoch 6 Results:
Accuracy: 34.66%
--------------------------------------------------


Epoch 7/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]



Epoch 7 Results:
Accuracy: 41.75%
--------------------------------------------------


Epoch 8/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.54s/it]



Epoch 8 Results:
Accuracy: 47.26%
--------------------------------------------------


Epoch 9/20: 100%|██████████████████████████████████████████████████████████████████████| 16/16 [00:25<00:00,  1.57s/it]



Epoch 9 Results:
Accuracy: 52.63%
--------------------------------------------------


Epoch 10/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.56s/it]



Epoch 10 Results:
Accuracy: 56.66%
--------------------------------------------------


Epoch 11/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:25<00:00,  1.61s/it]



Epoch 11 Results:
Accuracy: 60.37%
--------------------------------------------------


Epoch 12/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.54s/it]



Epoch 12 Results:
Accuracy: 63.25%
--------------------------------------------------


Epoch 13/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.47s/it]



Epoch 13 Results:
Accuracy: 65.39%
--------------------------------------------------


Epoch 14/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.47s/it]



Epoch 14 Results:
Accuracy: 67.86%
--------------------------------------------------


Epoch 15/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.47s/it]



Epoch 15 Results:
Accuracy: 69.85%
--------------------------------------------------


Epoch 16/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.46s/it]



Epoch 16 Results:
Accuracy: 71.51%
--------------------------------------------------


Epoch 17/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:25<00:00,  1.60s/it]



Epoch 17 Results:
Accuracy: 72.37%
--------------------------------------------------


Epoch 18/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.49s/it]



Epoch 18 Results:
Accuracy: 74.35%
--------------------------------------------------


Epoch 19/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.50s/it]



Epoch 19 Results:
Accuracy: 75.61%
--------------------------------------------------


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████| 16/16 [00:23<00:00,  1.48s/it]



Epoch 20 Results:
Accuracy: 76.00%
--------------------------------------------------

Classification Report:
              precision    recall  f1-score   support

   class_001       0.33      0.85      0.48        20
   class_002       1.00      0.93      0.96        40
   class_003       0.67      0.30      0.41        20
   class_004       0.41      0.36      0.38        36
   class_005       0.66      0.73      0.69        45
   class_006       0.73      0.88      0.80        25
   class_007       0.47      0.80      0.59        20
   class_008       0.85      0.98      0.91        65
   class_009       0.54      0.81      0.65        26
   class_010       0.92      0.96      0.94        25
   class_011       0.58      0.39      0.46        67
   class_012       0.92      0.87      0.89        67
   class_013       0.79      0.93      0.86        29
   class_014       0.64      0.96      0.77        28
   class_015       0.93      0.97      0.95        29
   class_016       0.50 