In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import os
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.data import DataLoader, random_split
from vit_pytorch.simple_vit import SimpleViT
from vit_pytorch.vit import ViT
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm  # To display progress bars

In [36]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(root='./data', transform=transform)

# Split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=2,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
)

# model = SimpleViT(
#     image_size=224,
#     patch_size=32,
#     num_classes=2,
#     dim=512,
#     depth=6,
#     heads=8,
#     mlp_dim=1024
# )



In [46]:
class DepthModelTrainer:
    def __init__(self, model, train_loader, test_loader, lr=0.0005, num_epochs=10, custom_threshold=0.8):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.num_epochs = num_epochs
        self.custom_threshold = custom_threshold
        self.best_accuracy = 0.0

        # Check for mps device
        self.device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
        self.model.to(self.device)

    def train(self):
        for epoch in range(self.num_epochs):
            self.model.train()
            running_loss = 0.0
            for images, labels in tqdm(self.train_loader, desc=f"Training Epoch {epoch + 1}/{self.num_epochs}"):
                images, labels = images.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item() * images.size(0)

            epoch_loss = running_loss / len(self.train_loader.dataset)
            print(f'Epoch [{epoch + 1}/{self.num_epochs}], Loss: {epoch_loss:.4f}')

            self.evaluate()

    def evaluate(self):
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in tqdm(self.test_loader, desc="Validating"):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                probabilities = F.softmax(outputs, dim=1)
                predicted = (probabilities > self.custom_threshold).int()  # Apply custom threshold

                # Ensure only one class is selected
                predicted = predicted.argmax(dim=1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Accuracy: {accuracy:.2f}%')

        # Save the best model
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            torch.save(self.model.state_dict(), 'best_model.pth')
            print('Model saved!')

        print(f'Best Accuracy: {self.best_accuracy:.2f}%')


In [38]:
# Example usage:
# Ensure the model is loaded and the transform is defined
directory = "./frames/"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
trainer = DepthModelTrainer(model, train_loader, test_loader)


In [39]:
trainer.train()

Training Epoch 1/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [1/10], Loss: 0.6389


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Accuracy: 91.72%
Model saved!
Best Accuracy: 91.72%


Training Epoch 2/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [2/10], Loss: 0.0946


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Accuracy: 93.38%
Model saved!
Best Accuracy: 93.38%


Training Epoch 3/10: 100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch [3/10], Loss: 0.0730


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Accuracy: 92.72%
Best Accuracy: 93.38%


Training Epoch 4/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [4/10], Loss: 0.0668


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Accuracy: 95.70%
Model saved!
Best Accuracy: 95.70%


Training Epoch 5/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [5/10], Loss: 0.0530


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 94.70%
Best Accuracy: 95.70%


Training Epoch 6/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [6/10], Loss: 0.0508


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 94.04%
Best Accuracy: 95.70%


Training Epoch 7/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [7/10], Loss: 0.0665


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 93.05%
Best Accuracy: 95.70%


Training Epoch 8/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [8/10], Loss: 0.0611


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Accuracy: 94.37%
Best Accuracy: 95.70%


Training Epoch 9/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [9/10], Loss: 0.0408


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Accuracy: 94.04%
Best Accuracy: 95.70%


Training Epoch 10/10: 100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch [10/10], Loss: 0.0497


Validating: 100%|██████████| 10/10 [00:05<00:00,  1.79it/s]

Accuracy: 94.04%
Best Accuracy: 95.70%





In [40]:
trainer.evaluate()

Validating: 100%|██████████| 10/10 [00:06<00:00,  1.56it/s]

Accuracy: 94.04%
Best Accuracy: 95.70%



