In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import os
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.data import DataLoader, random_split
from vit_pytorch.simple_vit import SimpleViT
from vit_pytorch.vit import ViT
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm  # To display progress bars

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(root='./data', transform=transform)

# Split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [6]:
model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=2,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
)

# model = SimpleViT(
#     image_size=224,
#     patch_size=32,
#     num_classes=2,
#     dim=512,
#     depth=6,
#     heads=8,
#     mlp_dim=1024
# )



In [7]:


# Check for mps device
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

num_epochs = 10
best_accuracy = 0.0
custom_threshold = 0.5  # Set your custom threshold here

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            predicted = (probabilities > custom_threshold).int()  # Apply custom threshold
            
            # Ensure only one class is selected
            predicted = predicted.argmax(dim=1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')

    # Save the best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')
        print('Model saved!')

print(f'Best Accuracy: {best_accuracy:.2f}%')


Training Epoch 1/10: 100%|██████████| 38/38 [00:29<00:00,  1.30it/s]


Epoch [1/10], Loss: 0.4072


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Accuracy: 94.04%
Model saved!


Training Epoch 2/10: 100%|██████████| 38/38 [00:27<00:00,  1.40it/s]


Epoch [2/10], Loss: 0.0972


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Accuracy: 94.04%


Training Epoch 3/10: 100%|██████████| 38/38 [00:27<00:00,  1.40it/s]


Epoch [3/10], Loss: 0.0839


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 96.36%
Model saved!


Training Epoch 4/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [4/10], Loss: 0.0665


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 95.03%


Training Epoch 5/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [5/10], Loss: 0.0582


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Accuracy: 94.04%


Training Epoch 6/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [6/10], Loss: 0.0630


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Accuracy: 95.36%


Training Epoch 7/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [7/10], Loss: 0.0497


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Accuracy: 95.36%


Training Epoch 8/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [8/10], Loss: 0.0428


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Accuracy: 96.69%
Model saved!


Training Epoch 9/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [9/10], Loss: 0.0437


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Accuracy: 96.03%


Training Epoch 10/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [10/10], Loss: 0.0360


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.76it/s]

Accuracy: 97.02%
Model saved!
Best Accuracy: 97.02%





In [None]:
model.eval()
correct = 0
total = 0

# Function to convert numerical labels to text labels
def convert_labels(label):
    return "D" if label == 0 else "ND"

# Function to display images with predictions and true labels
def show_images(images, labels, preds, num_images=10):
    num_cols = 4
    num_rows = (num_images // num_cols) + (num_images % num_cols > 0)
    plt.figure(figsize=(num_cols * 5, num_rows * 5))  # Adjust the figure size to ensure larger images
    for i in range(min(num_images, len(images))):
        ax = plt.subplot(num_rows, num_cols, i + 1)
        ax.imshow(images[i].permute(1, 2, 0) * 0.5 + 0.5)  # Unnormalize
        ax.set_title(f'True: {convert_labels(labels[i])}, Pred: {convert_labels(preds[i])}', fontsize=14)
        ax.axis('off')
    plt.subplots_adjust(wspace=0.3, hspace=0.3)  # Adjust the spacing between plots
    plt.show()

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Display a batch of images and predictions
        show_images(images.cpu(), labels.cpu().numpy(), predicted.cpu().numpy(), num_images=10)

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
