In [None]:
%pip install numpy datasets transformers matplotlib kagglehub

Collecting datasets
  Using cached datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting torch
  Using cached torch-2.8.0-cp313-cp313-win_amd64.whl.metadata (30 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.5-cp313-cp313-win_amd64.whl.metadata (11 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-21.0.0-cp313-cp313-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.3.1-cp313-cp313-win_amd64.whl.metadata (19 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp313-cp313-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Using cached multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Using cached fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting sympy>=1.13


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
%pip uninstall torch torchvision

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import time
from tqdm import tqdm

# --- 1. Kaggle Dataset Setup ---
# This section handles the automatic download and extraction of the dataset using kagglehub.
# IMPORTANT: You must have your Kaggle API key set up for kagglehub.
# 1. Go to your Kaggle account settings -> API -> Create New API Token.
# 2. This will download a 'kaggle.json' file.
# 3. Place this file in '~/.kaggle/' on Linux/macOS or 'C:\Users\<Your-Username>\.kaggle\' on Windows.

# kagglehub usage: https://github.com/KaggleHub/kagglehub-python
# Example: kagglehub.dataset.download('owner/dataset-name')

def setup_kaggle_dataset(dataset_slug, download_path):
    """Downloads and extracts a dataset from Kaggle using kagglehub if not already present."""
    if os.path.exists(download_path) and len(os.listdir(download_path)) > 0:
        print(f"Dataset already found at '{download_path}'. Skipping download.")
        return

    print(f"Downloading dataset '{dataset_slug}' to '{download_path}' using kagglehub...")
    try:
        import kagglehub
        dataset_dir = kagglehub.dataset_download(dataset_slug)
        if not os.path.exists(download_path):
            os.makedirs(download_path, exist_ok=True)
        print(f"Dataset downloaded to {dataset_dir}")
    except ImportError:
        print("Warning: 'kagglehub' library not found. Please install it using 'pip install kagglehub'.")
        print("Skipping automatic download. Please ensure the dataset is manually placed in the correct directory.")
    except Exception as e:
        print(f"An error occurred during dataset download with kagglehub: {e}")
        print("Please ensure your kaggle.json is correctly configured and the dataset slug is valid.")
        exit()

# --- 2. Configuration and Hyperparameters ---
KAGGLE_DATASET_SLUG = 'pdavpoojan/the-rvlcdip-dataset-test'
LOCAL_DATA_DIR = './rvlcdip_dataset'
# The path should point to the 'test' subfolder after extraction
DATASET_PATH = os.path.join(LOCAL_DATA_DIR, 'the-rvlcdip-dataset-test', 'test')

BATCH_SIZE = 16 # Reduced batch size to fit larger images in 8GB VRAM
NUM_EPOCHS = 50   # Increased epochs for more training
LEARNING_RATE = 0.001
IMAGE_SIZE = (512, 512) # Increased image resolution

# --- 3. Custom Dataset Class ---
# Load both 'handwritten' and 'email' folders for binary classification.
class HandwrittenPrintedDataset(Dataset):
    """Custom Dataset for loading handwritten and printed (email) images."""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_map = {'handwritten': 0, 'email': 1} # 0: Handwritten, 1: Printed
        self.class_names = {0: 'Handwritten', 1: 'Printed (Email)'}

        for class_name, label in self.class_map.items():
            class_dir = os.path.join(self.root_dir, class_name)
            if not os.path.isdir(class_dir):
                print(f"Warning: Directory not found: {class_dir}")
                continue
            for filename in os.listdir(class_dir):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                    self.samples.append((os.path.join(class_dir, filename), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.randn(3, IMAGE_SIZE[0], IMAGE_SIZE[1]), torch.tensor(0)

        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

# --- 4. Define the Enhanced CNN Model ---
# Output 2 classes for binary classification (handwritten vs printed)
class EnhancedCNN(nn.Module):
    def __init__(self):
        super(EnhancedCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 64 * 64, 512), 
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 2) # Two outputs for binary classification
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


In [None]:
# --- 5. Main Training Script ---
if __name__ == '__main__':
    # Step 1: Download and set up the dataset from Kaggle
    setup_kaggle_dataset(KAGGLE_DATASET_SLUG, DATASET_PATH)

    # Step 2: Set up device and transformations
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Step 3: Create dataset and dataloader
    print("Loading dataset...")
    try:
        full_dataset = HandwrittenPrintedDataset(root_dir=DATASET_PATH, transform=transform)
        if len(full_dataset) == 0:
            raise ValueError("Dataset is empty. Check your DATASET_PATH and folder structure.")
        train_loader = DataLoader(dataset=full_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        print(f"Dataset loaded successfully with {len(full_dataset)} images.")
    except Exception as e:
        print(f"Error creating dataset: {e}")
        exit()

    # Step 4: Initialize model, loss, and optimizer
    model = EnhancedCNN().to(device)
    criterion = nn.CrossEntropyLoss() # For binary classification (2 classes)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Step 5: Training loop
    print("\n--- Starting Training ---")
    start_time = time.time()

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)

        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device) # CrossEntropy expects class indices

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {epoch_loss:.4f}")

    end_time = time.time()
    print("\n--- Training Finished ---")
    print(f"Total training time: {(end_time - start_time)/60:.2f} minutes")

    torch.save(model.state_dict(), 'enhanced_classifier_model.pth')
    print("Model saved to enhanced_classifier_model.pth")


Dataset already found at './rvlcdip_dataset\the-rvlcdip-dataset-test\test'. Skipping download.
Using device: cpu
Loading dataset...
Dataset loaded successfully with 5048 images.

--- Starting Training ---

--- Starting Training ---


Epoch 1/50:   0%|          | 0/316 [00:00<?, ?it/s]

In [2]:
torch.cuda.is_available()

False