In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import time
from huggingface_hub import login

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define dataset path
data_dir = './train'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
dataset = ImageFolder(root=data_dir, transform=transform)

test_size = int(0.2 * len(dataset))  # 20% for testing
train_val_size = len(dataset) - test_size  # Remaining 80% for training & validation

train_val_dataset, test_dataset = random_split(dataset, [train_val_size, test_size])

val_size = int(0.2 * train_val_size)
train_size = train_val_size - val_size

train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

# Create DataLoaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=32, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=32, pin_memory=True)

print(f"Dataset split: {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} testing.")

# Define models
model_names = ["convnext_tiny", "deit_tiny_patch16_224", "mobilenetv3_large_100", "resnet50", "swin_tiny_patch4_window7_224", "vgg19_bn", "efficientnet_b3"]

# Training function
def train_model(model, train_loader, val_loader, optimizer, epochs=12):
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.float().to(device)
            labels = labels.view(-1, 1)  # Reshape for BCE Loss

            optimizer.zero_grad()
            outputs = model(images)  # Keep [batch_size, 1]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.float().to(device)
                labels = labels.view(-1, 1)

                outputs = model(images)  # Keep [batch_size, 1]
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total

        end_time = time.time()
        epoch_duration = end_time - start_time  # Calculate time taken
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss/len(val_loader):.4f} | Val Acc: {val_acc:.4f} | Time: {epoch_duration:.2f} sec")

# Loop through models
for model_name in model_names:
    print(f"\nTraining {model_name}...\n")

    model = timm.create_model(model_name, pretrained=True, num_classes=1)  # Binary classification
    model = model.to(device)

    # Train the model
    train_model(model, train_loader, val_loader, optimizer = optim.Adam(model.parameters(), lr=0.0001), epochs=25)

    # Save the trained model
    torch.save(model.state_dict(), f"{model_name}_classifier.pth")
    print(f"Training complete.")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

def test_model_with_confusion_matrix(model, test_loader):
    model.eval()
    criterion = nn.BCEWithLogitsLoss()

    test_loss, correct, total = 0.0, 0, 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.float().to(device)
            labels = labels.view(-1, 1)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            all_labels.extend(labels.cpu().numpy().flatten())
            all_preds.extend(predicted.cpu().numpy().flatten())

    test_acc = correct / total
    print(f"Test Loss: {test_loss/len(test_loader):.4f} | Test Accuracy: {test_acc:.4f}")

    # Compute Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    print("\nConfusion Matrix:")
    print(cm)

    # Plot Confusion Matrix
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=["Illegal", "Legal"], yticklabels=["Illegal", "Legal"])
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.show()

    # Classification Report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=["Illegal", "Legal"]))

# Run the test function
test_model_with_confusion_matrix(model, test_loader)

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2

# model_name = "resnet50"
# model = timm.create_model(model_name, pretrained=True, num_classes=1)
model.eval()


# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for model input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and transform image
image_path = "/home/train/aug_legal_folder/aug_0_3000.png"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Hook to store gradients and activations
gradients = None
activations = None

def save_gradients(grad):
    global gradients
    gradients = grad

# Find the last convolutional layer dynamically
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        target_layer = module  # Last Conv layer

# Forward hook to get activations
def forward_hook(module, input, output):
    global activations
    activations = output

# Register hooks
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(lambda module, grad_in, grad_out: save_gradients(grad_out[0]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move model to the appropriate device
input_tensor = input_tensor.to(device)  # Move input tensor to the same device

# Forward pass
output = model(input_tensor)

# Compute gradients w.r.t. highest logit
target_class = (output > 0.5).float().item()  # Binary threshold
model.zero_grad()
output.backward(torch.tensor([[1.0]], device=device))

# Compute Grad-CAM
weights = gradients.mean(dim=(2, 3), keepdim=True)  # Global average pooling of gradients
gradcam = torch.relu((weights * activations).sum(dim=1)).squeeze().cpu().detach().numpy()

# Normalize and convert to heatmap
gradcam = cv2.resize(gradcam, (224, 224))
gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
heatmap = cv2.applyColorMap(np.uint8(255 * gradcam), cv2.COLORMAP_JET)

# Overlay on original image
image_np = np.array(image.resize((224, 224)))
overlay = cv2.addWeighted(image_np, 0.5, heatmap, 0.5, 0)

# Show results
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image_np)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(gradcam, cmap="jet")
plt.title("Grad-CAM Heatmap")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(overlay)
plt.title("Overlay")
plt.axis("off")

plt.show()