In [1]:
# https://www.digitalocean.com/community/tutorials/few-shot-learning

In [2]:
images_dir = r"C:\Users\akilarasan.p\Downloads\archive\samples_for_clients\samples_for_clients"
annotations_dir = r"C:\Users\akilarasan.p\Downloads\archive\annotations\annotations"

In [3]:
import os
import numpy as np
import torch
from PIL import Image
from xml.etree.ElementTree import parse
import albumentations as A
from albumentations.pytorch import ToTensorV2 as ToTensor


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, annotations_dir, transforms=None):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.imgs = list(sorted(os.listdir(images_dir)))
        self.mask_labels = ['keyboard', 'monitor', 'mouse', 'laptop', 'mobile']  

        bbox_params = A.BboxParams(format='albumentations', label_fields=['class_labels'])
        self.transforms = transforms or A.Compose([
            A.HorizontalFlip(p=0.1),
            A.VerticalFlip(p=0.1),
            A.RandomBrightnessContrast(p=0.1),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
            ToTensor()], bbox_params=bbox_params)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.imgs[idx])
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        height, width, _ = image.shape

        # Parse XML annotation
        anno_path = os.path.join(self.annotations_dir, os.path.splitext(self.imgs[idx])[0] + '.xml')
        parser = parse(anno_path)
        labels, boxes = [], []

        for obj in parser.findall('object'):
            bndbox = obj.find('bndbox')
            if bndbox.find('xmin').text != bndbox.find('xmax').text:
                box = [float(bndbox.find(c).text) for c in ['xmin', 'ymin', 'xmax', 'ymax']]
                # Normalize bounding boxes
                box = [
                    max(0.0, min(box[0] / width, 1.0)),   # xmin
                    max(0.0, min(box[1] / height, 1.0)),  # ymin
                    max(0.0, min(box[2] / width, 1.0)),   # xmax
                    max(0.0, min(box[3] / height, 1.0))   # ymax
                ]
                label = obj.find('name').text
                label = self.mask_labels.index(label)
                boxes.append(box)
                labels.append(label)

        # Handle empty annotations
        if not boxes:
            boxes = [(0.0, 0.0, 1.0, 1.0)]
            labels = [0]

        # Apply transforms
        transformed = self.transforms(image=image, bboxes=boxes, class_labels=labels)
        image = transformed['image'].float()  # Ensure float32 type
        boxes = torch.tensor(transformed['bboxes'], dtype=torch.float32)
        labels = torch.tensor(transformed['class_labels'], dtype=torch.int64)

        # Additional target metadata
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        image_id = torch.tensor([idx])
        iscrowd = torch.zeros((len(labels),), dtype=torch.int64)

        # Construct target dictionary
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        return image, target

    def __len__(self):
        return len(self.imgs)

images_dir = r"C:\Users\akilarasan.p\Downloads\archive\samples_for_clients\samples_for_clients"
annotations_dir = r"C:\Users\akilarasan.p\Downloads\archive\annotations\annotations"

dataset = CustomDataset(images_dir, annotations_dir)

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [4]:
# pip install --upgrade albumentations


Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
Collecting albucore==0.0.24 (from albumentations)
  Using cached albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting stringzilla>=3.10.4 (from albucore==0.0.24->albumentations)
  Downloading stringzilla-3.12.5-cp310-cp310-win_amd64.whl.metadata (81 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.24->albumentations)
  Downloading simsimd-6.4.9-cp310-cp310-win_amd64.whl.metadata (67 kB)
Downloading albumentations-2.0.8-py3-none-any.whl (369 kB)
Using cached albucore-0.0.24-py3-none-any.whl (15 kB)
Downloading simsimd-6.4.9-cp310-cp310-win_amd64.whl (94 kB)
Downloading stringzilla-3.12.5-cp310-cp310-win_amd64.whl (80 kB)
Installing collected packages: stringzilla, simsimd, albucore, albumentations
  Attempting uninstall: albucore
    Found existing installation: albucore 0.0.13
    Uninstalling albucore-0.0.13:
      Successfully uninstalled albucore-0.0.13
  Attempting uninstall: a

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
paddleocr 2.9.1 requires albucore==0.0.13, but you have albucore 0.0.24 which is incompatible.
paddleocr 2.9.1 requires albumentations==1.4.10, but you have albumentations 2.0.8 which is incompatible.


In [5]:
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model(num_classes):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

num_classes = 5
# ['keyboard', 'monitor', 'mouse', 'laptop', 'mobile']

model = get_model(num_classes)




In [6]:
dtype = torch.float
device = torch.accelerator.current_accelerator().type if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

def collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

data_loader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
save_path = "model_epoch_{}.pth"
num_epochs = 5
for epoch in range(num_epochs):
    print(f"Starting Epoch {epoch + 1}/{num_epochs}")
    model.train()
    running_loss = 0.0  # Track loss for the epoch
    for batch_idx, (images, targets) in enumerate(data_loader, 1):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
    
        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        print(f"Batch {batch_idx} Loss: {losses.item():.4f}")

    lr_scheduler.step()
    avg_loss = running_loss / len(dataset)
    print(f"Epoch {epoch + 1} Completed. Average Loss = {avg_loss:.4f}")
    torch.save(model.state_dict(), save_path.format(epoch + 1))
    print(f"Model saved at {save_path.format(epoch + 1)}")


Starting Epoch 1/5
Batch 1 Loss: 3.9589
Batch 2 Loss: 1.0440
Batch 3 Loss: 0.4855
Batch 4 Loss: 0.6695


In [None]:
#xml file read

In [None]:
from bs4 import BeautifulSoup
import os

annotations_path = r'C:\Users\akilarasan.p\Downloads\archive\annotations\annotations'
data = []
# Loop through each file in the directory
for filename in os.listdir(annotations_path):
    if filename.endswith('.xml'):
        file_path = os.path.join(annotations_path, filename)
        
        # Read the XML file
        with open(file_path, 'r', encoding='utf-8') as file:
            xml_content = file.read()
        
        # Parse the XML with BeautifulSoup
        soup = BeautifulSoup(xml_content, 'xml')
        
        # Find all object tags
        objects = soup.find_all('object')
        
        # Loop through objects and get their names
        object_names = [obj.find('name').text for obj in objects]
        
        # Print the object names for each file
        # print(f"File: {filename}")
        # print("Object names:", object_names)
        data.append(' '.join(object_names))
        # print("-" * 40)


list(set(' '.join(data).split(' ')))