In [24]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import torch.optim as optim
from tqdm import tqdm
from PIL import Image

# Option 1: Simple CNN from scratch
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

# Option 2: Transfer Learning with ResNet
class ResNetTransfer(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetTransfer, self).__init__()
        # Load pretrained ResNet18
        self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        
        # Freeze all layers
        for param in self.resnet.parameters():
            param.requires_grad = False
            
        # Replace the final fully connected layer
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.resnet(x)
        

# Example usage:
def main():
    # Create instances of both models
    simple_cnn = SimpleCNN(num_classes=10)
    resnet_transfer = ResNetTransfer(num_classes=10)
    
    # Print model summaries
    print("Simple CNN Architecture:")
    print(simple_cnn)
    print("\nResNet Transfer Learning Architecture:")
    print(resnet_transfer)

if __name__ == "__main__":
    main()

Simple CNN Architecture:
SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

ResNet Transfer Learning Architecture:
ResNetTransfer(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kern

In [25]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import os
from torch.utils.data import Dataset, DataLoader,  random_split

In [69]:
# Option 1: Simple CNN from scratch
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 21 * 28, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))

        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)

        return x
    
    def train_model(self, train_loader, num_epochs=10, lr=0.01):
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=lr)
        best_accuracy = 0.0

        for epoch in range(num_epochs):
            self.train()
            running_loss, correct, total = 0.0, 0, 0
            loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

            for images, labels in loop:
                images, labels = images.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                outputs = self(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

                loop.set_postfix(loss=loss.item(), acc=100 * correct / total)

            epoch_acc = 100 * correct / total
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {epoch_acc:.2f}%")

            if epoch_acc > best_accuracy:
                best_accuracy = epoch_acc
                torch.save(self.state_dict(), "best_model.pth")
                print("Model saved!")


    def evaluate_model(self, test_loader):
        self.eval()
        correct, total, test_loss = 0, 0, 0.0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        test_accuracy = 100 * correct / total
        print(f"\n📊 Test Loss: {test_loss / len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%")




In [70]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.class_to_idx =  {}

        # Assign class labels based on folder names
        class_folders = sorted(os.listdir(root_dir))
        for idx, class_name in enumerate(class_folders):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_to_idx[class_name] = idx # Assign a  label
                for img_name in os.listdir(class_path):
                    img_path = os.path.join(class_path, img_name)
                    self.data.append((img_path, idx))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label
    
transform = transforms.Compose([
    transforms.Resize((168, 224)),  # Smaller images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5,0.5])
])

In [None]:
dataset_path = r'document_classification\document_classification'

full_dataset = CustomImageDataset(root_dir=dataset_path, transform=transform)

In [None]:
# Split into train (80%) and test (20%)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

# Check dataset structure
print(f"Classes: {full_dataset.class_to_idx}")
print(f"Total images: {len(full_dataset)}, Train: {len(train_dataset)}, Test: {len(test_dataset)}")

Classes: {'email': 0, 'resume': 1, 'scientific_publication': 2}
Total images: 165, Train: 132, Test: 33


In [75]:
# Initialize model
num_classes = len(full_dataset.class_to_idx)  
model = SimpleCNN(num_classes)

# Train and Evaluate
model.train_model(train_loader, num_epochs=15, lr=0.001)
model.evaluate_model(test_loader)

Epoch 1/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  20%|██        | 1/5 [00:00<00:03,  1.28it/s, acc=40.6, loss=1.09]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  40%|████      | 2/5 [00:01<00:02,  1.24it/s, acc=39.1, loss=1.58]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  60%|██████    | 3/5 [00:02<00:01,  1.22it/s, acc=41.7, loss=1.04]

torch.Size([32, 64, 21, 28])


Epoch 1/15: 100%|██████████| 5/5 [00:03<00:00,  1.43it/s, acc=38.6, loss=0.968]

torch.Size([4, 64, 21, 28])
Epoch [1/15], Loss: 1.2045, Accuracy: 38.64%





Model saved!


Epoch 2/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  20%|██        | 1/5 [00:00<00:03,  1.13it/s, acc=62.5, loss=1.02]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  40%|████      | 2/5 [00:01<00:02,  1.18it/s, acc=59.4, loss=0.959]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  60%|██████    | 3/5 [00:02<00:01,  1.27it/s, acc=61.5, loss=0.901]

torch.Size([32, 64, 21, 28])


Epoch 2/15: 100%|██████████| 5/5 [00:03<00:00,  1.45it/s, acc=59.8, loss=1.21] 

torch.Size([4, 64, 21, 28])
Epoch [2/15], Loss: 0.9957, Accuracy: 59.85%





Model saved!


Epoch 3/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  20%|██        | 1/5 [00:00<00:03,  1.29it/s, acc=50, loss=0.924]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  40%|████      | 2/5 [00:01<00:02,  1.18it/s, acc=60.9, loss=0.721]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  60%|██████    | 3/5 [00:02<00:01,  1.20it/s, acc=62.5, loss=0.781]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  80%|████████  | 4/5 [00:03<00:00,  1.14it/s, acc=61.7, loss=0.843]

torch.Size([4, 64, 21, 28])


Epoch 3/15: 100%|██████████| 5/5 [00:03<00:00,  1.36it/s, acc=60.6, loss=1.09] 


Epoch [3/15], Loss: 0.8710, Accuracy: 60.61%
Model saved!


Epoch 4/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  20%|██        | 1/5 [00:00<00:03,  1.28it/s, acc=53.1, loss=0.727]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  40%|████      | 2/5 [00:01<00:02,  1.31it/s, acc=60.9, loss=0.675]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  60%|██████    | 3/5 [00:02<00:01,  1.27it/s, acc=65.6, loss=0.506]

torch.Size([32, 64, 21, 28])


Epoch 4/15: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s, acc=64.4, loss=0.485]

torch.Size([4, 64, 21, 28])
Epoch [4/15], Loss: 0.6199, Accuracy: 64.39%





Model saved!


Epoch 5/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  20%|██        | 1/5 [00:00<00:02,  1.38it/s, acc=78.1, loss=0.432]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  40%|████      | 2/5 [00:01<00:02,  1.33it/s, acc=75, loss=0.643]  

torch.Size([32, 64, 21, 28])


Epoch 5/15:  60%|██████    | 3/5 [00:02<00:01,  1.35it/s, acc=79.2, loss=0.356]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  80%|████████  | 4/5 [00:03<00:00,  1.32it/s, acc=80.5, loss=0.456]

torch.Size([4, 64, 21, 28])


Epoch 5/15: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s, acc=79.5, loss=0.538]


Epoch [5/15], Loss: 0.4849, Accuracy: 79.55%
Model saved!


Epoch 6/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  20%|██        | 1/5 [00:00<00:03,  1.30it/s, acc=78.1, loss=0.363]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  40%|████      | 2/5 [00:01<00:02,  1.32it/s, acc=76.6, loss=0.538]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  60%|██████    | 3/5 [00:02<00:01,  1.31it/s, acc=81.2, loss=0.388]

torch.Size([32, 64, 21, 28])


Epoch 6/15: 100%|██████████| 5/5 [00:03<00:00,  1.52it/s, acc=81.8, loss=0.476]

torch.Size([4, 64, 21, 28])
Epoch [6/15], Loss: 0.4455, Accuracy: 81.82%





Model saved!


Epoch 7/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  20%|██        | 1/5 [00:00<00:03,  1.29it/s, acc=90.6, loss=0.411]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  40%|████      | 2/5 [00:01<00:02,  1.26it/s, acc=87.5, loss=0.383]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  60%|██████    | 3/5 [00:02<00:01,  1.30it/s, acc=87.5, loss=0.535]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  80%|████████  | 4/5 [00:03<00:00,  1.32it/s, acc=88.3, loss=0.243]

torch.Size([4, 64, 21, 28])


Epoch 7/15: 100%|██████████| 5/5 [00:03<00:00,  1.51it/s, acc=87.9, loss=0.825]


Epoch [7/15], Loss: 0.4796, Accuracy: 87.88%
Model saved!


Epoch 8/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  20%|██        | 1/5 [00:00<00:02,  1.36it/s, acc=90.6, loss=0.347]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  40%|████      | 2/5 [00:01<00:02,  1.23it/s, acc=93.8, loss=0.159]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  60%|██████    | 3/5 [00:02<00:01,  1.29it/s, acc=91.7, loss=0.371]

torch.Size([32, 64, 21, 28])


Epoch 8/15: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s, acc=90.9, loss=0.114]

torch.Size([4, 64, 21, 28])
Epoch [8/15], Loss: 0.2832, Accuracy: 90.91%





Model saved!


Epoch 9/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  20%|██        | 1/5 [00:00<00:03,  1.24it/s, acc=96.9, loss=0.148]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  40%|████      | 2/5 [00:01<00:02,  1.29it/s, acc=92.2, loss=0.322]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  60%|██████    | 3/5 [00:02<00:01,  1.26it/s, acc=89.6, loss=0.387]

torch.Size([32, 64, 21, 28])


Epoch 9/15: 100%|██████████| 5/5 [00:03<00:00,  1.39it/s, acc=87.1, loss=0.296]


torch.Size([4, 64, 21, 28])
Epoch [9/15], Loss: 0.3074, Accuracy: 87.12%


Epoch 10/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  20%|██        | 1/5 [00:00<00:03,  1.33it/s, acc=90.6, loss=0.189]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  40%|████      | 2/5 [00:01<00:02,  1.21it/s, acc=89.1, loss=0.43] 

torch.Size([32, 64, 21, 28])


Epoch 10/15:  60%|██████    | 3/5 [00:02<00:01,  1.19it/s, acc=91.7, loss=0.114]

torch.Size([32, 64, 21, 28])


Epoch 10/15: 100%|██████████| 5/5 [00:03<00:00,  1.43it/s, acc=90.2, loss=0.5]  


torch.Size([4, 64, 21, 28])
Epoch [10/15], Loss: 0.2956, Accuracy: 90.15%


Epoch 11/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  20%|██        | 1/5 [00:00<00:03,  1.07it/s, acc=90.6, loss=0.215]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  40%|████      | 2/5 [00:01<00:02,  1.10it/s, acc=95.3, loss=0.0912]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  60%|██████    | 3/5 [00:02<00:01,  1.18it/s, acc=95.8, loss=0.187] 

torch.Size([32, 64, 21, 28])


Epoch 11/15: 100%|██████████| 5/5 [00:03<00:00,  1.34it/s, acc=96.2, loss=0.135]

torch.Size([4, 64, 21, 28])
Epoch [11/15], Loss: 0.1561, Accuracy: 96.21%





Model saved!


Epoch 12/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  20%|██        | 1/5 [00:00<00:03,  1.24it/s, acc=96.9, loss=0.244]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  40%|████      | 2/5 [00:01<00:02,  1.13it/s, acc=95.3, loss=0.161]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  60%|██████    | 3/5 [00:02<00:01,  1.14it/s, acc=94.8, loss=0.148]

torch.Size([32, 64, 21, 28])


Epoch 12/15: 100%|██████████| 5/5 [00:03<00:00,  1.39it/s, acc=95.5, loss=0.081]


torch.Size([4, 64, 21, 28])
Epoch [12/15], Loss: 0.1522, Accuracy: 95.45%


Epoch 13/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  20%|██        | 1/5 [00:00<00:03,  1.28it/s, acc=93.8, loss=0.0935]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  40%|████      | 2/5 [00:01<00:02,  1.31it/s, acc=95.3, loss=0.0671]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  60%|██████    | 3/5 [00:02<00:01,  1.23it/s, acc=93.8, loss=0.132] 

torch.Size([32, 64, 21, 28])


Epoch 13/15: 100%|██████████| 5/5 [00:03<00:00,  1.45it/s, acc=93.9, loss=0.0168]


torch.Size([4, 64, 21, 28])
Epoch [13/15], Loss: 0.1092, Accuracy: 93.94%


Epoch 14/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  20%|██        | 1/5 [00:00<00:03,  1.15it/s, acc=90.6, loss=0.175]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  40%|████      | 2/5 [00:01<00:02,  1.21it/s, acc=93.8, loss=0.12] 

torch.Size([32, 64, 21, 28])


Epoch 14/15:  60%|██████    | 3/5 [00:02<00:01,  1.24it/s, acc=95.8, loss=0.0134]

torch.Size([32, 64, 21, 28])


Epoch 14/15: 100%|██████████| 5/5 [00:03<00:00,  1.42it/s, acc=96.2, loss=0.00673]


torch.Size([4, 64, 21, 28])
Epoch [14/15], Loss: 0.0794, Accuracy: 96.21%


Epoch 15/15:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  20%|██        | 1/5 [00:00<00:03,  1.28it/s, acc=96.9, loss=0.0587]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  40%|████      | 2/5 [00:01<00:02,  1.22it/s, acc=96.9, loss=0.0488]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  60%|██████    | 3/5 [00:02<00:01,  1.18it/s, acc=96.9, loss=0.126] 

torch.Size([32, 64, 21, 28])


Epoch 15/15: 100%|██████████| 5/5 [00:03<00:00,  1.39it/s, acc=97, loss=0.00119] 

torch.Size([4, 64, 21, 28])
Epoch [15/15], Loss: 0.0629, Accuracy: 96.97%





Model saved!
torch.Size([32, 64, 21, 28])
torch.Size([1, 64, 21, 28])

📊 Test Loss: 0.3793, Test Accuracy: 84.85%


# Let's try again but augumenting data 

In [74]:
import os
import cv2
import numpy as np
import random
from pathlib import Path

def random_rotate(image: np.ndarray,
                  min_angle: float,
                  max_angle: float) -> np.ndarray:
    """
    Randomly rotates the grayscale image by an angle between [min_angle, max_angle].
    Fills blank regions with white (255).
    """
    (height, width) = image.shape[:2] if image.ndim == 3 else image.shape
    angle = random.uniform(min_angle, max_angle)
    center = (width // 2, height // 2)
    
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated_image = cv2.warpAffine(
        image, M, (width, height),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=255  # White for a single-channel (grayscale) image
    )
    return rotated_image

def random_perspective(image: np.ndarray,
                       min_shift: float,
                       max_shift: float) -> np.ndarray:
    """
    Applies a random perspective transform in grayscale. Each corner can shift by
    a value in [min_shift, max_shift] in both x and y directions.
    Fills blank regions with white (255).
    """
    (height, width) = image.shape[:2] if image.ndim == 3 else image.shape

    src_pts = np.float32([
        [0,      0],
        [width,  0],
        [width,  height],
        [0,      height]
    ])
    
    dst_pts = []
    for (x, y) in src_pts:
        shift_x = random.uniform(min_shift, max_shift)
        shift_y = random.uniform(min_shift, max_shift)
        dst_pts.append([x + shift_x, y + shift_y])
    
    dst_pts = np.float32(dst_pts)
    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    
    warped_image = cv2.warpPerspective(
        image, M, (width, height),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=255  # White for a single-channel (grayscale) image
    )
    return warped_image

def generate_augmented_images(image: np.ndarray,
                              min_angle: float,
                              max_angle: float,
                              min_perspective_shift: float,
                              max_perspective_shift: float,
                              num_images: int = 5) -> list:
    """
    Given a grayscale image (np.ndarray), returns a list of 'num_images' augmented images.
    Each augmented image is generated by randomly applying BOTH:
      - rotation (angle in [min_angle, max_angle])
      - perspective shift in [min_perspective_shift, max_perspective_shift]
    """
    augmented_images = []
    for _ in range(num_images):
        # Random rotation
        rotated = random_rotate(image, min_angle, max_angle)
        # Random perspective transform
        warped = random_perspective(rotated, min_perspective_shift, max_perspective_shift)
        augmented_images.append(warped)
    return augmented_images

def augment_dataset(root_folder: str,
                    min_angle=-30, max_angle=30,
                    min_shift=-20, max_shift=20,
                    num_augmented=10):
    """
    Walks through root_folder (recursively), and for each image found,
    generates 'num_augmented' new images, saved in the same folder (grayscale).
    Naming convention:
        originalname_augment_1.ext, originalname_augment_2.ext, ...
    Supported image extensions can be changed as needed.
    """
    
    # Define which file extensions to consider images
    valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff"}
    
    for dirpath, dirnames, filenames in os.walk(root_folder):
        for filename in filenames:
            ext = os.path.splitext(filename)[1].lower()
            if ext in valid_extensions:
                # Construct the full path
                image_path = os.path.join(dirpath, filename)
                
                # Read the image in grayscale
                image_gray = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
                if image_gray is None:
                    print(f"Warning: Could not read {image_path}")
                    continue
                
                # Generate augmented images (grayscale in, grayscale out)
                augmented_images = generate_augmented_images(
                    image=image_gray,
                    min_angle=min_angle,
                    max_angle=max_angle,
                    min_perspective_shift=min_shift,
                    max_perspective_shift=max_shift,
                    num_images=num_augmented
                )
                
                # Save each augmented image
                original_name, ext_ = os.path.splitext(filename)
                for i, aug_img_gray in enumerate(augmented_images, start=1):
                    # Example naming: originalName_augment_1.jpg
                    new_filename = f"{original_name}_augment_{i}{ext_}"
                    new_filepath = os.path.join(dirpath, new_filename)
                    
                    # Save as grayscale
                    cv2.imwrite(new_filepath, aug_img_gray)
                    
                print(f"Augmented {filename} -> created {num_augmented} new images.")


In [76]:
folder_to_augment = r"document_classification\document_classification\email"
augment_dataset(root_folder=folder_to_augment,
            min_angle=-30, max_angle=30,
            min_shift=-20, max_shift=20,
            num_augmented=2)

folder_to_augment = r"document_classification\document_classification\resume"
augment_dataset(root_folder=folder_to_augment,
            min_angle=-30, max_angle=30,
            min_shift=-20, max_shift=20,
            num_augmented=2)

folder_to_augment = r"document_classification\document_classification\scientific_publication"
augment_dataset(root_folder=folder_to_augment,
            min_angle=-30, max_angle=30,
            min_shift=-20, max_shift=20,
            num_augmented=2)

Augmented doc_000042.png -> created 2 new images.
Augmented doc_000046.png -> created 2 new images.
Augmented doc_000076.png -> created 2 new images.
Augmented doc_000079.png -> created 2 new images.
Augmented doc_000111.png -> created 2 new images.
Augmented doc_000115.png -> created 2 new images.
Augmented doc_000133.png -> created 2 new images.
Augmented doc_000142.png -> created 2 new images.
Augmented doc_000148.png -> created 2 new images.
Augmented doc_000165.png -> created 2 new images.
Augmented doc_000195.png -> created 2 new images.
Augmented doc_000196.png -> created 2 new images.
Augmented doc_000238.png -> created 2 new images.
Augmented doc_000255.png -> created 2 new images.
Augmented doc_000260.png -> created 2 new images.
Augmented doc_000275.png -> created 2 new images.
Augmented doc_000278.png -> created 2 new images.
Augmented doc_000279.png -> created 2 new images.
Augmented doc_000282.png -> created 2 new images.
Augmented doc_000297.png -> created 2 new images.


In [78]:
dataset_path = r'document_classification\document_classification'

full_dataset = CustomImageDataset(root_dir=dataset_path, transform=transform)

# Split into train (80%) and test (20%)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

# Check dataset structure
print(f"Classes: {full_dataset.class_to_idx}")
print(f"Total images: {len(full_dataset)}, Train: {len(train_dataset)}, Test: {len(test_dataset)}")

Classes: {'email': 0, 'resume': 1, 'scientific_publication': 2}
Total images: 495, Train: 396, Test: 99


In [79]:
# Train and Evaluate
model.train_model(train_loader, num_epochs=15, lr=0.001)
model.evaluate_model(test_loader)

Epoch 1/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 1/15:   8%|▊         | 1/13 [00:01<00:12,  1.06s/it, acc=81.2, loss=1.11]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  15%|█▌        | 2/13 [00:02<00:11,  1.02s/it, acc=81.2, loss=1.31]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  23%|██▎       | 3/13 [00:02<00:09,  1.02it/s, acc=81.2, loss=0.781]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  31%|███       | 4/13 [00:03<00:08,  1.01it/s, acc=83.6, loss=0.313]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  38%|███▊      | 5/13 [00:04<00:07,  1.02it/s, acc=82.5, loss=0.578]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  46%|████▌     | 6/13 [00:05<00:06,  1.03it/s, acc=83.3, loss=0.346]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  54%|█████▍    | 7/13 [00:06<00:05,  1.03it/s, acc=83, loss=0.581]  

torch.Size([32, 64, 21, 28])


Epoch 1/15:  62%|██████▏   | 8/13 [00:07<00:05,  1.01s/it, acc=84, loss=0.723]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  69%|██████▉   | 9/13 [00:08<00:03,  1.00it/s, acc=84.4, loss=0.536]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  77%|███████▋  | 10/13 [00:09<00:03,  1.01s/it, acc=84.4, loss=0.372]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  85%|████████▍ | 11/13 [00:11<00:02,  1.03s/it, acc=84.1, loss=0.407]

torch.Size([32, 64, 21, 28])


Epoch 1/15:  92%|█████████▏| 12/13 [00:12<00:01,  1.05s/it, acc=84.9, loss=0.3]  

torch.Size([12, 64, 21, 28])


Epoch 1/15: 100%|██████████| 13/13 [00:12<00:00,  1.03it/s, acc=85.1, loss=0.388]


Epoch [1/15], Loss: 0.5957, Accuracy: 85.10%
Model saved!


Epoch 2/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 2/15:   8%|▊         | 1/13 [00:00<00:10,  1.18it/s, acc=90.6, loss=0.47]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  15%|█▌        | 2/13 [00:01<00:09,  1.12it/s, acc=90.6, loss=0.363]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  23%|██▎       | 3/13 [00:02<00:08,  1.19it/s, acc=87.5, loss=0.467]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  31%|███       | 4/13 [00:03<00:07,  1.22it/s, acc=85.9, loss=0.469]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  38%|███▊      | 5/13 [00:04<00:06,  1.22it/s, acc=85.6, loss=0.436]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  46%|████▌     | 6/13 [00:04<00:05,  1.22it/s, acc=85.9, loss=0.354]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  54%|█████▍    | 7/13 [00:05<00:04,  1.24it/s, acc=84.8, loss=0.491]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.24it/s, acc=86.3, loss=0.231]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.22it/s, acc=85.8, loss=0.341]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.23it/s, acc=86.2, loss=0.218]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.24it/s, acc=86.6, loss=0.321]

torch.Size([32, 64, 21, 28])


Epoch 2/15:  92%|█████████▏| 12/13 [00:09<00:00,  1.25it/s, acc=87.5, loss=0.178]

torch.Size([12, 64, 21, 28])


Epoch 2/15: 100%|██████████| 13/13 [00:10<00:00,  1.27it/s, acc=87.4, loss=0.292]


Epoch [2/15], Loss: 0.3562, Accuracy: 87.37%
Model saved!


Epoch 3/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 3/15:   8%|▊         | 1/13 [00:00<00:09,  1.27it/s, acc=93.8, loss=0.154]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  15%|█▌        | 2/13 [00:01<00:09,  1.12it/s, acc=92.2, loss=0.281]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  23%|██▎       | 3/13 [00:02<00:08,  1.15it/s, acc=93.8, loss=0.121]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  31%|███       | 4/13 [00:03<00:07,  1.19it/s, acc=90.6, loss=0.356]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  38%|███▊      | 5/13 [00:04<00:06,  1.21it/s, acc=90, loss=0.22]   

torch.Size([32, 64, 21, 28])


Epoch 3/15:  46%|████▌     | 6/13 [00:05<00:05,  1.18it/s, acc=90.6, loss=0.179]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.20it/s, acc=91.5, loss=0.141]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.19it/s, acc=90.2, loss=0.545]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.18it/s, acc=90.6, loss=0.188]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.18it/s, acc=90.3, loss=0.221]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.15it/s, acc=90.1, loss=0.316]

torch.Size([32, 64, 21, 28])


Epoch 3/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.19it/s, acc=90.6, loss=0.0839]

torch.Size([12, 64, 21, 28])


Epoch 3/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=90.7, loss=0.213] 


Epoch [3/15], Loss: 0.2322, Accuracy: 90.66%
Model saved!


Epoch 4/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 4/15:   8%|▊         | 1/13 [00:00<00:09,  1.21it/s, acc=96.9, loss=0.181]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  15%|█▌        | 2/13 [00:01<00:09,  1.18it/s, acc=96.9, loss=0.126]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  23%|██▎       | 3/13 [00:02<00:08,  1.17it/s, acc=97.9, loss=0.0681]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  31%|███       | 4/13 [00:03<00:07,  1.14it/s, acc=96.9, loss=0.108] 

torch.Size([32, 64, 21, 28])


Epoch 4/15:  38%|███▊      | 5/13 [00:04<00:06,  1.19it/s, acc=96.9, loss=0.152]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  46%|████▌     | 6/13 [00:05<00:05,  1.20it/s, acc=95.3, loss=0.289]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.20it/s, acc=95.1, loss=0.198]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.21it/s, acc=95.3, loss=0.106]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.22it/s, acc=93.8, loss=0.279]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.22it/s, acc=94.1, loss=0.133]

torch.Size([32, 64, 21, 28])


Epoch 4/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.23it/s, acc=94, loss=0.155]  

torch.Size([32, 64, 21, 28])


Epoch 4/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.18it/s, acc=94.3, loss=0.131]

torch.Size([12, 64, 21, 28])


Epoch 4/15: 100%|██████████| 13/13 [00:10<00:00,  1.24it/s, acc=94.2, loss=0.2]  


Epoch [4/15], Loss: 0.1635, Accuracy: 94.19%
Model saved!


Epoch 5/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 5/15:   8%|▊         | 1/13 [00:00<00:09,  1.24it/s, acc=93.8, loss=0.226]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  15%|█▌        | 2/13 [00:01<00:09,  1.14it/s, acc=95.3, loss=0.114]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  23%|██▎       | 3/13 [00:02<00:08,  1.17it/s, acc=95.8, loss=0.107]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  31%|███       | 4/13 [00:03<00:07,  1.19it/s, acc=95.3, loss=0.163]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  38%|███▊      | 5/13 [00:04<00:06,  1.16it/s, acc=95.6, loss=0.101]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  46%|████▌     | 6/13 [00:05<00:06,  1.15it/s, acc=95.3, loss=0.172]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.18it/s, acc=94.6, loss=0.189]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.18it/s, acc=94.5, loss=0.0895]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.21it/s, acc=94.8, loss=0.102] 

torch.Size([32, 64, 21, 28])


Epoch 5/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.21it/s, acc=95.3, loss=0.0669]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.21it/s, acc=95.5, loss=0.0787]

torch.Size([32, 64, 21, 28])


Epoch 5/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.21it/s, acc=95.8, loss=0.0259]

torch.Size([12, 64, 21, 28])


Epoch 5/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=96, loss=0.0762]  


Epoch [5/15], Loss: 0.1162, Accuracy: 95.96%
Model saved!


Epoch 6/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 6/15:   8%|▊         | 1/13 [00:00<00:09,  1.22it/s, acc=100, loss=0.0232]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  15%|█▌        | 2/13 [00:01<00:09,  1.16it/s, acc=98.4, loss=0.0926]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  23%|██▎       | 3/13 [00:02<00:08,  1.17it/s, acc=94.8, loss=0.229] 

torch.Size([32, 64, 21, 28])


Epoch 6/15:  31%|███       | 4/13 [00:03<00:07,  1.15it/s, acc=96.1, loss=0.0422]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  38%|███▊      | 5/13 [00:04<00:06,  1.16it/s, acc=96.2, loss=0.0466]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  46%|████▌     | 6/13 [00:05<00:06,  1.16it/s, acc=96.4, loss=0.052] 

torch.Size([32, 64, 21, 28])


Epoch 6/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.18it/s, acc=96.9, loss=0.0409]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.18it/s, acc=97.3, loss=0.0362]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.14it/s, acc=97.6, loss=0.0351]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.17it/s, acc=96.9, loss=0.171]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.17it/s, acc=97.2, loss=0.0141]

torch.Size([32, 64, 21, 28])


Epoch 6/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.18it/s, acc=97.4, loss=0.0184]

torch.Size([12, 64, 21, 28])


Epoch 6/15: 100%|██████████| 13/13 [00:10<00:00,  1.22it/s, acc=97.2, loss=0.129] 


Epoch [6/15], Loss: 0.0715, Accuracy: 97.22%
Model saved!


Epoch 7/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 7/15:   8%|▊         | 1/13 [00:00<00:10,  1.19it/s, acc=100, loss=0.0467]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  15%|█▌        | 2/13 [00:01<00:09,  1.16it/s, acc=98.4, loss=0.0474]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  23%|██▎       | 3/13 [00:02<00:08,  1.18it/s, acc=99, loss=0.0312]  

torch.Size([32, 64, 21, 28])


Epoch 7/15:  31%|███       | 4/13 [00:03<00:07,  1.16it/s, acc=99.2, loss=0.0158]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  38%|███▊      | 5/13 [00:04<00:06,  1.18it/s, acc=98.1, loss=0.107] 

torch.Size([32, 64, 21, 28])


Epoch 7/15:  46%|████▌     | 6/13 [00:05<00:06,  1.14it/s, acc=97.4, loss=0.072]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.17it/s, acc=97.8, loss=0.0292]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.18it/s, acc=97.7, loss=0.14]  

torch.Size([32, 64, 21, 28])


Epoch 7/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.18it/s, acc=97.9, loss=0.0595]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.19it/s, acc=98.1, loss=0.0271]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.20it/s, acc=98.3, loss=0.00479]

torch.Size([32, 64, 21, 28])


Epoch 7/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.20it/s, acc=98.4, loss=0.024]  

torch.Size([12, 64, 21, 28])


Epoch 7/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=98.5, loss=0.00439]


Epoch [7/15], Loss: 0.0469, Accuracy: 98.48%
Model saved!


Epoch 8/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 8/15:   8%|▊         | 1/13 [00:00<00:10,  1.10it/s, acc=100, loss=0.0242]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  15%|█▌        | 2/13 [00:01<00:09,  1.10it/s, acc=100, loss=0.00867]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  23%|██▎       | 3/13 [00:02<00:08,  1.15it/s, acc=100, loss=0.025]  

torch.Size([32, 64, 21, 28])


Epoch 8/15:  31%|███       | 4/13 [00:03<00:07,  1.16it/s, acc=100, loss=0.0114]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  38%|███▊      | 5/13 [00:04<00:06,  1.18it/s, acc=99.4, loss=0.0282]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  46%|████▌     | 6/13 [00:05<00:05,  1.19it/s, acc=99.5, loss=0.0135]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  54%|█████▍    | 7/13 [00:06<00:05,  1.14it/s, acc=99.6, loss=0.00811]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.17it/s, acc=99.2, loss=0.0397] 

torch.Size([32, 64, 21, 28])


Epoch 8/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.17it/s, acc=99.3, loss=0.00873]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.17it/s, acc=99.1, loss=0.0557]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.19it/s, acc=99.1, loss=0.0135]

torch.Size([32, 64, 21, 28])


Epoch 8/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.14it/s, acc=99.2, loss=0.0118]

torch.Size([12, 64, 21, 28])


Epoch 8/15: 100%|██████████| 13/13 [00:10<00:00,  1.21it/s, acc=99.2, loss=0.00296]


Epoch [8/15], Loss: 0.0193, Accuracy: 99.24%
Model saved!


Epoch 9/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 9/15:   8%|▊         | 1/13 [00:00<00:09,  1.24it/s, acc=100, loss=0.019]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  15%|█▌        | 2/13 [00:01<00:09,  1.16it/s, acc=100, loss=0.0148]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  23%|██▎       | 3/13 [00:02<00:08,  1.16it/s, acc=100, loss=0.00831]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  31%|███       | 4/13 [00:03<00:07,  1.17it/s, acc=100, loss=0.00767]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  38%|███▊      | 5/13 [00:04<00:06,  1.16it/s, acc=100, loss=0.00786]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  46%|████▌     | 6/13 [00:05<00:05,  1.17it/s, acc=100, loss=0.00887]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.19it/s, acc=100, loss=0.00338]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.19it/s, acc=100, loss=0.0125] 

torch.Size([32, 64, 21, 28])


Epoch 9/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.20it/s, acc=100, loss=0.00329]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.20it/s, acc=100, loss=0.00912]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.16it/s, acc=100, loss=0.000404]

torch.Size([32, 64, 21, 28])


Epoch 9/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.18it/s, acc=100, loss=0.019]   

torch.Size([12, 64, 21, 28])


Epoch 9/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=100, loss=0.000468]


Epoch [9/15], Loss: 0.0088, Accuracy: 100.00%
Model saved!


Epoch 10/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 10/15:   8%|▊         | 1/13 [00:00<00:09,  1.21it/s, acc=100, loss=0.00142]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  15%|█▌        | 2/13 [00:01<00:09,  1.14it/s, acc=100, loss=0.0173] 

torch.Size([32, 64, 21, 28])


Epoch 10/15:  23%|██▎       | 3/13 [00:02<00:09,  1.09it/s, acc=100, loss=0.00578]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  31%|███       | 4/13 [00:03<00:07,  1.14it/s, acc=100, loss=0.000689]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  38%|███▊      | 5/13 [00:04<00:06,  1.16it/s, acc=100, loss=0.00234] 

torch.Size([32, 64, 21, 28])


Epoch 10/15:  46%|████▌     | 6/13 [00:05<00:05,  1.17it/s, acc=100, loss=0.00121]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.19it/s, acc=100, loss=0.00462]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.14it/s, acc=100, loss=0.00104]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.17it/s, acc=100, loss=0.00312]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.19it/s, acc=100, loss=0.00331]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.17it/s, acc=100, loss=0.00236]

torch.Size([32, 64, 21, 28])


Epoch 10/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.18it/s, acc=100, loss=0.0147] 

torch.Size([12, 64, 21, 28])


Epoch 10/15: 100%|██████████| 13/13 [00:10<00:00,  1.21it/s, acc=100, loss=0.00614]


Epoch [10/15], Loss: 0.0049, Accuracy: 100.00%


Epoch 11/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 11/15:   8%|▊         | 1/13 [00:00<00:10,  1.17it/s, acc=100, loss=0.00233]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  15%|█▌        | 2/13 [00:01<00:09,  1.14it/s, acc=100, loss=0.00751]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  23%|██▎       | 3/13 [00:02<00:08,  1.18it/s, acc=100, loss=0.00531]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  31%|███       | 4/13 [00:03<00:07,  1.19it/s, acc=100, loss=0.000453]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  38%|███▊      | 5/13 [00:04<00:06,  1.20it/s, acc=100, loss=0.00362] 

torch.Size([32, 64, 21, 28])


Epoch 11/15:  46%|████▌     | 6/13 [00:05<00:05,  1.21it/s, acc=100, loss=0.00209]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  54%|█████▍    | 7/13 [00:05<00:04,  1.21it/s, acc=100, loss=0.00779]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.21it/s, acc=100, loss=0.00253]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.19it/s, acc=100, loss=0.00147]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.17it/s, acc=100, loss=0.01]  

torch.Size([32, 64, 21, 28])


Epoch 11/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.18it/s, acc=100, loss=0.00523]

torch.Size([32, 64, 21, 28])


Epoch 11/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.17it/s, acc=100, loss=0.00159]

torch.Size([12, 64, 21, 28])


Epoch 11/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=100, loss=0.000808]


Epoch [11/15], Loss: 0.0039, Accuracy: 100.00%


Epoch 12/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 12/15:   8%|▊         | 1/13 [00:00<00:09,  1.22it/s, acc=100, loss=0.0025]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  15%|█▌        | 2/13 [00:01<00:09,  1.22it/s, acc=100, loss=0.00315]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  23%|██▎       | 3/13 [00:02<00:08,  1.20it/s, acc=100, loss=0.00227]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  31%|███       | 4/13 [00:03<00:07,  1.14it/s, acc=100, loss=0.00358]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  38%|███▊      | 5/13 [00:04<00:06,  1.15it/s, acc=100, loss=0.00101]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  46%|████▌     | 6/13 [00:05<00:05,  1.17it/s, acc=100, loss=0.0041] 

torch.Size([32, 64, 21, 28])


Epoch 12/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.18it/s, acc=100, loss=0.00114]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.19it/s, acc=100, loss=0.000337]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.21it/s, acc=100, loss=0.00506] 

torch.Size([32, 64, 21, 28])


Epoch 12/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.20it/s, acc=100, loss=0.00151]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.18it/s, acc=100, loss=0.00392]

torch.Size([32, 64, 21, 28])


Epoch 12/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.18it/s, acc=100, loss=0.00707]

torch.Size([12, 64, 21, 28])


Epoch 12/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=100, loss=2.5e-5] 


Epoch [12/15], Loss: 0.0027, Accuracy: 100.00%


Epoch 13/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 13/15:   8%|▊         | 1/13 [00:00<00:10,  1.19it/s, acc=100, loss=0.00131]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  15%|█▌        | 2/13 [00:01<00:08,  1.22it/s, acc=100, loss=0.000578]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  23%|██▎       | 3/13 [00:02<00:08,  1.19it/s, acc=100, loss=0.00336] 

torch.Size([32, 64, 21, 28])


Epoch 13/15:  31%|███       | 4/13 [00:03<00:07,  1.21it/s, acc=100, loss=8.07e-5]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  38%|███▊      | 5/13 [00:04<00:06,  1.16it/s, acc=100, loss=0.00132]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  46%|████▌     | 6/13 [00:05<00:05,  1.18it/s, acc=100, loss=0.000274]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.19it/s, acc=99.6, loss=0.0326] 

torch.Size([32, 64, 21, 28])


Epoch 13/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.18it/s, acc=99.6, loss=0.000421]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.19it/s, acc=99.7, loss=0.00267] 

torch.Size([32, 64, 21, 28])


Epoch 13/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.20it/s, acc=99.7, loss=0.00134]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.19it/s, acc=99.7, loss=0.00562]

torch.Size([32, 64, 21, 28])


Epoch 13/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.20it/s, acc=99.7, loss=0.025]  

torch.Size([12, 64, 21, 28])


Epoch 13/15: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, acc=99.7, loss=0.000445]


Epoch [13/15], Loss: 0.0058, Accuracy: 99.75%


Epoch 14/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 14/15:   8%|▊         | 1/13 [00:00<00:10,  1.18it/s, acc=100, loss=0.01]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  15%|█▌        | 2/13 [00:01<00:09,  1.20it/s, acc=100, loss=0.00299]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  23%|██▎       | 3/13 [00:02<00:08,  1.21it/s, acc=100, loss=0.00146]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  31%|███       | 4/13 [00:03<00:07,  1.18it/s, acc=100, loss=0.00561]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  38%|███▊      | 5/13 [00:04<00:06,  1.21it/s, acc=100, loss=0.00395]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  46%|████▌     | 6/13 [00:04<00:05,  1.20it/s, acc=100, loss=0.00306]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.14it/s, acc=100, loss=0.00666]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.15it/s, acc=100, loss=0.00106]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.17it/s, acc=100, loss=0.00865]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.18it/s, acc=99.7, loss=0.0785]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.19it/s, acc=99.7, loss=0.0011]

torch.Size([32, 64, 21, 28])


Epoch 14/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.17it/s, acc=99.7, loss=0.000832]

torch.Size([12, 64, 21, 28])


Epoch 14/15: 100%|██████████| 13/13 [00:10<00:00,  1.22it/s, acc=99.7, loss=0.000914]


Epoch [14/15], Loss: 0.0096, Accuracy: 99.75%


Epoch 15/15:   0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 21, 28])


Epoch 15/15:   8%|▊         | 1/13 [00:00<00:10,  1.10it/s, acc=100, loss=0.015]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  15%|█▌        | 2/13 [00:01<00:09,  1.15it/s, acc=98.4, loss=0.0604]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  23%|██▎       | 3/13 [00:02<00:08,  1.17it/s, acc=99, loss=0.0264]  

torch.Size([32, 64, 21, 28])


Epoch 15/15:  31%|███       | 4/13 [00:03<00:07,  1.19it/s, acc=98.4, loss=0.0442]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  38%|███▊      | 5/13 [00:04<00:06,  1.18it/s, acc=98.8, loss=0.00571]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  46%|████▌     | 6/13 [00:05<00:05,  1.19it/s, acc=99, loss=0.00615]  

torch.Size([32, 64, 21, 28])


Epoch 15/15:  54%|█████▍    | 7/13 [00:05<00:05,  1.20it/s, acc=98.7, loss=0.117]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  62%|██████▏   | 8/13 [00:06<00:04,  1.15it/s, acc=98.8, loss=0.0151]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  69%|██████▉   | 9/13 [00:07<00:03,  1.15it/s, acc=98.6, loss=0.0623]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  77%|███████▋  | 10/13 [00:08<00:02,  1.18it/s, acc=98.8, loss=0.00282]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  85%|████████▍ | 11/13 [00:09<00:01,  1.18it/s, acc=98.9, loss=0.00468]

torch.Size([32, 64, 21, 28])


Epoch 15/15:  92%|█████████▏| 12/13 [00:10<00:00,  1.19it/s, acc=99, loss=0.0158]   

torch.Size([12, 64, 21, 28])


Epoch 15/15: 100%|██████████| 13/13 [00:10<00:00,  1.21it/s, acc=99, loss=0.0221]


Epoch [15/15], Loss: 0.0306, Accuracy: 98.99%
torch.Size([32, 64, 21, 28])
torch.Size([32, 64, 21, 28])
torch.Size([32, 64, 21, 28])
torch.Size([3, 64, 21, 28])

📊 Test Loss: 0.2131, Test Accuracy: 89.90%


After  augumentation we have an improvement of data  by 5%

# Let's try with Resnet

In [80]:
from torchvision import models

In [85]:
class ResNetTransfer(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetTransfer, self).__init__()
        
        # Load pretrained ResNet18 model
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Freeze all layers
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        # Replace the final fully connected layer with a custom one
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, num_classes)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, x):
        return self.resnet(x)
    
    def unfreeze_last_few_layers(self, num_layers=3):
        """
        Unfreeze the last `num_layers` of the ResNet model to fine-tune them.
        """
        layers = list(self.resnet.children())[:-1]  # Exclude the final FC layer
        for layer in layers[-num_layers:]:
            for param in layer.parameters():
                param.requires_grad = True

    def train_model(self, train_loader, num_epochs=10, lr=0.001):
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=lr)
        best_accuracy = 0.0
        
        for epoch in range(num_epochs):
            self.train()
            running_loss, correct, total = 0.0, 0, 0
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            epoch_acc = 100 * correct / total
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {epoch_acc:.2f}%")

            if epoch_acc > best_accuracy:
                best_accuracy = epoch_acc
                torch.save(self.state_dict(), "best_resnet_model.pth")
                print("Model saved!")

    def evaluate_model(self, test_loader):
        self.eval()
        correct, total, test_loss = 0, 0, 0.0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        test_accuracy = 100 * correct / total
        print(f"\n📊 Test Loss: {test_loss / len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [86]:
# Initialize model
num_classes = len(full_dataset.class_to_idx)  
model = ResNetTransfer(num_classes)

In [87]:
model.unfreeze_last_few_layers(num_layers=3)

In [88]:
# Train and Evaluate
model.train_model(train_loader, num_epochs=15, lr=0.001)
model.evaluate_model(test_loader)

Epoch [1/15], Loss: 0.4498, Accuracy: 82.83%
Model saved!
Epoch [2/15], Loss: 0.1517, Accuracy: 94.95%
Model saved!
Epoch [3/15], Loss: 0.0622, Accuracy: 99.24%
Model saved!
Epoch [4/15], Loss: 0.0413, Accuracy: 98.99%
Epoch [5/15], Loss: 0.0229, Accuracy: 99.24%
Epoch [6/15], Loss: 0.0049, Accuracy: 100.00%
Model saved!
Epoch [7/15], Loss: 0.0038, Accuracy: 100.00%
Epoch [8/15], Loss: 0.0105, Accuracy: 99.75%
Epoch [9/15], Loss: 0.0285, Accuracy: 99.24%
Epoch [10/15], Loss: 0.0062, Accuracy: 100.00%
Epoch [11/15], Loss: 0.0047, Accuracy: 100.00%
Epoch [12/15], Loss: 0.0028, Accuracy: 100.00%
Epoch [13/15], Loss: 0.0009, Accuracy: 100.00%
Epoch [14/15], Loss: 0.0016, Accuracy: 100.00%
Epoch [15/15], Loss: 0.0004, Accuracy: 100.00%

📊 Test Loss: 0.0773, Test Accuracy: 96.97%
