# Train ViT

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import timm
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define image transformations 
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset using ImageFolder
dataset_dir = 'Data/VIT'  
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)

# Split dataset into training and validation sets (70% training, 10% validation, 20% test)
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_dataset.dataset.transform = train_transform

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


# Load Vision Transformer (ViT) model with advanced architecture
model = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=2)
model.to(device)

# Define loss function and optimizer with learning rate scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)

# Training loop with early stopping
num_epochs = 3
best_val_accuracy = 0
patience = 3
patience_counter = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    # Wrap the training loop with tqdm for progress bar
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as train_bar:
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            # Update the progress bar with loss and accuracy
            train_accuracy = 100 * correct_train / total_train
            train_bar.set_postfix(loss=running_loss/len(train_bar), acc=train_accuracy)

    # Print loss and accuracy after each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%", flush=True)

    # Validation
    model.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        with tqdm(val_loader, desc="Validation") as val_bar:
            for images, labels in val_bar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

                # Update the validation progress bar
                val_accuracy = 100 * correct_val / total_val
                val_bar.set_postfix(acc=val_accuracy)

    print(f"Validation Accuracy: {val_accuracy:.2f}%")

    # Early stopping and model saving
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_vit_model.pth')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping due to no improvement.")
            break

    scheduler.step()

print(f"Best Validation Accuracy: {best_val_accuracy:.2f}%")


  from .autonotebook import tqdm as notebook_tqdm
Epoch 1/3: 100%|██████████| 438/438 [5:02:13<00:00, 41.40s/it, acc=88.5, loss=0.275]  

Epoch [1/3], Loss: 0.2746, Training Accuracy: 88.50%



Validation: 100%|██████████| 63/63 [14:30<00:00, 13.82s/it, acc=94]  


Validation Accuracy: 94.00%


Epoch 2/3: 100%|██████████| 438/438 [5:02:10<00:00, 41.39s/it, acc=94.1, loss=0.143]   

Epoch [2/3], Loss: 0.1426, Training Accuracy: 94.14%



Validation: 100%|██████████| 63/63 [14:29<00:00, 13.81s/it, acc=94.8]


Validation Accuracy: 94.85%


Epoch 3/3: 100%|██████████| 438/438 [5:02:14<00:00, 41.40s/it, acc=95.9, loss=0.103]   

Epoch [3/3], Loss: 0.1028, Training Accuracy: 95.90%



Validation: 100%|██████████| 63/63 [14:30<00:00, 13.82s/it, acc=95.7]


Validation Accuracy: 95.65%
Best Validation Accuracy: 95.65%


# Test ViT

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import timm
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset using ImageFolder
dataset_dir = 'Data/VIT'  
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)

# Split dataset into training and validation sets (70% training, 10% validation, 20% test)
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_dataset.dataset.transform = train_transform

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Load model
model = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=2)
model.load_state_dict(torch.load('best_vit_model.pth'))
model.to(device)
model.eval()

# Evaluate on the test set
correct_test = 0
total_test = 0

with torch.no_grad():
    with tqdm(test_loader, desc="Testing") as test_bar:
        for images, labels in test_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

            # Update progress bar
            test_accuracy = 100 * correct_test / total_test
            test_bar.set_postfix(acc=test_accuracy)

print(f"Test Accuracy: {test_accuracy:.2f}%")


  model.load_state_dict(torch.load('best_vit_model.pth'))
Testing: 100%|██████████| 125/125 [1:03:40<00:00, 30.56s/it, acc=96]  

Test Accuracy: 96.00%



