In [1]:
!pip install torch torchvision transformers



In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
import os
import numpy as np
from sklearn.metrics import classification_report
from PIL import Image
import requests

In [3]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [39]:
import subprocess
import os

def clone_repo(repo_url, save_path):

    # Combine save path and repo name
    target_dir = os.path.join(save_path, os.path.basename(repo_url))

    # Construct the git clone command
    cmd = ["git", "clone", repo_url, target_dir]  # Specify target directory

    try:
        # Execute the git clone command
        subprocess.run(cmd, check=True)
        print(f"Repository cloned successfully into {target_dir}")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred while cloning the repository: {e}")

# Example usage
repo_url = "https://github.com/ari-dasci/OD-WeaponDetection/tree/master/Weapons%20and%20similar%20handled%20objects/Sohas_weapon-Classification"
save_path = "/content/drive/MyDrive/datasets/weapons_dataset"  # Example save path

clone_repo(repo_url, save_path)


An error occurred while cloning the repository: Command '['git', 'clone', 'https://github.com/ari-dasci/OD-WeaponDetection/tree/master/Weapons%20and%20similar%20handled%20objects/Sohas_weapon-Classification', '/content/drive/MyDrive/datasets/weapons_dataset/Sohas_weapon-Classification']' returned non-zero exit status 128.


In [40]:
# Load dataset
def load_data(data_path):
    full_dataset = datasets.ImageFolder(root=data_path, transform=transform)
    classes = full_dataset.classes
    train_idx, val_idx = train_test_split(np.arange(len(full_dataset.targets)), test_size=0.2, shuffle=True, stratify=full_dataset.targets)
    train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
    return train_dataset, val_dataset, classes

In [43]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [45]:
!unzip "/content/drive/MyDrive/datasets/SOHAS weapon detection.v2i.coco (1).zip"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: test/smartphone_1063_jpg.rf.e4ab0502b6a572044ae36def290e7ccc.jpg  
 extracting: test/smartphone_1064_jpg.rf.1037e653e81b11951e6a6740edead606.jpg  
 extracting: test/smartphone_1065_jpg.rf.37d0bd7c5c105be08249e855a3b21d56.jpg  
 extracting: test/smartphone_1067_jpg.rf.894bf4bddaae91893069946ef7ea843d.jpg  
 extracting: test/smartphone_1068_jpg.rf.7b2d5099e18639ca191a660557d91a39.jpg  
 extracting: test/smartphone_1069_jpg.rf.4b8b5db8070c456d1242f99f9c57e340.jpg  
 extracting: test/smartphone_1070_jpg.rf.961a0c9fb1d58f3fe1f3a38d1e25e699.jpg  
 extracting: test/smartphone_1071_jpg.rf.f9aa38a10cd6ca1bd1708f6cea9fc2af.jpg  
 extracting: test/smartphone_1115_jpg.rf.39fa409521a54a73c13d7d83073b9b0d.jpg  
 extracting: test/smartphone_9001_jpg.rf.eccfc0a982317f589ab509eff8d30306.jpg  
 extracting: test/smartphone_9003_jpg.rf.07ab9194dd58b801aefb30da2fccfd2c.jpg  
 extracting: test/smartphone_9004_jpg.rf.a436e736a17f6b

In [46]:
train_dataset, val_dataset, classes = load_data('/content/drive/MyDrive/datasets/SOHAS weapon detection.v2i.coco (1).zip (Unzipped Files)')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [47]:
# Initialize the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')



In [49]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=len(classes)).to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = CrossEntropyLoss()

In [51]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [52]:
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [None]:
# Training loop
epochs = 4
for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}")

In [15]:
# Function to get predictions
def get_predictions(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images).logits
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_labels, all_preds

In [16]:
# Calculate predictions
true_labels, predictions = get_predictions(model, val_loader, device)
report = classification_report(true_labels, predictions, target_names=classes, digits=4)
print("Evaluation Report:\n", report)


Evaluation Report:
               precision    recall  f1-score   support

        test     0.0000    0.0000    0.0000       184
       train     0.7576    1.0000    0.8621       822
       valid     0.9310    0.4880    0.6403       166

    accuracy                         0.7705      1172
   macro avg     0.5629    0.4960    0.5008      1172
weighted avg     0.6632    0.7705    0.6953      1172



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
model_path = '/content/drive/My Drive/weapon_classification_model.pth'
torch.save(model.state_dict(), model_path)

In [18]:
# Load the model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=len(classes))
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()

# Initialize the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

def predict_image(image_path, model, feature_extractor, classes, device):
    # Load and transform the image
    image = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        logits = model(**inputs).logits

    # Calculate softmax to get probabilities
    probs = torch.nn.functional.softmax(logits, dim=1)
    probs = probs.cpu().numpy().flatten()

    # Get the predicted class
    predicted_class_idx = probs.argmax()
    predicted_class = classes[predicted_class_idx]
    confidence = probs[predicted_class_idx]

    return predicted_class, confidence

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
